Skip to content

Commit 791af13

Browse files
Configurable sampling parameter for generator (#550)
Co-authored-by: Felipe Mello <fmellomascarenhas@gmail.com>
1 parent 415c8e9 commit 791af13

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

src/forge/actors/generator.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -287,12 +287,20 @@ def split_keys(keys):
287287
return state_dict
288288

289289
@endpoint
290-
async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
290+
async def generate(
291+
self,
292+
prompt: str,
293+
*,
294+
priority: int = 0,
295+
sampling_params: SamplingParams | None = None,
296+
) -> list[Completion]:
291297
"""Generate a response for the given prompt
292298
293299
Args:
294300
prompt (str): The prompt to generate a response for.
295301
priority (int, optional): The priority of the request. Defaults to 0.
302+
sampling_params (SamplingParams, optional): Sampling parameters to use for this request.
303+
If not provided, uses self.sampling_params.
296304
297305
Returns:
298306
list[Completion]: n completions from vLLM based on your prompt.
@@ -301,12 +309,18 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
301309
t.start()
302310
record_metric("generator/generate/count_requests", 1, Reduce.SUM)
303311

312+
if sampling_params is not None:
313+
# as in `post_init`
314+
sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
315+
316+
params = sampling_params or self.sampling_params
317+
304318
self.request_id += 1 % sys.maxsize
305319
request_id = str(self.request_id)
306320

307321
tokenization_kwargs = {}
308322
# TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507
309-
truncate_prompt_tokens = self.sampling_params.truncate_prompt_tokens
323+
truncate_prompt_tokens = params.truncate_prompt_tokens
310324
_validate_truncation_size(
311325
self.vllm_config.model_config.max_model_len,
312326
truncate_prompt_tokens,
@@ -315,7 +329,7 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
315329
prompt_str, request = self.processor.process_inputs(
316330
request_id=request_id,
317331
prompt={"prompt": prompt},
318-
params=self.sampling_params,
332+
params=params,
319333
arrival_time=None,
320334
tokenization_kwargs=tokenization_kwargs,
321335
trace_headers=None,
@@ -331,21 +345,21 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
331345
await self.request_lock.wait_for(lambda: self.accepting_requests)
332346

333347
# Explicitly keeping the redundant logic to make it easier to pick up vLLM changes
334-
if (num_samples := self.sampling_params.n) == 1:
348+
if (num_samples := params.n) == 1:
335349
self.output_processor.add_request(request, prompt_str, None, 0)
336350
request, _ = self._preprocess_add_request(request)
337351
request_fut = asyncio.Future()
338352
self.requests[request_id] = (None, request_fut)
339353
self.scheduler.add_request(request)
340354
else:
341-
parent_req = ParentRequest(request_id, self.sampling_params)
355+
parent_req = ParentRequest(request_id, params)
342356
for idx in range(num_samples):
343357
# Note: `get_child_info` mutates ParentRequest to track the
344358
# generated child request
345-
child_request_id, params = parent_req.get_child_info(idx)
359+
child_request_id, params_child = parent_req.get_child_info(idx)
346360
child_request = request if idx == num_samples - 1 else copy(request)
347361
child_request.request_id = child_request_id
348-
child_request.sampling_params = params
362+
child_request.sampling_params = params_child
349363
self.output_processor.add_request(
350364
child_request, prompt_str, parent_req, idx
351365
)

0 commit comments

Comments
 (0)