@@ -322,15 +322,21 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
322
322
else :
323
323
position_ids = None
324
324
325
- enable_thinking = request .get ("enable_thinking" , True )
326
- enable_thinking = enable_thinking if enable_thinking is not None else True
327
- self .share_inputs ["enable_thinking" ][:] = enable_thinking
328
- self .share_inputs ["need_think_end" ][idx : idx + 1 , :] = 1 if enable_thinking else 0
329
- self .share_inputs ["reasoning_index" ][idx : idx + 1 , :] = request .get ("reasoning_max_tokens" , 2048 )
330
325
self .share_inputs ["rope_emb" ][idx : idx + 1 , :] = self .prepare_rope3d (
331
326
position_ids , request .get ("max_tokens" , 2048 )
332
327
)
333
328
329
+ if request .get ("enable_thinking" , False ) and request .get ("reasoning_max_tokens" ) is not None :
330
+ # Enable thinking
331
+ self .share_inputs ["enable_thinking" ][:] = True
332
+ self .share_inputs ["need_think_end" ][idx : idx + 1 , :] = 1
333
+ self .share_inputs ["reasoning_index" ][idx : idx + 1 , :] = request .get ("reasoning_max_tokens" )
334
+ else :
335
+ # Disable thinking
336
+ self .share_inputs ["enable_thinking" ][:] = False
337
+ self .share_inputs ["need_think_end" ][idx : idx + 1 , :] = 0
338
+ self .share_inputs ["reasoning_index" ][idx : idx + 1 , :] = 0
339
+
334
340
if isinstance (request .prompt_token_ids , np .ndarray ):
335
341
prompt_token_ids = request .prompt_token_ids .tolist ()
336
342
else :
@@ -549,16 +555,22 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
549
555
self .share_inputs ["prompt_lens" ][idx : idx + 1 ] = length
550
556
551
557
if self .enable_mm :
552
- enable_thinking = request .get ("enable_thinking" , True )
553
- enable_thinking = enable_thinking if enable_thinking is not None else True
554
- self .share_inputs ["enable_thinking" ][:] = enable_thinking
555
- self .share_inputs ["need_think_end" ][idx : idx + 1 , :] = 1 if enable_thinking else 0
556
- self .share_inputs ["reasoning_index" ][idx : idx + 1 , :] = request .get ("reasoning_max_tokens" , 2048 )
557
558
self .share_inputs ["rope_emb" ][idx : idx + 1 , :] = self .prepare_rope3d (
558
559
position_ids , request .get ("max_tokens" , 2048 )
559
560
)
560
561
self .share_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = 0
561
562
563
+ if request .get ("enable_thinking" , False ) and request .get ("reasoning_max_tokens" ) is not None :
564
+ # Enable thinking
565
+ self .share_inputs ["enable_thinking" ][:] = True
566
+ self .share_inputs ["need_think_end" ][idx : idx + 1 , :] = 1
567
+ self .share_inputs ["reasoning_index" ][idx : idx + 1 , :] = request .get ("reasoning_max_tokens" )
568
+ else :
569
+ # Disable thinking
570
+ self .share_inputs ["enable_thinking" ][:] = False
571
+ self .share_inputs ["need_think_end" ][idx : idx + 1 , :] = 0
572
+ self .share_inputs ["reasoning_index" ][idx : idx + 1 , :] = 0
573
+
562
574
def get_attr_from_request (request , attr , default_value = None ):
563
575
res = request .get (attr , default_value )
564
576
if res is not None :
@@ -853,6 +865,11 @@ def _init_share_inputs(self, max_num_seqs: int):
853
865
# Initialize rotary position embedding
854
866
tmp_position_ids = paddle .arange (self .parallel_config .max_model_len ).reshape ((1 , - 1 ))
855
867
868
+ # Initialize thinking related buffers
869
+ self .share_inputs ["need_think_end" ] = paddle .full (shape = [max_num_seqs , 1 ], fill_value = 0 , dtype = "int32" )
870
+ self .share_inputs ["enable_thinking" ] = paddle .full (shape = [1 ], fill_value = False , dtype = "bool" )
871
+ self .share_inputs ["reasoning_index" ] = paddle .full (shape = [max_num_seqs , 1 ], fill_value = 0 , dtype = "int32" )
872
+
856
873
# TODO(gongshaotian): move to models
857
874
if not self .enable_mm :
858
875
self .share_inputs ["rope_emb" ] = get_rope (
@@ -952,11 +969,6 @@ def _init_share_inputs(self, max_num_seqs: int):
952
969
dtype = "float32" ,
953
970
)
954
971
self .share_inputs ["image_features" ] = None
955
- self .share_inputs ["need_think_end" ] = paddle .full (shape = [max_num_seqs , 1 ], fill_value = 0 , dtype = "int32" )
956
- self .share_inputs ["enable_thinking" ] = paddle .full (
957
- shape = [1 ], fill_value = ("ernie" in self .model_config .model_type ), dtype = "bool"
958
- )
959
- self .share_inputs ["reasoning_index" ] = paddle .full (shape = [max_num_seqs , 1 ], fill_value = 0 , dtype = "int32" )
960
972
961
973
def _prepare_inputs (self ) -> None :
962
974
"""Prepare the model inputs"""
@@ -1392,10 +1404,10 @@ def _dummy_run(
1392
1404
),
1393
1405
accept_tokens = (self .share_inputs ["accept_tokens" ] if self .speculative_decoding else None ),
1394
1406
accept_num = (self .share_inputs ["accept_num" ] if self .speculative_decoding else None ),
1395
- enable_thinking = ( self .share_inputs ["enable_thinking" ] if self . enable_mm else None ) ,
1396
- think_end_id = ( getattr ( self .model_config , "think_end_id" , - 1 ) if self . enable_mm else - 1 ) ,
1397
- need_think_end = ( self .share_inputs ["need_think_end" ] if self . enable_mm else None ) ,
1398
- reasoning_index = ( self .share_inputs ["reasoning_index" ] if self . enable_mm else None ) ,
1407
+ enable_thinking = self .share_inputs ["enable_thinking" ],
1408
+ think_end_id = self .model_config . think_end_id ,
1409
+ need_think_end = self .share_inputs ["need_think_end" ],
1410
+ reasoning_index = self .share_inputs ["reasoning_index" ],
1399
1411
stop_token_ids = self .share_inputs ["stop_seqs" ],
1400
1412
stop_seqs_len = self .share_inputs ["stop_seqs_len" ],
1401
1413
)
@@ -1703,10 +1715,10 @@ class at the server level, which is too granular for ModelRunner.
1703
1715
),
1704
1716
accept_tokens = (self .share_inputs ["accept_tokens" ] if self .speculative_decoding else None ),
1705
1717
accept_num = (self .share_inputs ["accept_num" ] if self .speculative_decoding else None ),
1706
- enable_thinking = ( self .share_inputs ["enable_thinking" ] if self . enable_mm else None ) ,
1707
- think_end_id = ( getattr ( self .model_config , "think_end_id" , - 1 ) if self . enable_mm else - 1 ) ,
1708
- need_think_end = ( self .share_inputs ["need_think_end" ][:num_running_requests ] if self . enable_mm else None ) ,
1709
- reasoning_index = ( self .share_inputs ["reasoning_index" ][:num_running_requests ] if self . enable_mm else None ) ,
1718
+ enable_thinking = self .share_inputs ["enable_thinking" ],
1719
+ think_end_id = self .model_config . think_end_id ,
1720
+ need_think_end = self .share_inputs ["need_think_end" ][:num_running_requests ],
1721
+ reasoning_index = self .share_inputs ["reasoning_index" ][:num_running_requests ],
1710
1722
stop_token_ids = self .share_inputs ["stop_seqs" ],
1711
1723
stop_seqs_len = self .share_inputs ["stop_seqs_len" ],
1712
1724
)
0 commit comments