diff --git a/scripts/jiuge.py b/scripts/jiuge.py index 523820c9..e23f032f 100644 --- a/scripts/jiuge.py +++ b/scripts/jiuge.py @@ -616,7 +616,7 @@ def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10): batch_id += 1 batch_inputs = JiugeBatchedTask(tasks[:batch_id]) - logits = torch.zeros( + log_probs = torch.zeros( (batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits ) self.jiuge_model.forward_batch( @@ -627,12 +627,12 @@ def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10): batch_inputs.nreq, batch_inputs.req_pos, batch_inputs.kv_caches, - logits.data_ptr(), + log_probs.data_ptr(), ) - logits = logits.float() + # forward_batch now returns log_softmax results, no need for additional calculation + log_probs = log_probs.float() token_ids = torch.tensor(true_tokens, dtype=torch.int64) # [ntok,] - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # (ntok, vocab) token_logprobs = log_probs[ torch.arange(batch_inputs.ntok), token_ids ] # (ntok,) diff --git a/scripts/launch_server.py b/scripts/launch_server.py index a315b4e6..903a5396 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -3,6 +3,7 @@ from libinfinicore_infer import DeviceType from infer_task import InferTask from kvcache_pool import KVCachePool +import torch import argparse import queue @@ -176,17 +177,27 @@ def worker_loop(app): def build_task(id_, request_data, request: Request): - messages = request_data.get("messages", []) - input_content = request.app.state.model.tokenizer.apply_chat_template( - conversation=messages, - add_generation_prompt=True, - tokenize=False, - ) - tokens = request.app.state.model.tokenizer.encode(input_content) + # Handle both chat and completion formats + if "messages" in request_data: + # Chat format + messages = request_data.get("messages", []) + input_content = request.app.state.model.tokenizer.apply_chat_template( + conversation=messages, + add_generation_prompt=True, + tokenize=False, + ) + tokens = request.app.state.model.tokenizer.encode(input_content) + max_tokens = request_data.get("max_tokens", request.app.state.model.max_context_len()) + else: + # Completion format + prompt = request_data.get("prompt", "") + tokens = request.app.state.model.tokenizer.encode(prompt) + max_tokens = request_data.get("max_tokens", 0) + return AsyncInferTask( id_, tokens, - request_data.get("max_tokens", request.app.state.model.max_context_len()), + max_tokens, request_data.get("temperature", 1.0), request_data.get("top_k", 1), request_data.get("top_p", 1.0), @@ -294,6 +305,145 @@ async def chat_completions(request: Request): return JSONResponse(content=response) + + + +async def completion(id_, request_data, request: Request): + infer_task = None # Initialize to None to avoid UnboundLocalError + try: + # Check if max_tokens > 0 is requested + max_tokens = request_data.get("max_tokens", 0) + if max_tokens > 0: + return JSONResponse( + content={"error": "max_tokens > 0 is not supported yet. Please use max_tokens=0 for logprobs calculation."}, + status_code=400 + ) + + infer_task = build_task(id_, request_data, request) + await request.app.state.kv_cache_pool.acquire(infer_task) + + output = [] + logprobs = [] + + # Handle echo and logprobs calculation + echo = request_data.get("echo", False) + if echo: + # Add input tokens to output + input_tokens = infer_task.tokens + for token in input_tokens: + content = ( + request.app.state.model.tokenizer._tokenizer.id_to_token(token) + .replace("▁", " ") + .replace("<0x0A>", "\n") + ) + output.append(content) + + # Calculate logprobs for input tokens + from jiuge import JiugeBatchedTask + batch_inputs = JiugeBatchedTask([infer_task]) + log_probs = torch.zeros( + (batch_inputs.ntok, request.app.state.model.meta.dvoc), + dtype=request.app.state.model.meta.torch_dtype_logits + ) + request.app.state.model.jiuge_model.forward_batch( + request.app.state.model.model_instance, + batch_inputs.tokens, + batch_inputs.ntok, + batch_inputs.req_lens, + batch_inputs.nreq, + batch_inputs.req_pos, + batch_inputs.kv_caches, + log_probs.data_ptr(), + ) + + log_probs = log_probs.float() + + # Calculate correct logprobs for input tokens + token_logprobs = [] + for i in range(len(infer_task.tokens) - 1): # Only up to second-to-last token + next_token = infer_task.tokens[i+1] # Next token to predict + logprob = log_probs[i, next_token].item() # Use position i logits to predict position i+1 token + token_logprobs.append(logprob) + + # First token has no context, so logprob is None + logprobs = [None] + token_logprobs + else: + # echo=false: don't calculate logprobs since user can't see input text + logprobs = [] + + # For max_tokens=0, we need to manually release the KV cache since we don't go through worker + await request.app.state.kv_cache_pool.release(infer_task) + print(f"[DEBUG] {id_} Released KV cache for max_tokens=0") + + output_text = "".join(output).strip() + + # Prepare tokens list for logprobs + tokens_list = [] + text_offset_list = [] + current_offset = 0 + + # Build tokens list and text offsets + for i, content in enumerate(output): + tokens_list.append(content) + text_offset_list.append(current_offset) + current_offset += len(content) + + # Build response according to DeepSeek API completion format + response = { + "id": id_, + "object": "text_completion", + "created": int(time.time()), + "model": "jiuge", + "choices": [ + { + "text": output_text, + "index": 0, + "logprobs": { + "token_logprobs": logprobs, + "tokens": tokens_list, + "text_offset": text_offset_list, + "top_logprobs": [] + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": len(infer_task.tokens), + "prompt_cache_hit_tokens": 0, + "prompt_cache_miss_tokens": len(infer_task.tokens), + "completion_tokens": 0, + "total_tokens": len(infer_task.tokens), + "completion_tokens_details": { + "reasoning_tokens": 0 + } + } + } + return response + + except Exception as e: + print(f"[Error] ID: {id_} Exception: {e}") + return JSONResponse(content={"error": str(e)}, status_code=500) + finally: + if infer_task and infer_task.finish_reason is None: + infer_task.finish_reason = "cancel" + + +@App.post("/completions") +async def completions(request: Request): + data = await request.json() + + if not data.get("prompt"): + return JSONResponse(content={"error": "No prompt provided"}, status_code=400) + + id_ = f"cmpl-{uuid.uuid4().hex}" + response = await completion(id_, data, request) + + # Check if response is already a JSONResponse (error case) + if isinstance(response, JSONResponse): + return response + else: + return JSONResponse(content=response) + if __name__ == "__main__": uvicorn.run(App, host="0.0.0.0", port=8000) diff --git a/scripts/test_ppl.py b/scripts/test_ppl.py index 268a9f7d..5627dc57 100644 --- a/scripts/test_ppl.py +++ b/scripts/test_ppl.py @@ -33,8 +33,9 @@ # endcode, chunk and decode tokens = tokenizer.encode(text, add_special_tokens=False) - for i in range(0, len(tokens), CHUNK_SIZE): - chunk_tokens = tokens[i : min(i + CHUNK_SIZE, len(tokens))] + # 使用与jiuge_ppl.py相同的分割逻辑,只处理完整的chunk + for i in range(0, len(tokens) - CHUNK_SIZE + 1, CHUNK_SIZE): + chunk_tokens = tokens[i : i + CHUNK_SIZE] chunk_text = tokenizer.decode(chunk_tokens) resp = requests.post( diff --git a/src/cache_manager/opcache_manager.hpp b/src/cache_manager/opcache_manager.hpp index 333583e8..83ef5aed 100644 --- a/src/cache_manager/opcache_manager.hpp +++ b/src/cache_manager/opcache_manager.hpp @@ -158,6 +158,7 @@ class CacheManager { DECLARE_OP_CACHE(RoPE) DECLARE_OP_CACHE(Rearrange) DECLARE_OP_CACHE(CausalSoftmax) + DECLARE_OP_CACHE(LogSoftmax) DECLARE_OP_CACHE(Topkrouter) DECLARE_OP_CACHE(SwiGLU) DECLARE_OP_CACHE(RandomSample) @@ -170,6 +171,7 @@ class CacheManager { RoPE_cache(capacity, DESTROY_FUNC(RoPE)), Rearrange_cache(capacity, DESTROY_FUNC(Rearrange)), CausalSoftmax_cache(capacity, DESTROY_FUNC(CausalSoftmax)), + LogSoftmax_cache(capacity, DESTROY_FUNC(LogSoftmax)), Topkrouter_cache(capacity, DESTROY_FUNC(Topkrouter)), SwiGLU_cache(capacity, DESTROY_FUNC(SwiGLU)), RandomSample_cache(capacity, DESTROY_FUNC(RandomSample)), diff --git a/src/models/inference_context.cpp b/src/models/inference_context.cpp index e41e4bb3..15517538 100644 --- a/src/models/inference_context.cpp +++ b/src/models/inference_context.cpp @@ -143,6 +143,26 @@ void InferenceContext::causalSoftmax(std::shared_ptr y, y->data(), x->data(), stream)); } +void InferenceContext::logSoftmax(std::shared_ptr y, + std::shared_ptr x) { + size_t key = CacheManager::createDescriptorKey(y, x); + + infiniopLogSoftmaxDescriptor_t desc; + if (!cache_manager->getLogSoftmaxDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateLogSoftmaxDescriptor( + op_handle, &desc, y->desc(), x->desc())); + cache_manager->putLogSoftmaxDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetLogSoftmaxWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopLogSoftmax(desc, workspace, workspace_size, + y->data(), x->data(), stream)); +} + void InferenceContext::topkrouter(std::shared_ptr values, // F32 std::shared_ptr indices, // I32 std::shared_ptr x, diff --git a/src/models/inference_context.hpp b/src/models/inference_context.hpp index 0cf93f6f..d8597b5c 100644 --- a/src/models/inference_context.hpp +++ b/src/models/inference_context.hpp @@ -37,6 +37,8 @@ struct InferenceContext { infiniopRoPEAlgo_t algo); void causalSoftmax(std::shared_ptr y, std::shared_ptr x); + void logSoftmax(std::shared_ptr y, + std::shared_ptr x); void topkrouter(std::shared_ptr values, // F32 std::shared_ptr indices, // I32 @@ -111,6 +113,10 @@ inline void causalSoftmax(std::shared_ptr y, std::shared_ptr x) getInferenceContext().causalSoftmax(y, x); } +inline void logSoftmax(std::shared_ptr y, std::shared_ptr x) { + getInferenceContext().logSoftmax(y, x); +} + inline void topkrouter(std::shared_ptr values, // F32 std::shared_ptr indices, // I32 std::shared_ptr x, diff --git a/src/models/jiuge/jiuge.cpp b/src/models/jiuge/jiuge.cpp index 059842cc..f33b3e1e 100644 --- a/src/models/jiuge/jiuge.cpp +++ b/src/models/jiuge/jiuge.cpp @@ -262,8 +262,12 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, rmsnorm(logits_out, logits_in, rsrc.w_out_norm, meta.epsilon); auto last_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); linear(last_logits_buf, logits_out, rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr); + + auto log_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); + logSoftmax(log_logits_buf, last_logits_buf); + RUN_INFINI(infinirtStreamSynchronize(stream)); - RUN_INFINI(infinirtMemcpy(last_logits, last_logits_buf->data(), dsize(dt_logits) * ntok * dvoc, INFINIRT_MEMCPY_D2H)); + RUN_INFINI(infinirtMemcpy(last_logits, log_logits_buf->data(), dsize(dt_logits) * ntok * dvoc, INFINIRT_MEMCPY_D2H)); } if (output != nullptr) { size_t token_offset = 0;