Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __init__(
self.reasoning_parser = None
self.pad_token_id: int = -1
self.eos_tokens_lens: int = 2
self.think_end_id = None
self.lm_head_fp32: bool = False
self.model_format = "auto"
self.partial_rotary_factor: float = 1.0
Expand Down
2 changes: 0 additions & 2 deletions fastdeploy/entrypoints/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,6 @@ async def add_requests(self, task):
task["prompt_token_ids_len"] = len(task["prompt_token_ids"])
input_ids_len = task["prompt_token_ids_len"]
task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens"))
if task.get("reasoning_max_tokens", None) is None:
task["reasoning_max_tokens"] = max(int(task["max_tokens"] * 0.8), 1)
min_tokens = task.get("min_tokens", 1)
if "messages" in task:
del task["messages"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ def process_request_dict(self, request, max_model_len=None):
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
if request.get("max_tokens") is None:
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
else:
request["max_tokens"] = min(max_model_len - len(request["prompt_token_ids"]), request["max_tokens"])
if request.get("reasoning_max_tokens") is None:
request["reasoning_max_tokens"] = max(int(request["max_tokens"] * 0.8), 1)
data_processor_logger.info(f"Processed request {request}")

return request
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/model_executor/pre_and_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def post_process_normal(
) -> ModelRunnerOutput:
"""Post-processing steps after completing a single token generation."""
# handle vl:
if model_output.enable_thinking:
if model_output.enable_thinking and model_output.think_end_id is not None:
exists_think_end = sampler_output.sampled_token_ids == model_output.think_end_id
paddle.assign(
paddle.where(
Expand Down
58 changes: 35 additions & 23 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,15 +322,21 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
else:
position_ids = None

enable_thinking = request.get("enable_thinking", True)
enable_thinking = enable_thinking if enable_thinking is not None else True
self.share_inputs["enable_thinking"][:] = enable_thinking
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048)
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
position_ids, request.get("max_tokens", 2048)
)

if request.get("enable_thinking", False) and request.get("reasoning_max_tokens") is not None:
# Enable thinking
self.share_inputs["enable_thinking"][:] = True
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens")
else:
# Disable thinking
self.share_inputs["enable_thinking"][:] = False
self.share_inputs["need_think_end"][idx : idx + 1, :] = 0
self.share_inputs["reasoning_index"][idx : idx + 1, :] = 0

if isinstance(request.prompt_token_ids, np.ndarray):
prompt_token_ids = request.prompt_token_ids.tolist()
else:
Expand Down Expand Up @@ -549,16 +555,22 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
self.share_inputs["prompt_lens"][idx : idx + 1] = length

if self.enable_mm:
enable_thinking = request.get("enable_thinking", True)
enable_thinking = enable_thinking if enable_thinking is not None else True
self.share_inputs["enable_thinking"][:] = enable_thinking
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048)
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
position_ids, request.get("max_tokens", 2048)
)
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0

if request.get("enable_thinking", False) and request.get("reasoning_max_tokens") is not None:
# Enable thinking
self.share_inputs["enable_thinking"][:] = True
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens")
else:
# Disable thinking
self.share_inputs["enable_thinking"][:] = False
self.share_inputs["need_think_end"][idx : idx + 1, :] = 0
self.share_inputs["reasoning_index"][idx : idx + 1, :] = 0

def get_attr_from_request(request, attr, default_value=None):
res = request.get(attr, default_value)
if res is not None:
Expand Down Expand Up @@ -853,6 +865,11 @@ def _init_share_inputs(self, max_num_seqs: int):
# Initialize rotary position embedding
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))

# Initialize thinking related buffers
self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
self.share_inputs["enable_thinking"] = paddle.full(shape=[1], fill_value=False, dtype="bool")
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")

# TODO(gongshaotian): move to models
if not self.enable_mm:
self.share_inputs["rope_emb"] = get_rope(
Expand Down Expand Up @@ -952,11 +969,6 @@ def _init_share_inputs(self, max_num_seqs: int):
dtype="float32",
)
self.share_inputs["image_features"] = None
self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
self.share_inputs["enable_thinking"] = paddle.full(
shape=[1], fill_value=("ernie" in self.model_config.model_type), dtype="bool"
)
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")

def _prepare_inputs(self) -> None:
"""Prepare the model inputs"""
Expand Down Expand Up @@ -1392,10 +1404,10 @@ def _dummy_run(
),
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1),
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None),
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
enable_thinking=self.share_inputs["enable_thinking"],
think_end_id=self.model_config.think_end_id,
need_think_end=self.share_inputs["need_think_end"],
reasoning_index=self.share_inputs["reasoning_index"],
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
)
Expand Down Expand Up @@ -1703,10 +1715,10 @@ class at the server level, which is too granular for ModelRunner.
),
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1),
need_think_end=(self.share_inputs["need_think_end"][:num_running_requests] if self.enable_mm else None),
reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None),
enable_thinking=self.share_inputs["enable_thinking"],
think_end_id=self.model_config.think_end_id,
need_think_end=self.share_inputs["need_think_end"][:num_running_requests],
reasoning_index=self.share_inputs["reasoning_index"][:num_running_requests],
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
)
Expand Down
23 changes: 23 additions & 0 deletions fastdeploy/worker/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,28 @@ def update_fd_config_for_mm(fd_config: FDConfig) -> None:
fd_config.model_config.sequence_parallel = fd_config.parallel_config.sequence_parallel


