@@ -287,12 +287,20 @@ def split_keys(keys):
287287 return state_dict
288288
289289 @endpoint
290- async def generate (self , prompt : str , * , priority : int = 0 ) -> list [Completion ]:
290+ async def generate (
291+ self ,
292+ prompt : str ,
293+ * ,
294+ priority : int = 0 ,
295+ sampling_params : SamplingParams | None = None ,
296+ ) -> list [Completion ]:
291297 """Generate a response for the given prompt
292298
293299 Args:
294300 prompt (str): The prompt to generate a response for.
295301 priority (int, optional): The priority of the request. Defaults to 0.
302+ sampling_params (SamplingParams, optional): Sampling parameters to use for this request.
303+ If not provided, uses self.sampling_params.
296304
297305 Returns:
298306 list[Completion]: n completions from vLLM based on your prompt.
@@ -301,12 +309,18 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
301309 t .start ()
302310 record_metric ("generator/generate/count_requests" , 1 , Reduce .SUM )
303311
312+ if sampling_params is not None :
313+ # as in `post_init`
314+ sampling_params .output_kind = RequestOutputKind .FINAL_ONLY
315+
316+ params = sampling_params or self .sampling_params
317+
304318 self .request_id += 1 % sys .maxsize
305319 request_id = str (self .request_id )
306320
307321 tokenization_kwargs = {}
308322 # TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507
309- truncate_prompt_tokens = self . sampling_params .truncate_prompt_tokens
323+ truncate_prompt_tokens = params .truncate_prompt_tokens
310324 _validate_truncation_size (
311325 self .vllm_config .model_config .max_model_len ,
312326 truncate_prompt_tokens ,
@@ -315,7 +329,7 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
315329 prompt_str , request = self .processor .process_inputs (
316330 request_id = request_id ,
317331 prompt = {"prompt" : prompt },
318- params = self . sampling_params ,
332+ params = params ,
319333 arrival_time = None ,
320334 tokenization_kwargs = tokenization_kwargs ,
321335 trace_headers = None ,
@@ -331,21 +345,21 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
331345 await self .request_lock .wait_for (lambda : self .accepting_requests )
332346
333347 # Explicitly keeping the redundant logic to make it easier to pick up vLLM changes
334- if (num_samples := self . sampling_params .n ) == 1 :
348+ if (num_samples := params .n ) == 1 :
335349 self .output_processor .add_request (request , prompt_str , None , 0 )
336350 request , _ = self ._preprocess_add_request (request )
337351 request_fut = asyncio .Future ()
338352 self .requests [request_id ] = (None , request_fut )
339353 self .scheduler .add_request (request )
340354 else :
341- parent_req = ParentRequest (request_id , self . sampling_params )
355+ parent_req = ParentRequest (request_id , params )
342356 for idx in range (num_samples ):
343357 # Note: `get_child_info` mutates ParentRequest to track the
344358 # generated child request
345- child_request_id , params = parent_req .get_child_info (idx )
359+ child_request_id , params_child = parent_req .get_child_info (idx )
346360 child_request = request if idx == num_samples - 1 else copy (request )
347361 child_request .request_id = child_request_id
348- child_request .sampling_params = params
362+ child_request .sampling_params = params_child
349363 self .output_processor .add_request (
350364 child_request , prompt_str , parent_req , idx
351365 )
0 commit comments