From 1726dab9e3d938f1743c7c2bf35c15af5bf7fb20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ey=C3=BCp=20Can=20Akman?= Date: Fri, 20 Mar 2026 15:49:05 +0300 Subject: [PATCH] Respect per-request enable_thinking in server response handling The chat_template_kwargs from client requests (e.g. enable_thinking) were applied to the chat template during tokenization but ignored when building the GenerationContext. This meant the response handler always used the static tokenizer.has_thinking flag, so reasoning detection and prompt checkpointing could not be toggled per request. Add _has_thinking() that resolves the effective thinking state from per-request kwargs, CLI --chat-template-args, then the tokenizer default. Use it in both generation paths and prompt checkpointing. Fixes #914 --- mlx_lm/server.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 7fc91fa2a..ca6c2b3ed 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -734,7 +734,7 @@ def _tokenize(self, tokenizer, request, args): else: return tokenizer.encode(request.prompt) - def _compute_prompt_checkpoint(self, tokenizer, request, prompt): + def _compute_prompt_checkpoint(self, tokenizer, request, prompt, args): if request.request_type != "chat": return False, -1 if request.messages[-1]["role"] != "user": @@ -744,7 +744,7 @@ def _compute_prompt_checkpoint(self, tokenizer, request, prompt): # the think start token which will likely be removed in the # next turn. prompt_checkpoint = -1 - if tokenizer.has_thinking: + if self._has_thinking(tokenizer, args): for i in range(1, min(11, len(prompt)) - 1, 1): if prompt[-i] == tokenizer.think_start_id: prompt_checkpoint = -i - 1 @@ -752,6 +752,20 @@ def _compute_prompt_checkpoint(self, tokenizer, request, prompt): return True, prompt_checkpoint + def _has_thinking(self, tokenizer, args): + """Return whether thinking is active for this request. + + Check (in priority order) the per-request chat_template_kwargs, the + CLI --chat-template-args, and finally the tokenizer's own capability + flag. + """ + if args.chat_template_kwargs and "enable_thinking" in args.chat_template_kwargs: + return args.chat_template_kwargs["enable_thinking"] + cli_args = self.model_provider.cli_args.chat_template_args + if "enable_thinking" in cli_args: + return cli_args["enable_thinking"] + return tokenizer.has_thinking + def _is_batchable(self, args): if not self.model_provider.is_batchable: return False @@ -829,7 +843,7 @@ def checkpoint_callback(prompts): tool_call_start=tokenizer.tool_call_start, tool_call_end=tokenizer.tool_call_end, tool_parser=tokenizer.tool_parser, - has_thinking=tokenizer.has_thinking, + has_thinking=self._has_thinking(tokenizer, args), think_start_id=tokenizer.think_start_id, think_end=tokenizer.think_end, think_end_id=tokenizer.think_end_id, @@ -851,7 +865,9 @@ def checkpoint_callback(prompts): cache = make_prompt_cache(self.model_provider.model) do_checkpoint, checkpoint_position = ( - self._compute_prompt_checkpoint(tokenizer, request, prompt) + self._compute_prompt_checkpoint( + tokenizer, request, prompt, args + ) ) (uid,) = batch_generator.insert( @@ -1001,7 +1017,7 @@ def progress(tokens_processed, tokens_total): tool_call_start=tokenizer.tool_call_start, tool_call_end=tokenizer.tool_call_end, tool_parser=tokenizer.tool_parser, - has_thinking=tokenizer.has_thinking, + has_thinking=self._has_thinking(tokenizer, args), think_start_id=tokenizer.think_start_id, think_end=tokenizer.think_end, think_end_id=tokenizer.think_end_id,