diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 7fc91fa2a..ca6c2b3ed 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -734,7 +734,7 @@ def _tokenize(self, tokenizer, request, args): else: return tokenizer.encode(request.prompt) - def _compute_prompt_checkpoint(self, tokenizer, request, prompt): + def _compute_prompt_checkpoint(self, tokenizer, request, prompt, args): if request.request_type != "chat": return False, -1 if request.messages[-1]["role"] != "user": @@ -744,7 +744,7 @@ def _compute_prompt_checkpoint(self, tokenizer, request, prompt): # the think start token which will likely be removed in the # next turn. prompt_checkpoint = -1 - if tokenizer.has_thinking: + if self._has_thinking(tokenizer, args): for i in range(1, min(11, len(prompt)) - 1, 1): if prompt[-i] == tokenizer.think_start_id: prompt_checkpoint = -i - 1 @@ -752,6 +752,20 @@ def _compute_prompt_checkpoint(self, tokenizer, request, prompt): return True, prompt_checkpoint + def _has_thinking(self, tokenizer, args): + """Return whether thinking is active for this request. + + Check (in priority order) the per-request chat_template_kwargs, the + CLI --chat-template-args, and finally the tokenizer's own capability + flag. + """ + if args.chat_template_kwargs and "enable_thinking" in args.chat_template_kwargs: + return args.chat_template_kwargs["enable_thinking"] + cli_args = self.model_provider.cli_args.chat_template_args + if "enable_thinking" in cli_args: + return cli_args["enable_thinking"] + return tokenizer.has_thinking + def _is_batchable(self, args): if not self.model_provider.is_batchable: return False @@ -829,7 +843,7 @@ def checkpoint_callback(prompts): tool_call_start=tokenizer.tool_call_start, tool_call_end=tokenizer.tool_call_end, tool_parser=tokenizer.tool_parser, - has_thinking=tokenizer.has_thinking, + has_thinking=self._has_thinking(tokenizer, args), think_start_id=tokenizer.think_start_id, think_end=tokenizer.think_end, think_end_id=tokenizer.think_end_id, @@ -851,7 +865,9 @@ def checkpoint_callback(prompts): cache = make_prompt_cache(self.model_provider.model) do_checkpoint, checkpoint_position = ( - self._compute_prompt_checkpoint(tokenizer, request, prompt) + self._compute_prompt_checkpoint( + tokenizer, request, prompt, args + ) ) (uid,) = batch_generator.insert( @@ -1001,7 +1017,7 @@ def progress(tokens_processed, tokens_total): tool_call_start=tokenizer.tool_call_start, tool_call_end=tokenizer.tool_call_end, tool_parser=tokenizer.tool_parser, - has_thinking=tokenizer.has_thinking, + has_thinking=self._has_thinking(tokenizer, args), think_start_id=tokenizer.think_start_id, think_end=tokenizer.think_end, think_end_id=tokenizer.think_end_id,