From 1fa186fe1e813e604eff1fa23377a21f07cee802 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Fri, 18 Jul 2025 13:17:05 -0400 Subject: [PATCH 01/10] Add turns support to synthetic dataset Signed-off-by: Samuel Monson --- src/guidellm/dataset/synthetic.py | 103 ++++++++++++++++++++---------- 1 file changed, 71 insertions(+), 32 deletions(-) diff --git a/src/guidellm/dataset/synthetic.py b/src/guidellm/dataset/synthetic.py index 8c30f0f7..06972643 100644 --- a/src/guidellm/dataset/synthetic.py +++ b/src/guidellm/dataset/synthetic.py @@ -3,7 +3,7 @@ from collections.abc import Iterable, Iterator from itertools import cycle from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Any, Optional, TypedDict, Union import yaml from datasets import ( @@ -69,6 +69,26 @@ class SyntheticDatasetConfig(BaseModel): gt=0, default=None, ) + turns: int = Field( + description="The number of turns in the conversation.", + gt=0, + default=1, + ) + turns_stdev: Optional[int] = Field( + description="The standard deviation of the number of turns.", + gt=0, + default=None, + ) + turns_min: Optional[int] = Field( + description="The minimum number of turns in the conversation.", + gt=0, + default=None, + ) + turns_max: Optional[int] = Field( + description="The maximum number of turns in the conversation.", + gt=0, + default=None, + ) samples: int = Field( description="The number of samples to generate for the dataset.", gt=0, @@ -124,14 +144,13 @@ def parse_config_file(data: Union[str, Path]) -> "SyntheticDatasetConfig": return SyntheticDatasetConfig(**config_dict) -class SyntheticTextItemsGenerator( - Iterable[ - dict[ - Literal["prompt", "prompt_tokens_count", "output_tokens_count"], - Union[str, int], - ] - ] -): +class SyntheticDatasetRow(TypedDict): + prompt: list[str] + prompt_tokens_count: list[int] + output_tokens_count: list[int] + + +class SyntheticTextItemsGenerator(Iterable[SyntheticDatasetRow]): def __init__( self, config: SyntheticDatasetConfig, @@ -147,12 +166,7 @@ def __init__( def __iter__( self, - ) -> Iterator[ - dict[ - Literal["prompt", "prompt_tokens_count", "output_tokens_count"], - Union[str, int], - ] - ]: + ) -> Iterator[SyntheticDatasetRow]: prompt_tokens_sampler = IntegerRangeSampler( average=self.config.prompt_tokens, variance=self.config.prompt_tokens_stdev, @@ -167,6 +181,13 @@ def __iter__( max_value=self.config.output_tokens_max, random_seed=self.random_seed + 1, # ensure diff dist from prompts ) + turns_sampler = IntegerRangeSampler( + average=self.config.turns, + variance=self.config.turns_stdev, + min_value=self.config.turns_min, + max_value=self.config.turns_max, + random_seed=self.random_seed + 7, # ensure diff dist + ) # ensure diff distribution from output tokens rand = random.Random(self.random_seed + 2) # noqa: S311 unique_prefix_iter = cycle(self.processor.get_vocab().values()) @@ -174,24 +195,42 @@ def __iter__( prefix_index = rand.randint(0, len(self.text_creator.words)) prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index) - for _, prompt_tokens, output_tokens in zip( - range(self.config.samples), - prompt_tokens_sampler, - output_tokens_sampler, - ): - start_index = rand.randint(0, len(self.text_creator.words)) - prompt_text = self.processor.decode( - prefix_tokens - + self._create_prompt( - prompt_tokens, start_index, next(unique_prefix_iter) - ), - skip_special_tokens=True, - ) - yield { - "prompt": prompt_text, - "prompt_tokens_count": self.config.prefix_tokens + prompt_tokens, - "output_tokens_count": output_tokens, + for _, turns in zip(range(self.config.samples), turns_sampler): + row: SyntheticDatasetRow = { + "prompt": [], + "prompt_tokens_count": [], + "output_tokens_count": [], } + for i, prompt_tokens, output_tokens in zip( + range(turns), + prompt_tokens_sampler, + output_tokens_sampler, + ): + start_index = rand.randint(0, len(self.text_creator.words)) + # Append the prefix tokens only for the first turn + if i == 0: + prompt_text = self.processor.decode( + prefix_tokens + + self._create_prompt( + prompt_tokens, start_index, next(unique_prefix_iter) + ), + skip_special_tokens=True, + ) + row["prompt"].append(prompt_text) + row["prompt_tokens_count"].append(self.config.prefix_tokens + prompt_tokens) + row["output_tokens_count"].append(output_tokens) + else: + prompt_text = self.processor.decode( + self._create_prompt( + prompt_tokens, start_index, next(unique_prefix_iter) + ), + skip_special_tokens=True, + ) + row["prompt"].append(prompt_text) + row["prompt_tokens_count"].append(prompt_tokens) + row["output_tokens_count"].append(output_tokens) + + yield row def _create_prompt( self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None From 7efb7b174336881be06a298d9c214eec650c4d1a Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Tue, 23 Sep 2025 15:59:05 -0400 Subject: [PATCH 02/10] Add basic multiturn loader support Signed-off-by: Samuel Monson --- src/guidellm/request/loader.py | 47 +++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/src/guidellm/request/loader.py b/src/guidellm/request/loader.py index 607a7455..e23e3111 100644 --- a/src/guidellm/request/loader.py +++ b/src/guidellm/request/loader.py @@ -105,14 +105,14 @@ def __init__( self.preserve_iter_state = iter_type == "infinite" # ensure no caching requests self._preserved_iter = None - def __iter__(self) -> Iterator[GenerationRequest]: + def __iter__(self) -> Iterator[list[GenerationRequest]]: scope_create_count = 0 while (dataset_iter := self._get_dataset_iter(scope_create_count)) is not None: scope_create_count += 1 for item in dataset_iter: - yield self._create_request(item) + yield self._create_requests(item) self._preserved_iter = None @@ -260,25 +260,36 @@ def _get_dataset_iter( return dataset_iter - def _create_request(self, item: dict[str, Any]) -> GenerationRequest: - prompt_tokens = ( - item[self.column_mappings["prompt_tokens_count_column"]] + def _create_requests(self, item: dict[str, Any]) -> list[GenerationRequest]: + prompts = list(item[self.column_mappings["prompt_column"]]) + prompts_tokens: list[Optional[int]] = ( + list(item[self.column_mappings["prompt_tokens_count_column"]]) if "prompt_tokens_count_column" in self.column_mappings - else None + else [None] * len(prompts) ) - output_tokens = ( - item[self.column_mappings["output_tokens_count_column"]] + outputs_tokens: list[Optional[int]] = ( + list(item[self.column_mappings["output_tokens_count_column"]]) if "output_tokens_count_column" in self.column_mappings - else None + else [None] * len(prompts) ) - return GenerationRequest( - request_type=settings.preferred_route, - content=item[self.column_mappings["prompt_column"]], - stats=( + if len(prompts) != len(prompts_tokens) != len(outputs_tokens): + raise ValueError( + "Mismatched lengths between prompts and token counts. " + f"Prompts: {len(prompts)}, Prompt Tokens: {len(prompts_tokens)}, " + f"Output Tokens: {len(outputs_tokens)}" + ) + + return [ + GenerationRequest( + request_type=settings.preferred_route, + content=prompt, + stats=( {"prompt_tokens": prompt_tokens} if prompt_tokens is not None else {} - ), - constraints=( - {"output_tokens": output_tokens} if output_tokens is not None else {} - ), - ) + ), + constraints=( + {"output_tokens": output_tokens} if output_tokens is not None else {} + ), + ) + for prompt, prompt_tokens, output_tokens in zip(prompts, prompts_tokens, outputs_tokens) + ] From 3f0cdbc1f5594d74d91bf0443cf652fd851f1182 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Thu, 25 Sep 2025 15:42:05 -0400 Subject: [PATCH 03/10] Make dict encoding recursive Signed-off-by: Samuel Monson --- src/guidellm/utils/encoding.py | 37 ++++++++-------------------------- 1 file changed, 8 insertions(+), 29 deletions(-) diff --git a/src/guidellm/utils/encoding.py b/src/guidellm/utils/encoding.py index ccd26982..d4fa007b 100644 --- a/src/guidellm/utils/encoding.py +++ b/src/guidellm/utils/encoding.py @@ -390,23 +390,11 @@ def to_dict(self, obj: Any) -> Any: if isinstance(obj, BaseModel): return self.to_dict_pydantic(obj) - if isinstance(obj, (list, tuple)) and any( - isinstance(item, BaseModel) for item in obj - ): - return [ - self.to_dict_pydantic(item) if isinstance(item, BaseModel) else item - for item in obj - ] + if isinstance(obj, (list, tuple)): + return [self.to_dict(item) for item in obj] - if isinstance(obj, dict) and any( - isinstance(value, BaseModel) for value in obj.values() - ): - return { - key: self.to_dict_pydantic(value) - if isinstance(value, BaseModel) - else value - for key, value in obj.items() - } + if isinstance(obj, dict): + return {key: self.to_dict(value) for key, value in obj.items()} return obj @@ -418,22 +406,13 @@ def from_dict(self, data: Any) -> Any: :return: Reconstructed object with proper types restored """ if isinstance(data, (list, tuple)): - return [ - self.from_dict_pydantic(item) - if isinstance(item, dict) and "*PYD*" in item - else item - for item in data - ] - elif isinstance(data, dict) and data: + return [self.from_dict(item) for item in data] + + if isinstance(data, dict) and data: if "*PYD*" in data: return self.from_dict_pydantic(data) - return { - key: self.from_dict_pydantic(value) - if isinstance(value, dict) and "*PYD*" in value - else value - for key, value in data.items() - } + return {key: self.from_dict(value) for key, value in data.items()} return data From 220377e31b3c9f9b14a33e991e6ebfc45fb88c08 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Thu, 25 Sep 2025 15:54:13 -0400 Subject: [PATCH 04/10] Use details for next request in chain Signed-off-by: Samuel Monson --- src/guidellm/scheduler/worker_group.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/guidellm/scheduler/worker_group.py b/src/guidellm/scheduler/worker_group.py index c1d516f1..355ca86b 100644 --- a/src/guidellm/scheduler/worker_group.py +++ b/src/guidellm/scheduler/worker_group.py @@ -496,7 +496,11 @@ def _iter(): count = 0 request_info: ScheduledRequestInfo = None - for request in _iter(): + for request_chain in _iter(): + if isinstance(request_chain, (list, tuple)): + request = request_chain[0] + else: + request = request_chain count += 1 if hasattr(request, "request_id"): From 3ac4df61a555d3b195f6bcf93c15ffe0a8f3d17d Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Fri, 26 Sep 2025 12:10:21 -0400 Subject: [PATCH 05/10] Implement worker support for multiturn Signed-off-by: Samuel Monson --- src/guidellm/request/loader.py | 33 +++++++++----- src/guidellm/scheduler/__init__.py | 8 ++++ src/guidellm/scheduler/objects.py | 44 +++++++++++++++--- src/guidellm/scheduler/worker.py | 25 ++++++++-- src/guidellm/scheduler/worker_group.py | 63 +++++++++++++++----------- 5 files changed, 124 insertions(+), 49 deletions(-) diff --git a/src/guidellm/request/loader.py b/src/guidellm/request/loader.py index e23e3111..81aae8fb 100644 --- a/src/guidellm/request/loader.py +++ b/src/guidellm/request/loader.py @@ -105,7 +105,7 @@ def __init__( self.preserve_iter_state = iter_type == "infinite" # ensure no caching requests self._preserved_iter = None - def __iter__(self) -> Iterator[list[GenerationRequest]]: + def __iter__(self) -> Iterator[list[tuple[GenerationRequest, float]]]: scope_create_count = 0 while (dataset_iter := self._get_dataset_iter(scope_create_count)) is not None: @@ -260,7 +260,9 @@ def _get_dataset_iter( return dataset_iter - def _create_requests(self, item: dict[str, Any]) -> list[GenerationRequest]: + def _create_requests( + self, item: dict[str, Any] + ) -> list[tuple[GenerationRequest, float]]: prompts = list(item[self.column_mappings["prompt_column"]]) prompts_tokens: list[Optional[int]] = ( list(item[self.column_mappings["prompt_tokens_count_column"]]) @@ -281,15 +283,24 @@ def _create_requests(self, item: dict[str, Any]) -> list[GenerationRequest]: ) return [ - GenerationRequest( - request_type=settings.preferred_route, - content=prompt, - stats=( - {"prompt_tokens": prompt_tokens} if prompt_tokens is not None else {} - ), - constraints=( - {"output_tokens": output_tokens} if output_tokens is not None else {} + ( + GenerationRequest( + request_type=settings.preferred_route, + content=prompt, + stats=( + {"prompt_tokens": prompt_tokens} + if prompt_tokens is not None + else {} + ), + constraints=( + {"output_tokens": output_tokens} + if output_tokens is not None + else {} + ), ), + 0.0, # TODO: delay + ) + for prompt, prompt_tokens, output_tokens in zip( + prompts, prompts_tokens, outputs_tokens ) - for prompt, prompt_tokens, output_tokens in zip(prompts, prompts_tokens, outputs_tokens) ] diff --git a/src/guidellm/scheduler/__init__.py b/src/guidellm/scheduler/__init__.py index 64647424..cb225460 100644 --- a/src/guidellm/scheduler/__init__.py +++ b/src/guidellm/scheduler/__init__.py @@ -15,16 +15,20 @@ from .objects import ( BackendInterface, BackendT, + HistoryT, MeasuredRequestTimings, MultiTurnRequestT, + MultiTurnT, RequestSchedulerTimings, RequestT, ResponseT, + ScheduledRequestAugmentation, ScheduledRequestInfo, SchedulerMessagingPydanticRegistry, SchedulerState, SchedulerUpdateAction, SchedulerUpdateActionProgress, + TurnT, ) from .scheduler import Scheduler from .strategies import ( @@ -56,6 +60,7 @@ "ConstraintInitializer", "ConstraintsInitializerFactory", "Environment", + "HistoryT", "LastCompletionRequestTimings", "MaxDurationConstraint", "MaxErrorRateConstraint", @@ -64,6 +69,7 @@ "MaxNumberConstraint", "MeasuredRequestTimings", "MultiTurnRequestT", + "MultiTurnT", "NoDelayRequestTimings", "NonDistributedEnvironment", "PoissonRateRequestTimings", @@ -71,6 +77,7 @@ "RequestSchedulerTimings", "RequestT", "ResponseT", + "ScheduledRequestAugmentation", "ScheduledRequestInfo", "ScheduledRequestTimings", "Scheduler", @@ -84,6 +91,7 @@ "StrategyType", "SynchronousStrategy", "ThroughputStrategy", + "TurnT", "UnserializableConstraintInitializer", "WorkerProcess", "WorkerProcessGroup", diff --git a/src/guidellm/scheduler/objects.py b/src/guidellm/scheduler/objects.py index b7f2efc3..a58d9225 100644 --- a/src/guidellm/scheduler/objects.py +++ b/src/guidellm/scheduler/objects.py @@ -19,7 +19,6 @@ Literal, Protocol, TypeVar, - Union, ) from pydantic import Field, computed_field @@ -35,34 +34,50 @@ __all__ = [ "BackendInterface", "BackendT", + "HistoryT", "MeasuredRequestTimings", "MultiTurnRequestT", + "MultiTurnT", "RequestSchedulerTimings", "RequestT", "ResponseT", + "ScheduledRequestAugmentation", "ScheduledRequestInfo", "SchedulerMessagingPydanticRegistry", "SchedulerState", "SchedulerUpdateAction", "SchedulerUpdateActionProgress", + "TurnT", ] RequestT = TypeVar("RequestT") """Generic request object type for scheduler processing.""" +# TODO: Remove +MultiTurnRequestT = RequestT + ResponseT = TypeVar("ResponseT") """Generic response object type returned by backend processing.""" -MultiTurnRequestT = TypeAliasType( - "MultiTurnRequestT", - Union[ - list[Union[RequestT, tuple[RequestT, float]]], - tuple[Union[RequestT, tuple[RequestT, float]]], - ], +TurnT = TypeAliasType( + "TurnT", + tuple[RequestT, "ScheduledRequestAugmentation", "ScheduledRequestInfo"], + type_params=(RequestT,), +) + +MultiTurnT = TypeAliasType( + "MultiTurnT", + list[TurnT[RequestT]], type_params=(RequestT,), ) """Multi-turn request structure supporting conversation history with optional delays.""" +HistoryT = TypeAliasType( + "HistoryT", + list[tuple[RequestT, ResponseT]], + type_params=(RequestT, ResponseT), +) + class SchedulerMessagingPydanticRegistry(RegistryMixin[RegistryObjT]): """ @@ -71,6 +86,21 @@ class SchedulerMessagingPydanticRegistry(RegistryMixin[RegistryObjT]): """ +@SchedulerMessagingPydanticRegistry.register() +class ScheduledRequestAugmentation(StandardBaseModel): + """ + Adjustments to scheduler logic for a paired request. + """ + + post_requeue_delay: float = Field( + description=( + "Delay in seconds to wait after a request to " + "queue the next request in the conversation." + ), + default=0.0, + ) + + @SchedulerMessagingPydanticRegistry.register() class RequestSchedulerTimings(StandardBaseModel): """ diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index 5f2fb74b..4513fe3a 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -31,9 +31,12 @@ from guidellm.scheduler.objects import ( BackendInterface, + HistoryT, MultiTurnRequestT, + MultiTurnT, RequestT, ResponseT, + ScheduledRequestAugmentation, ScheduledRequestInfo, SchedulerMessagingPydanticRegistry, ) @@ -118,6 +121,9 @@ def __init__( self.startup_completed = False self.backend_started = False self.messaging_started = False + self.turns_queue: list[ + tuple[HistoryT[RequestT, ResponseT], MultiTurnT[RequestT]] + ] = [] def run(self): """ @@ -302,16 +308,19 @@ async def _cancel_requests_loop(self): self._send_update("cancelled", None, request, request_info) async def _process_next_request(self): - request: RequestT | MultiTurnRequestT[RequestT] | None = None + request: RequestT | None = None request_info: ScheduledRequestInfo | None = None response: ResponseT | None = None + aug: ScheduledRequestAugmentation | None = None try: # Pull request from the queue - request, request_info = await self.messaging.get() - - if isinstance(request, (list, tuple)): - raise NotImplementedError("Multi-turn requests are not yet supported") + history, conversation = ( + self.turns_queue.pop(0) + if self.turns_queue + else ([], await self.messaging.get()) + ) + request, aug, request_info = conversation.pop(0) # Calculate targeted start and set pending state for request request_info.scheduler_node_id = self.messaging.worker_index @@ -341,6 +350,12 @@ async def _process_next_request(self): request_info.scheduler_timings.resolve_end = time.time() self._send_update("completed", response, request, request_info) + # If multi-turn, queue up next turn(s) + # TODO: Move to callback and support delay + if conversation: # more turns to process + history.append((request, response)) + self.turns_queue.append((history, conversation)) + response = request = request_info = None except asyncio.CancelledError: # Handle cancellation diff --git a/src/guidellm/scheduler/worker_group.py b/src/guidellm/scheduler/worker_group.py index 355ca86b..221e95e1 100644 --- a/src/guidellm/scheduler/worker_group.py +++ b/src/guidellm/scheduler/worker_group.py @@ -26,8 +26,10 @@ from guidellm.scheduler.objects import ( BackendInterface, MultiTurnRequestT, + MultiTurnT, RequestT, ResponseT, + ScheduledRequestAugmentation, ScheduledRequestInfo, SchedulerMessagingPydanticRegistry, SchedulerState, @@ -471,9 +473,9 @@ def __init__( def requests_generator( self, - requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, - cycle_requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, - ) -> Generator[tuple[RequestT | MultiTurnRequestT[RequestT],], None, None]: + requests: Iterable[Iterable[tuple[RequestT, float]]] | None, + cycle_requests: Iterable[Iterable[tuple[RequestT, float]]] | None, + ) -> Generator[MultiTurnT[RequestT], None, None]: """ Generate request-info pairs for worker processing with constraint evaluation. @@ -494,31 +496,40 @@ def _iter(): while True: yield from cycle_requests - count = 0 - request_info: ScheduledRequestInfo = None + count: int = 0 + stop_queueing: bool = False + + def _turn_iter(requests_chain: Iterable[tuple[RequestT, float]]): + nonlocal count, stop_queueing + for request, delay in requests_chain: + count += 1 + + if hasattr(request, "request_id"): + request_id = request.request_id + elif hasattr(request, "id"): + request_id = request.id + else: + request_id = str(uuid.uuid4()) + request_augmentation = ScheduledRequestAugmentation( + post_requeue_delay=delay + ) + request_info: ScheduledRequestInfo = ScheduledRequestInfo( + request_id=request_id, + status="queued", + scheduler_process_id=0, + scheduler_start_time=self.start_time, + ) + state_update = self._locked_update(request_info) + yield (request, request_augmentation, request_info) + + if state_update.stop_queueing: + stop_queueing = True + return + for request_chain in _iter(): - if isinstance(request_chain, (list, tuple)): - request = request_chain[0] - else: - request = request_chain - count += 1 - - if hasattr(request, "request_id"): - request_id = request.request_id - elif hasattr(request, "id"): - request_id = request.id - else: - request_id = str(uuid.uuid4()) - request_info: ScheduledRequestInfo = ScheduledRequestInfo( - request_id=request_id, - status="queued", - scheduler_process_id=0, - scheduler_start_time=self.start_time, - ) - state_update = self._locked_update(request_info) - yield (request, request_info) + yield list(_turn_iter(request_chain)) - if state_update.stop_queueing: + if stop_queueing: self.stop_send_requests_event.set() return From a7bf6900fc77125bc3d887701c8a4e89856b89c2 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Fri, 26 Sep 2025 12:55:42 -0400 Subject: [PATCH 06/10] Cancel requests in conversation Signed-off-by: Samuel Monson --- src/guidellm/scheduler/worker.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index 4513fe3a..155552a8 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -296,16 +296,22 @@ async def _cancel_requests_loop(self): try: request: RequestT request_info: ScheduledRequestInfo - request, request_info = await self.messaging.get( - timeout=self.messaging.poll_interval + _, conversation = ( + self.turns_queue.pop(0) + if self.turns_queue + else ( + None, + await self.messaging.get(timeout=self.messaging.poll_interval), + ) ) except asyncio.TimeoutError: continue - request_info.scheduler_node_id = self.messaging.worker_index - request_info.error = "Request was cancelled" - request_info.scheduler_timings.resolve_end = time.time() - self._send_update("cancelled", None, request, request_info) + for request, _, request_info in conversation: + request_info.scheduler_node_id = self.messaging.worker_index + request_info.error = "Request was cancelled" + request_info.scheduler_timings.resolve_end = time.time() + self._send_update("cancelled", None, request, request_info) async def _process_next_request(self): request: RequestT | None = None From e276f6c091a9c995d85644dfbbb32ec24a344d33 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Fri, 26 Sep 2025 14:38:11 -0400 Subject: [PATCH 07/10] Cancel whole conversation Signed-off-by: Samuel Monson --- src/guidellm/scheduler/worker.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index 155552a8..33be659f 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -314,6 +314,7 @@ async def _cancel_requests_loop(self): self._send_update("cancelled", None, request, request_info) async def _process_next_request(self): + conversation: MultiTurnT[RequestT] | None = None request: RequestT | None = None request_info: ScheduledRequestInfo | None = None response: ResponseT | None = None @@ -362,7 +363,7 @@ async def _process_next_request(self): history.append((request, response)) self.turns_queue.append((history, conversation)) - response = request = request_info = None + response = request = request_info = conversation = None except asyncio.CancelledError: # Handle cancellation if request is not None and request_info is not None: @@ -375,6 +376,12 @@ async def _process_next_request(self): request_info.error = str(exc) request_info.scheduler_timings.resolve_end = time.time() self._send_update("errored", response, request, request_info) + finally: + if conversation is not None: + for request, _, request_info in conversation: + request_info.error = "Request was cancelled" + request_info.scheduler_timings.resolve_end = time.time() + self._send_update("cancelled", response, request, request_info) def _send_update( self, From 1de1c64a072b00512c97b364d62474d5ff436226 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Fri, 26 Sep 2025 14:57:52 -0400 Subject: [PATCH 08/10] Implement multiturn history in openai backend Signed-off-by: Samuel Monson --- src/guidellm/backends/openai.py | 27 +++++++++++++++++++++------ src/guidellm/scheduler/worker.py | 4 +++- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/src/guidellm/backends/openai.py b/src/guidellm/backends/openai.py index ce83076f..acce5f88 100644 --- a/src/guidellm/backends/openai.py +++ b/src/guidellm/backends/openai.py @@ -16,6 +16,7 @@ import json import time from collections.abc import AsyncIterator +from itertools import chain from pathlib import Path from typing import Any, ClassVar, Optional, Union @@ -29,7 +30,7 @@ GenerationRequestTimings, GenerationResponse, ) -from guidellm.scheduler import ScheduledRequestInfo +from guidellm.scheduler import HistoryT, ScheduledRequestInfo __all__ = ["OpenAIHTTPBackend", "UsageStats"] @@ -280,7 +281,7 @@ async def resolve( self, request: GenerationRequest, request_info: ScheduledRequestInfo, - history: Optional[list[tuple[GenerationRequest, GenerationResponse]]] = None, + history: Optional[HistoryT[GenerationRequest, GenerationResponse]] = None, ) -> AsyncIterator[tuple[GenerationResponse, ScheduledRequestInfo]]: """ Process a generation request and yield progressive responses. @@ -295,10 +296,8 @@ async def resolve( :yields: Tuples of (response, updated_request_info) as generation progresses. """ self._check_in_process() - if history is not None: - raise NotImplementedError( - "Multi-turn requests with conversation history are not yet supported" - ) + if history: + request = self._apply_history(request, history) response = GenerationResponse( request_id=request.request_id, @@ -500,6 +499,22 @@ async def chat_completions( self._get_completions_usage_stats(data), ) + def _apply_history( + self, + request: GenerationRequest, + history: HistoryT[GenerationRequest, GenerationResponse], + ) -> GenerationRequest: + """ + Apply conversation history to the current request. + """ + + def turn_to_text(turn: tuple[GenerationRequest, GenerationResponse]) -> str: + req, res = turn + return f"{req.content}{res.value}" + + request.content = "".join(chain(map(turn_to_text, history), (request.content,))) + return request + def _build_headers( self, api_key: Optional[str], diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index 33be659f..3c980e60 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -349,7 +349,9 @@ async def _process_next_request(self): # Process the request with the backend request_info.scheduler_timings.resolve_start = time.time() self._send_update("in_progress", response, request, request_info) - async for resp, info in self.backend.resolve(request, request_info, None): + async for resp, info in self.backend.resolve( + request, request_info, history + ): response = resp request_info = info From 0e8713c776958f230171414fbe271e52ef63198f Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Fri, 26 Sep 2025 15:49:53 -0400 Subject: [PATCH 09/10] Add wait_then_requeue behavior Signed-off-by: Samuel Monson --- src/guidellm/scheduler/worker.py | 62 +++++++++++++++++++++++++------- 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index 3c980e60..81b9e2b1 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -13,7 +13,7 @@ import time from multiprocessing.synchronize import Barrier as ProcessingBarrier from multiprocessing.synchronize import Event as ProcessingEvent -from typing import Annotated, Generic, Literal +from typing import Annotated, Generic, Literal, TypeAliasType try: import uvloop @@ -50,6 +50,16 @@ __all__ = ["WorkerProcess"] +ProcessRequestT = TypeAliasType( + "ProcessRequestT", + tuple[ + HistoryT[RequestT, ResponseT], + MultiTurnT[RequestT], + ScheduledRequestAugmentation, + ], + type_params=(RequestT, ResponseT), +) + class WorkerProcess(Generic[RequestT, ResponseT]): """ @@ -271,12 +281,20 @@ async def _process_requests_loop(self): async_semaphore = asyncio.Semaphore(self.async_limit) pending_tasks: set[asyncio.Task] = set() - def _task_done(task): + def _task_done(task: asyncio.Task[ProcessRequestT[RequestT, ResponseT]]): pending_tasks.discard(task) async_semaphore.release() - if not task.cancelled() and (exception := task.exception()): - raise exception + if not task.cancelled(): + if exception := task.exception(): + raise exception + + history, conversation, aug = task.result() + if conversation: + requeue_task = asyncio.create_task( + self._wait_then_requeue(history, conversation, aug) + ) + pending_tasks.add(requeue_task) # Main loop; loop until canceled while True: @@ -313,12 +331,14 @@ async def _cancel_requests_loop(self): request_info.scheduler_timings.resolve_end = time.time() self._send_update("cancelled", None, request, request_info) - async def _process_next_request(self): - conversation: MultiTurnT[RequestT] | None = None + async def _process_next_request(self) -> ProcessRequestT[RequestT, ResponseT]: + conversation: MultiTurnT[RequestT] = [] + history: HistoryT[RequestT, ResponseT] = [] request: RequestT | None = None request_info: ScheduledRequestInfo | None = None response: ResponseT | None = None aug: ScheduledRequestAugmentation | None = None + premature_exit: bool = False try: # Pull request from the queue @@ -359,14 +379,12 @@ async def _process_next_request(self): request_info.scheduler_timings.resolve_end = time.time() self._send_update("completed", response, request, request_info) - # If multi-turn, queue up next turn(s) - # TODO: Move to callback and support delay - if conversation: # more turns to process - history.append((request, response)) - self.turns_queue.append((history, conversation)) + # Record Turn + history.append((request, response)) - response = request = request_info = conversation = None + response = request = request_info = None except asyncio.CancelledError: + premature_exit = True # Handle cancellation if request is not None and request_info is not None: request_info.error = "Request was cancelled" @@ -374,17 +392,35 @@ async def _process_next_request(self): self._send_update("cancelled", response, request, request_info) raise except Exception as exc: # noqa: BLE001 + premature_exit = True if request is not None and request_info is not None: request_info.error = str(exc) request_info.scheduler_timings.resolve_end = time.time() self._send_update("errored", response, request, request_info) finally: - if conversation is not None: + if premature_exit and conversation: for request, _, request_info in conversation: request_info.error = "Request was cancelled" request_info.scheduler_timings.resolve_end = time.time() self._send_update("cancelled", response, request, request_info) + return history, conversation, aug + + async def _wait_then_requeue( + self, + history: HistoryT[RequestT, ResponseT], + conversation: MultiTurnT[RequestT], + aug: ScheduledRequestAugmentation, + ): + try: + if aug.post_requeue_delay > 0: + await asyncio.sleep(aug.post_requeue_delay) + except asyncio.CancelledError: + # If we are cancelled, dump straight to queue + raise + finally: + self.turns_queue.append((history, conversation)) + def _send_update( self, new_status: Literal[ From cd43b2cf768d21201ec50e17c8025de54e78474b Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Fri, 26 Sep 2025 16:55:18 -0400 Subject: [PATCH 10/10] Type cleanup Signed-off-by: Samuel Monson --- src/guidellm/scheduler/__init__.py | 10 ++++----- src/guidellm/scheduler/environments.py | 16 +++++++------- src/guidellm/scheduler/objects.py | 28 +++++++++++------------- src/guidellm/scheduler/scheduler.py | 6 +++--- src/guidellm/scheduler/worker.py | 20 +++++++---------- src/guidellm/scheduler/worker_group.py | 30 ++++++++++---------------- tests/unit/scheduler/test_objects.py | 16 -------------- 7 files changed, 46 insertions(+), 80 deletions(-) diff --git a/src/guidellm/scheduler/__init__.py b/src/guidellm/scheduler/__init__.py index cb225460..4eff5c12 100644 --- a/src/guidellm/scheduler/__init__.py +++ b/src/guidellm/scheduler/__init__.py @@ -15,10 +15,10 @@ from .objects import ( BackendInterface, BackendT, + DatasetIterT, HistoryT, MeasuredRequestTimings, - MultiTurnRequestT, - MultiTurnT, + RequestDataT, RequestSchedulerTimings, RequestT, ResponseT, @@ -28,7 +28,6 @@ SchedulerState, SchedulerUpdateAction, SchedulerUpdateActionProgress, - TurnT, ) from .scheduler import Scheduler from .strategies import ( @@ -59,6 +58,7 @@ "Constraint", "ConstraintInitializer", "ConstraintsInitializerFactory", + "DatasetIterT", "Environment", "HistoryT", "LastCompletionRequestTimings", @@ -68,12 +68,11 @@ "MaxGlobalErrorRateConstraint", "MaxNumberConstraint", "MeasuredRequestTimings", - "MultiTurnRequestT", - "MultiTurnT", "NoDelayRequestTimings", "NonDistributedEnvironment", "PoissonRateRequestTimings", "PydanticConstraintInitializer", + "RequestDataT", "RequestSchedulerTimings", "RequestT", "ResponseT", @@ -91,7 +90,6 @@ "StrategyType", "SynchronousStrategy", "ThroughputStrategy", - "TurnT", "UnserializableConstraintInitializer", "WorkerProcess", "WorkerProcessGroup", diff --git a/src/guidellm/scheduler/environments.py b/src/guidellm/scheduler/environments.py index 6234f8f6..a9853544 100644 --- a/src/guidellm/scheduler/environments.py +++ b/src/guidellm/scheduler/environments.py @@ -19,14 +19,14 @@ import time from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Iterable +from collections.abc import AsyncIterator from typing import ( Generic, ) from guidellm.scheduler.constraints import Constraint from guidellm.scheduler.objects import ( - MultiTurnRequestT, + DatasetIterT, RequestT, ResponseT, ScheduledRequestInfo, @@ -52,11 +52,11 @@ class Environment(ABC, Generic[RequestT, ResponseT], InfoMixin): @abstractmethod async def sync_run_params( self, - requests: Iterable[RequestT | MultiTurnRequestT[RequestT]], + requests: DatasetIterT[RequestT], strategy: SchedulingStrategy, constraints: dict[str, Constraint], ) -> tuple[ - Iterable[RequestT | MultiTurnRequestT[RequestT]], + DatasetIterT[RequestT], SchedulingStrategy, dict[str, Constraint], ]: @@ -130,7 +130,7 @@ async def sync_run_end( ) -> AsyncIterator[ tuple[ ResponseT, - RequestT | MultiTurnRequestT[RequestT], + RequestT, ScheduledRequestInfo, SchedulerState, ] @@ -194,11 +194,11 @@ def __init__(self): async def sync_run_params( self, - requests: Iterable[RequestT | MultiTurnRequestT[RequestT]], + requests: DatasetIterT[RequestT], strategy: SchedulingStrategy, constraints: dict[str, Constraint], ) -> tuple[ - Iterable[RequestT | MultiTurnRequestT[RequestT]], + DatasetIterT[RequestT], SchedulingStrategy, dict[str, Constraint], ]: @@ -250,7 +250,7 @@ async def sync_run_end( ) -> AsyncIterator[ tuple[ ResponseT, - RequestT | MultiTurnRequestT[RequestT], + RequestT, ScheduledRequestInfo, SchedulerState, ] diff --git a/src/guidellm/scheduler/objects.py b/src/guidellm/scheduler/objects.py index a58d9225..e7d4c6c7 100644 --- a/src/guidellm/scheduler/objects.py +++ b/src/guidellm/scheduler/objects.py @@ -11,7 +11,7 @@ import time import uuid -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterable from typing import ( Any, ClassVar, @@ -34,10 +34,10 @@ __all__ = [ "BackendInterface", "BackendT", + "DatasetIterT", "HistoryT", "MeasuredRequestTimings", - "MultiTurnRequestT", - "MultiTurnT", + "RequestDataT", "RequestSchedulerTimings", "RequestT", "ResponseT", @@ -47,36 +47,32 @@ "SchedulerState", "SchedulerUpdateAction", "SchedulerUpdateActionProgress", - "TurnT", ] RequestT = TypeVar("RequestT") """Generic request object type for scheduler processing.""" -# TODO: Remove -MultiTurnRequestT = RequestT - ResponseT = TypeVar("ResponseT") """Generic response object type returned by backend processing.""" -TurnT = TypeAliasType( - "TurnT", +RequestDataT = TypeAliasType( + "RequestDataT", tuple[RequestT, "ScheduledRequestAugmentation", "ScheduledRequestInfo"], type_params=(RequestT,), ) - -MultiTurnT = TypeAliasType( - "MultiTurnT", - list[TurnT[RequestT]], - type_params=(RequestT,), -) -"""Multi-turn request structure supporting conversation history with optional delays.""" +"""Request including external metadata and scheduling config.""" HistoryT = TypeAliasType( "HistoryT", list[tuple[RequestT, ResponseT]], type_params=(RequestT, ResponseT), ) +"""Record of requests + responses in conversation.""" + + +DatasetIterT = TypeAliasType( + "DatasetIterT", Iterable[Iterable[tuple[RequestT, float]]], type_params=(RequestT,) +) class SchedulerMessagingPydanticRegistry(RegistryMixin[RegistryObjT]): diff --git a/src/guidellm/scheduler/scheduler.py b/src/guidellm/scheduler/scheduler.py index e7d8b2c6..43948d18 100644 --- a/src/guidellm/scheduler/scheduler.py +++ b/src/guidellm/scheduler/scheduler.py @@ -10,7 +10,7 @@ from __future__ import annotations -from collections.abc import AsyncIterator, Iterable +from collections.abc import AsyncIterator from typing import Any, Generic from guidellm.scheduler.constraints import ( @@ -20,7 +20,7 @@ from guidellm.scheduler.environments import Environment, NonDistributedEnvironment from guidellm.scheduler.objects import ( BackendInterface, - MultiTurnRequestT, + DatasetIterT, RequestT, ResponseT, ScheduledRequestInfo, @@ -66,7 +66,7 @@ class Scheduler( async def run( self, - requests: Iterable[RequestT | MultiTurnRequestT[RequestT]], + requests: DatasetIterT[RequestT], backend: BackendInterface[RequestT, ResponseT], strategy: SchedulingStrategy, env: Environment | None, diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index 81b9e2b1..4c5903fb 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -32,8 +32,7 @@ from guidellm.scheduler.objects import ( BackendInterface, HistoryT, - MultiTurnRequestT, - MultiTurnT, + RequestDataT, RequestT, ResponseT, ScheduledRequestAugmentation, @@ -54,7 +53,7 @@ "ProcessRequestT", tuple[ HistoryT[RequestT, ResponseT], - MultiTurnT[RequestT], + list[RequestDataT[RequestT]], ScheduledRequestAugmentation, ], type_params=(RequestT, ResponseT), @@ -87,11 +86,8 @@ class WorkerProcess(Generic[RequestT, ResponseT]): def __init__( self, messaging: InterProcessMessaging[ - tuple[ - ResponseT | None, - RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo, - ], + tuple[ResponseT | None, RequestT, ScheduledRequestInfo], + list[RequestDataT[RequestT]], ], backend: BackendInterface[RequestT, ResponseT], request_timings: ScheduledRequestTimings, @@ -132,7 +128,7 @@ def __init__( self.backend_started = False self.messaging_started = False self.turns_queue: list[ - tuple[HistoryT[RequestT, ResponseT], MultiTurnT[RequestT]] + tuple[HistoryT[RequestT, ResponseT], list[RequestDataT[RequestT]]] ] = [] def run(self): @@ -332,7 +328,7 @@ async def _cancel_requests_loop(self): self._send_update("cancelled", None, request, request_info) async def _process_next_request(self) -> ProcessRequestT[RequestT, ResponseT]: - conversation: MultiTurnT[RequestT] = [] + conversation: list[RequestDataT[RequestT]] = [] history: HistoryT[RequestT, ResponseT] = [] request: RequestT | None = None request_info: ScheduledRequestInfo | None = None @@ -409,7 +405,7 @@ async def _process_next_request(self) -> ProcessRequestT[RequestT, ResponseT]: async def _wait_then_requeue( self, history: HistoryT[RequestT, ResponseT], - conversation: MultiTurnT[RequestT], + conversation: list[RequestDataT[RequestT]], aug: ScheduledRequestAugmentation, ): try: @@ -427,7 +423,7 @@ def _send_update( "pending", "in_progress", "completed", "errored", "cancelled" ], response: ResponseT | None, - request: RequestT | MultiTurnRequestT[RequestT], + request: RequestT, request_info: ScheduledRequestInfo, ): prev_status = request_info.status diff --git a/src/guidellm/scheduler/worker_group.py b/src/guidellm/scheduler/worker_group.py index 221e95e1..296152a8 100644 --- a/src/guidellm/scheduler/worker_group.py +++ b/src/guidellm/scheduler/worker_group.py @@ -25,8 +25,8 @@ from guidellm.scheduler.constraints import Constraint, RequestsExhaustedConstraint from guidellm.scheduler.objects import ( BackendInterface, - MultiTurnRequestT, - MultiTurnT, + DatasetIterT, + RequestDataT, RequestT, ResponseT, ScheduledRequestAugmentation, @@ -83,8 +83,8 @@ class WorkerProcessGroup(Generic[RequestT, ResponseT]): def __init__( self, - requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, - cycle_requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, + requests: DatasetIterT[RequestT] | None, + cycle_requests: DatasetIterT[RequestT] | None, backend: BackendInterface[RequestT, ResponseT], strategy: SchedulingStrategy, constraints: dict[str, Constraint], @@ -131,16 +131,8 @@ def __init__( # Scheduler and messaging state, created in start self.state: WorkerGroupState[ResponseT, RequestT] = None self.messaging: InterProcessMessaging[ - tuple[ - RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo, - ], - tuple[ - ResponseT | None, - RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo, - SchedulerState, - ], + list[RequestDataT[RequestT]], + tuple[ResponseT | None, RequestT, ScheduledRequestInfo, SchedulerState], ] = None async def create_processes(self): @@ -473,9 +465,9 @@ def __init__( def requests_generator( self, - requests: Iterable[Iterable[tuple[RequestT, float]]] | None, - cycle_requests: Iterable[Iterable[tuple[RequestT, float]]] | None, - ) -> Generator[MultiTurnT[RequestT], None, None]: + requests: DatasetIterT[RequestT] | None, + cycle_requests: DatasetIterT[RequestT] | None, + ) -> Generator[list[RequestDataT[RequestT]], None, None]: """ Generate request-info pairs for worker processing with constraint evaluation. @@ -544,12 +536,12 @@ def received_callback( self, update: tuple[ ResponseT | None, - RequestT | MultiTurnRequestT, + RequestT, ScheduledRequestInfo, ], ) -> tuple[ ResponseT | None, - RequestT | MultiTurnRequestT, + RequestT, ScheduledRequestInfo, SchedulerState, ]: diff --git a/tests/unit/scheduler/test_objects.py b/tests/unit/scheduler/test_objects.py index df794ff8..f76fcfd1 100644 --- a/tests/unit/scheduler/test_objects.py +++ b/tests/unit/scheduler/test_objects.py @@ -7,13 +7,11 @@ import pytest from pydantic import ValidationError -from typing_extensions import TypeAliasType from guidellm.scheduler import ( BackendInterface, BackendT, MeasuredRequestTimings, - MultiTurnRequestT, RequestSchedulerTimings, RequestT, ResponseT, @@ -49,20 +47,6 @@ def test_backend_t(): assert BackendT.__constraints__ == () -def test_multi_turn_request_t(): - """Validate MultiTurnRequestT is a TypeAliasType for multi-turn requests.""" - assert isinstance(MultiTurnRequestT, TypeAliasType) - assert MultiTurnRequestT.__name__ == "MultiTurnRequestT" - - value = MultiTurnRequestT.__value__ - assert hasattr(value, "__origin__") - assert value.__origin__ is Union - - type_params = getattr(MultiTurnRequestT, "__type_params__", ()) - assert len(type_params) == 1 - assert type_params[0].__name__ == "RequestT" - - class TestBackendInterface: """Test the BackendInterface abstract base class."""