def update_think_end_id_for_ernie(fd_config: FDConfig) -> None:
"""
Updates the think_end_id in the model config. Uses the ID of '</think>'
if it exists, otherwise defaults to None.
"""
is_ernie = ErnieArchitectures.contains_ernie_arch(fd_config.model_config.architectures)
if current_platform.is_cuda() and is_ernie:
tokenizer = Ernie4_5Tokenizer.from_pretrained(
fd_config.model_config.model,
model_max_length=fd_config.parallel_config.max_model_len,
padding_side="right",
use_fast=False,
)

vocab = tokenizer.get_vocab()
fd_config.model_config.think_end_id = vocab.get("</think>", None)
if fd_config.model_config.think_end_id is not None:
logger.info(f"Get think_end_id {fd_config.model_config.think_end_id} from vocab.")
else:
logger.info("No </think> token found in vocabulary, the model can not do reasoning.")


class PaddleDisWorkerProc:
"""
Paddle Distributed wrapper for fastdeploy.worker.Worker,
Expand Down Expand Up @@ -767,6 +789,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
moba_attention_config=moba_attention_config,
)
update_fd_config_for_mm(fd_config)
update_think_end_id_for_ernie(fd_config)

return fd_config

Expand Down
62 changes: 62 additions & 0 deletions tests/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,21 @@ def test_chat_with_thinking(openai_client, capsys):
assert response.choices[0].message.reasoning_content is None
assert "</think>" not in response.choices[0].message.content

# test logic
reasoning_max_tokens = None
response = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
temperature=1,
stream=False,
max_tokens=20,
extra_body={
"chat_template_kwargs": {"enable_thinking": True},
"reasoning_max_tokens": reasoning_max_tokens,
},
)
assert response.choices[0].message.reasoning_content is not None

# enable thinking, streaming
reasoning_max_tokens = 3
response = openai_client.chat.completions.create(
Expand Down Expand Up @@ -927,3 +942,50 @@ def test_profile_reset_block_num():
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
)


def test_thinking_logic_flag(openai_client, capsys):
"""
Test the interaction between token calculation logic and conditional thinking.
This test covers:
1. Default max_tokens calculation when not provided.
2. Capping of max_tokens when it exceeds model limits.
3. Default reasoning_max_tokens calculation when not provided.
4. Activation of thinking based on the final state of reasoning_max_tokens.
"""

response_case_1 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Explain gravity briefly."}],
temperature=1,
stream=False,
extra_body={
"chat_template_kwargs": {"enable_thinking": True},
},
)
assert response_case_1.choices[0].message.reasoning_content is not None

response_case_2 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
temperature=1,
stream=False,
max_tokens=20,
extra_body={
"chat_template_kwargs": {"enable_thinking": True},
"reasoning_max_tokens": 5,
},
)
assert response_case_2.choices[0].message.reasoning_content is not None

response_case_3 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
temperature=1,
stream=False,
max_tokens=20,
extra_body={
"chat_template_kwargs": {"enable_thinking": False},
},
)
assert response_case_3.choices[0].message.reasoning_content is None
62 changes: 62 additions & 0 deletions tests/e2e/test_EB_VL_Lite_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,21 @@ def test_chat_with_thinking(openai_client, capsys):
assert response.choices[0].message.reasoning_content is None
assert "</think>" not in response.choices[0].message.content

# test logic
reasoning_max_tokens = None
response = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
temperature=1,
stream=False,
max_tokens=20,
extra_body={
"chat_template_kwargs": {"enable_thinking": True},
"reasoning_max_tokens": reasoning_max_tokens,
},
)
assert response.choices[0].message.reasoning_content is not None

# enable thinking, streaming
reasoning_max_tokens = 3
response = openai_client.chat.completions.create(
Expand Down Expand Up @@ -642,3 +657,50 @@ def test_profile_reset_block_num():
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
)


def test_thinking_logic_flag(openai_client, capsys):
"""
Test the interaction between token calculation logic and conditional thinking.
This test covers:
1. Default max_tokens calculation when not provided.
2. Capping of max_tokens when it exceeds model limits.
3. Default reasoning_max_tokens calculation when not provided.
4. Activation of thinking based on the final state of reasoning_max_tokens.
"""

response_case_1 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Explain gravity briefly."}],
temperature=1,
stream=False,
extra_body={
"chat_template_kwargs": {"enable_thinking": True},
},
)
assert response_case_1.choices[0].message.reasoning_content is not None

response_case_2 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
temperature=1,
stream=False,
max_tokens=20,
extra_body={
"chat_template_kwargs": {"enable_thinking": True},
"reasoning_max_tokens": 5,
},
)
assert response_case_2.choices[0].message.reasoning_content is not None

response_case_3 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
temperature=1,
stream=False,
max_tokens=20,
extra_body={
"chat_template_kwargs": {"enable_thinking": False},
},
)
assert response_case_3.choices[0].message.reasoning_content is None
Loading