diff --git a/.containerignore b/.containerignore new file mode 120000 index 00000000..3e4e48b0 --- /dev/null +++ b/.containerignore @@ -0,0 +1 @@ +.gitignore \ No newline at end of file diff --git a/.github/workflows/container-maintenance.yml b/.github/workflows/container-maintenance.yml new file mode 100644 index 00000000..ff75a3f4 --- /dev/null +++ b/.github/workflows/container-maintenance.yml @@ -0,0 +1,49 @@ +name: Container Image Maintenance + +on: + schedule: + - cron: '0 2 * * 3' # Runs at 2am on Wednesdays + workflow_dispatch: # Enables manual triggering of the workflow + +# Only run one at a time +concurrency: + group: ${{ github.workflow }} + +jobs: + cleanup-container-tags: + runs-on: ubuntu-latest + steps: + - name: Delete PR and untagged images older than 2 weeks + uses: snok/container-retention-policy@v3.0.0 + with: + account: ${{ github.actor }} + token: ${{ github.token }} + image-names: ${{ github.event.repository.name }} + image-tags: "pr-*" + cut-off: 2w + dry-run: true + + push-container-tags: + runs-on: ubuntu-latest + needs: cleanup-container-tags + if: always() # Run after cleanup even if it fails + steps: + - name: Log into ghcr.io + uses: redhat-actions/podman-login@v1 + with: + username: ${{ github.actor }} + password: ${{ github.token }} + registry: ghcr.io/${{ github.repository_owner }} + - name: Get list of tags + run: | + skopeo list-tags docker://${{ github.repository }} | jq --raw-output '.Tags[]' > tags + - name: Get latest release and rc tags + run: | + STABLE_TAG="$(grep -P '^v\d+\.\d+\.\d+$' tags | sort -rV | head -n1)" + echo "STABLE_TAG=${STABLE_TAG:-v0.0.0}" >> $GITHUB_ENV + LATEST_TAG="$(grep -P '^v\d+\.\d+\.\d+' tags | sort -rV | head -n1)" + echo "LATEST_TAG=${LATEST_TAG:-v0.0.0}" >> $GITHUB_ENV + - name: Update latest and stable tags + run: | + skopeo copy docker://${{ github.repository }}:${{ env.stable_tag }} docker://${{ github.repository }}:stable + skopeo copy docker://${{ github.repository }}:${{ env.latest_tag }} docker://${{ github.repository }}:latest diff --git a/.github/workflows/development.yml b/.github/workflows/development.yml index ff8fe2e9..eabf1934 100644 --- a/.github/workflows/development.yml +++ b/.github/workflows/development.yml @@ -293,14 +293,18 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + fetch-depth: 0 - name: Buildah build id: build-image uses: redhat-actions/buildah-build@v2 with: image: ${{ github.event.repository.name }} + build-args: | + GUIDELLM_BUILD_TYPE=dev tags: "pr-${{ github.event.number }}" containerfiles: | - ./deploy/Containerfile + ./Containerfile - name: Push To ghcr.io id: push-to-ghcr uses: redhat-actions/push-to-registry@v2 diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index f732330e..87ff04ad 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -251,14 +251,18 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + fetch-depth: 0 - name: Buildah build id: build-image uses: redhat-actions/buildah-build@v2 with: image: ${{ github.event.repository.name }} + build-args: | + GUIDELLM_BUILD_TYPE=nightly tags: nightly containerfiles: | - ./deploy/Containerfile + ./Containerfile - name: Push To ghcr.io id: push-to-ghcr uses: redhat-actions/push-to-registry@v2 diff --git a/.github/workflows/release-candidate.yml b/.github/workflows/release-candidate.yml index d48c27f7..703ca4c9 100644 --- a/.github/workflows/release-candidate.yml +++ b/.github/workflows/release-candidate.yml @@ -295,15 +295,20 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Get version from branch + run: echo "PACKAGE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV - name: Buildah build id: build-image uses: redhat-actions/buildah-build@v2 with: image: ${{ github.event.repository.name }} - # TODO: Tag version - tags: latest + build-args: | + GUIDELLM_BUILD_TYPE=candidate + tags: ${{ env.package_version }}~rc containerfiles: | - ./deploy/Containerfile + ./Containerfile - name: Push To ghcr.io id: push-to-ghcr uses: redhat-actions/push-to-registry@v2 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index bb329251..9f3d9d75 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -294,15 +294,20 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Get version from branch + run: echo "PACKAGE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV - name: Buildah build id: build-image uses: redhat-actions/buildah-build@v2 with: image: ${{ github.event.repository.name }} - # TODO: Tag version - tags: latest stable + build-args: | + GUIDELLM_BUILD_TYPE=release + tags: ${{ env.package_version }} containerfiles: | - ./deploy/Containerfile + ./Containerfile - name: Push To ghcr.io id: push-to-ghcr uses: redhat-actions/push-to-registry@v2 diff --git a/Containerfile b/Containerfile new file mode 100644 index 00000000..1f935623 --- /dev/null +++ b/Containerfile @@ -0,0 +1,65 @@ +# TODO: Update to official python-3.13-minimal image when available +ARG BASE_IMAGE=quay.io/fedora/python-313-minimal:latest + +# release: take the last version and add a post if build iteration +# candidate: increment to next minor, add 'rc' with build iteration +# nightly: increment to next minor, add 'a' with build iteration +# alpha: increment to next minor, add 'a' with build iteration +# dev: increment to next minor, add 'dev' with build iteration +ARG GUIDELLM_BUILD_TYPE=dev + +# Use a multi-stage build to create a lightweight production image +FROM $BASE_IMAGE as builder + +# Switch to root for installing packages +USER root + +# Install build tooling +RUN dnf install -y git \ + && /usr/bin/python3 -m venv /tmp/pdm \ + && /tmp/pdm/bin/pip install --no-cache-dir -U pdm \ + && ln -s /tmp/pdm/bin/pdm /usr/local/bin/pdm + +# Disable pdm update check +# Set correct build type for versioning +ENV PDM_CHECK_UPDATE=false \ + GUIDELLM_BUILD_TYPE=$GUIDELLM_BUILD_TYPE + +# Copy repository files +# Do this as late as possible to leverage layer caching +COPY / /src + +# Install guidellm and locked dependencies +RUN pdm use -p /src -f /opt/app-root \ + && pdm install -p /src --check --prod --no-editable + +# Prod image +FROM $BASE_IMAGE + +# Add guidellm bin to PATH +# Argument defaults can be set with GUIDELLM_ +ENV HOME="/home/guidellm" \ + GUIDELLM_OUTPUT_PATH="/results/benchmarks.json" + +# Make sure root is the primary group +USER 1001:0 + +# Create the user home dir +WORKDIR $HOME + +# Create a volume for results +VOLUME /results + +# Metadata +LABEL io.k8s.display-name="GuideLLM" \ + org.opencontainers.image.description="GuideLLM Performance Benchmarking Container" \ + org.opencontainers.image.source="https://github.com/vllm-project/guidellm" \ + org.opencontainers.image.documentation="https://blog.vllm.ai/guidellm/stable" \ + org.opencontainers.image.license="Apache-2.0" + +# Copy the virtual environment from the builder stage +# Do this as late as possible to leverage layer caching +COPY --chown=1001:0 --from=builder /opt/app-root /opt/app-root + +ENTRYPOINT [ "/opt/app-root/bin/guidellm" ] +CMD [ "benchmark", "run" ] diff --git a/README.md b/README.md index 55f8e815..2de7b4a9 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,15 @@ podman run \ Replace `latest` with `stable` for the newest tagged release or set a specific release if desired. +#### Available Tags + +| Tags | Notes | +| ------------------------------------------------------------------------------------------ | --------------------------------------------- | +| `nightly` | Built from `main` every night | +| [`v0.3.0`](https://github.com/vllm-project/guidellm/releases/tag/v0.3.0) `stable` `latest` | - | +| [`v0.2.1`](https://github.com/vllm-project/guidellm/releases/tag/v0.2.1) | - | +| `pr-*` | Development builds (DO NOT USE IN PRODUCTION) | + ### Quick Start #### 1. Start an OpenAI Compatible Server (vLLM) diff --git a/deploy/Containerfile b/deploy/Containerfile deleted file mode 100644 index 7715de93..00000000 --- a/deploy/Containerfile +++ /dev/null @@ -1,42 +0,0 @@ -ARG BASE_IMAGE=docker.io/python:3.13-slim - -# Use a multi-stage build to create a lightweight production image -FROM $BASE_IMAGE as builder - -# Ensure files are installed as root -USER root - -# Copy repository files -COPY / /opt/app-root/src - -# Create a venv and install guidellm -RUN python3 -m venv /opt/app-root/guidellm \ - && /opt/app-root/guidellm/bin/pip install --no-cache-dir /opt/app-root/src - -# Prod image -FROM $BASE_IMAGE - -# Copy the virtual environment from the builder stage -COPY --from=builder /opt/app-root/guidellm /opt/app-root/guidellm - -# Add guidellm bin to PATH -ENV PATH="/opt/app-root/guidellm/bin:$PATH" - -# Create a non-root user -RUN useradd -md /results guidellm - -# Switch to non-root user -USER guidellm - -# Set working directory -WORKDIR /results - -# Metadata -LABEL org.opencontainers.image.source="https://github.com/vllm-project/guidellm" \ - org.opencontainers.image.description="GuideLLM Performance Benchmarking Container" - -# Argument defaults can be set with GUIDELLM_ -ENV GUIDELLM_OUTPUT_PATH="/results/benchmarks.json" - -ENTRYPOINT [ "/opt/app-root/guidellm/bin/guidellm" ] -CMD [ "benchmark", "run" ] diff --git a/pyproject.toml b/pyproject.toml index 29ae92c5..fbe054ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,9 @@ include = ["*"] [tool.setuptools.package-data] "guidellm.data" = ["*.gz"] +[tool.pdm] +distribution = true + # ************************************************ # ********** Project Metadata ********** @@ -70,6 +73,10 @@ perf = [ "msgspec", "uvloop", ] +recommended = [ + "tiktoken>=0.11.0", # For OpenAI tokenizer + "blobfile>=3.1.0", # For OpenAI tokenizer +] dev = [ # build "build>=1.0.0", @@ -110,6 +117,9 @@ dev = [ "mkdocs-linkcheck~=1.0.6", ] +[dependency-groups] +dev = [ "guidellm[dev]" ] + [project.urls] homepage = "https://github.com/vllm-project/guidellm" source = "https://github.com/vllm-project/guidellm" diff --git a/src/guidellm/backends/openai.py b/src/guidellm/backends/openai.py index ce83076f..acce5f88 100644 --- a/src/guidellm/backends/openai.py +++ b/src/guidellm/backends/openai.py @@ -16,6 +16,7 @@ import json import time from collections.abc import AsyncIterator +from itertools import chain from pathlib import Path from typing import Any, ClassVar, Optional, Union @@ -29,7 +30,7 @@ GenerationRequestTimings, GenerationResponse, ) -from guidellm.scheduler import ScheduledRequestInfo +from guidellm.scheduler import HistoryT, ScheduledRequestInfo __all__ = ["OpenAIHTTPBackend", "UsageStats"] @@ -280,7 +281,7 @@ async def resolve( self, request: GenerationRequest, request_info: ScheduledRequestInfo, - history: Optional[list[tuple[GenerationRequest, GenerationResponse]]] = None, + history: Optional[HistoryT[GenerationRequest, GenerationResponse]] = None, ) -> AsyncIterator[tuple[GenerationResponse, ScheduledRequestInfo]]: """ Process a generation request and yield progressive responses. @@ -295,10 +296,8 @@ async def resolve( :yields: Tuples of (response, updated_request_info) as generation progresses. """ self._check_in_process() - if history is not None: - raise NotImplementedError( - "Multi-turn requests with conversation history are not yet supported" - ) + if history: + request = self._apply_history(request, history) response = GenerationResponse( request_id=request.request_id, @@ -500,6 +499,22 @@ async def chat_completions( self._get_completions_usage_stats(data), ) + def _apply_history( + self, + request: GenerationRequest, + history: HistoryT[GenerationRequest, GenerationResponse], + ) -> GenerationRequest: + """ + Apply conversation history to the current request. + """ + + def turn_to_text(turn: tuple[GenerationRequest, GenerationResponse]) -> str: + req, res = turn + return f"{req.content}{res.value}" + + request.content = "".join(chain(map(turn_to_text, history), (request.content,))) + return request + def _build_headers( self, api_key: Optional[str], diff --git a/src/guidellm/benchmark/aggregator.py b/src/guidellm/benchmark/aggregator.py index 9db93a12..e965c482 100644 --- a/src/guidellm/benchmark/aggregator.py +++ b/src/guidellm/benchmark/aggregator.py @@ -441,7 +441,7 @@ def __call__( def compile( self, state: AggregatorState, scheduler_state: SchedulerState - ) -> dict[Literal["scheduler_stats"], BenchmarkSchedulerStats]: + ) -> dict[Literal["run_stats"], BenchmarkSchedulerStats]: """ Compile scheduler timing metrics into benchmark statistics. @@ -473,7 +473,7 @@ def compile( key="worker_resolve_time", type_="avg", default=0.0 ), worker_resolve_end_delay_avg=state.get_metric( - key="worker_resolve_end_delay", type_="avg" + key="worker_resolve_end_delay", type_="avg", default=0.0 ), finalized_delay_avg=state.get_metric( key="finalized_delay", type_="avg", default=0.0 diff --git a/src/guidellm/benchmark/profile.py b/src/guidellm/benchmark/profile.py index 3d4e7287..3ff8d0e0 100644 --- a/src/guidellm/benchmark/profile.py +++ b/src/guidellm/benchmark/profile.py @@ -678,6 +678,8 @@ def next_strategy( self.throughput_rate = ( prev_benchmark.metrics.requests_per_second.successful.mean ) + if self.synchronous_rate <= 0 and self.throughput_rate <= 0: + raise RuntimeError("Invalid rates in sweep; aborting. Were there any successful requests?") self.measured_rates = list( np.linspace( self.synchronous_rate, @@ -696,7 +698,6 @@ def next_strategy( if strat.type_ == self.strategy_type ] ) - if self.strategy_type == "constant": return AsyncConstantStrategy( rate=self.measured_rates[next_rate_index], diff --git a/src/guidellm/dataset/synthetic.py b/src/guidellm/dataset/synthetic.py index 8c30f0f7..06972643 100644 --- a/src/guidellm/dataset/synthetic.py +++ b/src/guidellm/dataset/synthetic.py @@ -3,7 +3,7 @@ from collections.abc import Iterable, Iterator from itertools import cycle from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Any, Optional, TypedDict, Union import yaml from datasets import ( @@ -69,6 +69,26 @@ class SyntheticDatasetConfig(BaseModel): gt=0, default=None, ) + turns: int = Field( + description="The number of turns in the conversation.", + gt=0, + default=1, + ) + turns_stdev: Optional[int] = Field( + description="The standard deviation of the number of turns.", + gt=0, + default=None, + ) + turns_min: Optional[int] = Field( + description="The minimum number of turns in the conversation.", + gt=0, + default=None, + ) + turns_max: Optional[int] = Field( + description="The maximum number of turns in the conversation.", + gt=0, + default=None, + ) samples: int = Field( description="The number of samples to generate for the dataset.", gt=0, @@ -124,14 +144,13 @@ def parse_config_file(data: Union[str, Path]) -> "SyntheticDatasetConfig": return SyntheticDatasetConfig(**config_dict) -class SyntheticTextItemsGenerator( - Iterable[ - dict[ - Literal["prompt", "prompt_tokens_count", "output_tokens_count"], - Union[str, int], - ] - ] -): +class SyntheticDatasetRow(TypedDict): + prompt: list[str] + prompt_tokens_count: list[int] + output_tokens_count: list[int] + + +class SyntheticTextItemsGenerator(Iterable[SyntheticDatasetRow]): def __init__( self, config: SyntheticDatasetConfig, @@ -147,12 +166,7 @@ def __init__( def __iter__( self, - ) -> Iterator[ - dict[ - Literal["prompt", "prompt_tokens_count", "output_tokens_count"], - Union[str, int], - ] - ]: + ) -> Iterator[SyntheticDatasetRow]: prompt_tokens_sampler = IntegerRangeSampler( average=self.config.prompt_tokens, variance=self.config.prompt_tokens_stdev, @@ -167,6 +181,13 @@ def __iter__( max_value=self.config.output_tokens_max, random_seed=self.random_seed + 1, # ensure diff dist from prompts ) + turns_sampler = IntegerRangeSampler( + average=self.config.turns, + variance=self.config.turns_stdev, + min_value=self.config.turns_min, + max_value=self.config.turns_max, + random_seed=self.random_seed + 7, # ensure diff dist + ) # ensure diff distribution from output tokens rand = random.Random(self.random_seed + 2) # noqa: S311 unique_prefix_iter = cycle(self.processor.get_vocab().values()) @@ -174,24 +195,42 @@ def __iter__( prefix_index = rand.randint(0, len(self.text_creator.words)) prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index) - for _, prompt_tokens, output_tokens in zip( - range(self.config.samples), - prompt_tokens_sampler, - output_tokens_sampler, - ): - start_index = rand.randint(0, len(self.text_creator.words)) - prompt_text = self.processor.decode( - prefix_tokens - + self._create_prompt( - prompt_tokens, start_index, next(unique_prefix_iter) - ), - skip_special_tokens=True, - ) - yield { - "prompt": prompt_text, - "prompt_tokens_count": self.config.prefix_tokens + prompt_tokens, - "output_tokens_count": output_tokens, + for _, turns in zip(range(self.config.samples), turns_sampler): + row: SyntheticDatasetRow = { + "prompt": [], + "prompt_tokens_count": [], + "output_tokens_count": [], } + for i, prompt_tokens, output_tokens in zip( + range(turns), + prompt_tokens_sampler, + output_tokens_sampler, + ): + start_index = rand.randint(0, len(self.text_creator.words)) + # Append the prefix tokens only for the first turn + if i == 0: + prompt_text = self.processor.decode( + prefix_tokens + + self._create_prompt( + prompt_tokens, start_index, next(unique_prefix_iter) + ), + skip_special_tokens=True, + ) + row["prompt"].append(prompt_text) + row["prompt_tokens_count"].append(self.config.prefix_tokens + prompt_tokens) + row["output_tokens_count"].append(output_tokens) + else: + prompt_text = self.processor.decode( + self._create_prompt( + prompt_tokens, start_index, next(unique_prefix_iter) + ), + skip_special_tokens=True, + ) + row["prompt"].append(prompt_text) + row["prompt_tokens_count"].append(prompt_tokens) + row["output_tokens_count"].append(output_tokens) + + yield row def _create_prompt( self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None diff --git a/src/guidellm/request/loader.py b/src/guidellm/request/loader.py index 607a7455..81aae8fb 100644 --- a/src/guidellm/request/loader.py +++ b/src/guidellm/request/loader.py @@ -105,14 +105,14 @@ def __init__( self.preserve_iter_state = iter_type == "infinite" # ensure no caching requests self._preserved_iter = None - def __iter__(self) -> Iterator[GenerationRequest]: + def __iter__(self) -> Iterator[list[tuple[GenerationRequest, float]]]: scope_create_count = 0 while (dataset_iter := self._get_dataset_iter(scope_create_count)) is not None: scope_create_count += 1 for item in dataset_iter: - yield self._create_request(item) + yield self._create_requests(item) self._preserved_iter = None @@ -260,25 +260,47 @@ def _get_dataset_iter( return dataset_iter - def _create_request(self, item: dict[str, Any]) -> GenerationRequest: - prompt_tokens = ( - item[self.column_mappings["prompt_tokens_count_column"]] + def _create_requests( + self, item: dict[str, Any] + ) -> list[tuple[GenerationRequest, float]]: + prompts = list(item[self.column_mappings["prompt_column"]]) + prompts_tokens: list[Optional[int]] = ( + list(item[self.column_mappings["prompt_tokens_count_column"]]) if "prompt_tokens_count_column" in self.column_mappings - else None + else [None] * len(prompts) ) - output_tokens = ( - item[self.column_mappings["output_tokens_count_column"]] + outputs_tokens: list[Optional[int]] = ( + list(item[self.column_mappings["output_tokens_count_column"]]) if "output_tokens_count_column" in self.column_mappings - else None + else [None] * len(prompts) ) - return GenerationRequest( - request_type=settings.preferred_route, - content=item[self.column_mappings["prompt_column"]], - stats=( - {"prompt_tokens": prompt_tokens} if prompt_tokens is not None else {} - ), - constraints=( - {"output_tokens": output_tokens} if output_tokens is not None else {} - ), - ) + if len(prompts) != len(prompts_tokens) != len(outputs_tokens): + raise ValueError( + "Mismatched lengths between prompts and token counts. " + f"Prompts: {len(prompts)}, Prompt Tokens: {len(prompts_tokens)}, " + f"Output Tokens: {len(outputs_tokens)}" + ) + + return [ + ( + GenerationRequest( + request_type=settings.preferred_route, + content=prompt, + stats=( + {"prompt_tokens": prompt_tokens} + if prompt_tokens is not None + else {} + ), + constraints=( + {"output_tokens": output_tokens} + if output_tokens is not None + else {} + ), + ), + 0.0, # TODO: delay + ) + for prompt, prompt_tokens, output_tokens in zip( + prompts, prompts_tokens, outputs_tokens + ) + ] diff --git a/src/guidellm/scheduler/__init__.py b/src/guidellm/scheduler/__init__.py index 64647424..4eff5c12 100644 --- a/src/guidellm/scheduler/__init__.py +++ b/src/guidellm/scheduler/__init__.py @@ -15,11 +15,14 @@ from .objects import ( BackendInterface, BackendT, + DatasetIterT, + HistoryT, MeasuredRequestTimings, - MultiTurnRequestT, + RequestDataT, RequestSchedulerTimings, RequestT, ResponseT, + ScheduledRequestAugmentation, ScheduledRequestInfo, SchedulerMessagingPydanticRegistry, SchedulerState, @@ -55,7 +58,9 @@ "Constraint", "ConstraintInitializer", "ConstraintsInitializerFactory", + "DatasetIterT", "Environment", + "HistoryT", "LastCompletionRequestTimings", "MaxDurationConstraint", "MaxErrorRateConstraint", @@ -63,14 +68,15 @@ "MaxGlobalErrorRateConstraint", "MaxNumberConstraint", "MeasuredRequestTimings", - "MultiTurnRequestT", "NoDelayRequestTimings", "NonDistributedEnvironment", "PoissonRateRequestTimings", "PydanticConstraintInitializer", + "RequestDataT", "RequestSchedulerTimings", "RequestT", "ResponseT", + "ScheduledRequestAugmentation", "ScheduledRequestInfo", "ScheduledRequestTimings", "Scheduler", diff --git a/src/guidellm/scheduler/environments.py b/src/guidellm/scheduler/environments.py index 6234f8f6..a9853544 100644 --- a/src/guidellm/scheduler/environments.py +++ b/src/guidellm/scheduler/environments.py @@ -19,14 +19,14 @@ import time from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Iterable +from collections.abc import AsyncIterator from typing import ( Generic, ) from guidellm.scheduler.constraints import Constraint from guidellm.scheduler.objects import ( - MultiTurnRequestT, + DatasetIterT, RequestT, ResponseT, ScheduledRequestInfo, @@ -52,11 +52,11 @@ class Environment(ABC, Generic[RequestT, ResponseT], InfoMixin): @abstractmethod async def sync_run_params( self, - requests: Iterable[RequestT | MultiTurnRequestT[RequestT]], + requests: DatasetIterT[RequestT], strategy: SchedulingStrategy, constraints: dict[str, Constraint], ) -> tuple[ - Iterable[RequestT | MultiTurnRequestT[RequestT]], + DatasetIterT[RequestT], SchedulingStrategy, dict[str, Constraint], ]: @@ -130,7 +130,7 @@ async def sync_run_end( ) -> AsyncIterator[ tuple[ ResponseT, - RequestT | MultiTurnRequestT[RequestT], + RequestT, ScheduledRequestInfo, SchedulerState, ] @@ -194,11 +194,11 @@ def __init__(self): async def sync_run_params( self, - requests: Iterable[RequestT | MultiTurnRequestT[RequestT]], + requests: DatasetIterT[RequestT], strategy: SchedulingStrategy, constraints: dict[str, Constraint], ) -> tuple[ - Iterable[RequestT | MultiTurnRequestT[RequestT]], + DatasetIterT[RequestT], SchedulingStrategy, dict[str, Constraint], ]: @@ -250,7 +250,7 @@ async def sync_run_end( ) -> AsyncIterator[ tuple[ ResponseT, - RequestT | MultiTurnRequestT[RequestT], + RequestT, ScheduledRequestInfo, SchedulerState, ] diff --git a/src/guidellm/scheduler/objects.py b/src/guidellm/scheduler/objects.py index b7f2efc3..e7d4c6c7 100644 --- a/src/guidellm/scheduler/objects.py +++ b/src/guidellm/scheduler/objects.py @@ -11,7 +11,7 @@ import time import uuid -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterable from typing import ( Any, ClassVar, @@ -19,7 +19,6 @@ Literal, Protocol, TypeVar, - Union, ) from pydantic import Field, computed_field @@ -35,11 +34,14 @@ __all__ = [ "BackendInterface", "BackendT", + "DatasetIterT", + "HistoryT", "MeasuredRequestTimings", - "MultiTurnRequestT", + "RequestDataT", "RequestSchedulerTimings", "RequestT", "ResponseT", + "ScheduledRequestAugmentation", "ScheduledRequestInfo", "SchedulerMessagingPydanticRegistry", "SchedulerState", @@ -53,15 +55,24 @@ ResponseT = TypeVar("ResponseT") """Generic response object type returned by backend processing.""" -MultiTurnRequestT = TypeAliasType( - "MultiTurnRequestT", - Union[ - list[Union[RequestT, tuple[RequestT, float]]], - tuple[Union[RequestT, tuple[RequestT, float]]], - ], +RequestDataT = TypeAliasType( + "RequestDataT", + tuple[RequestT, "ScheduledRequestAugmentation", "ScheduledRequestInfo"], type_params=(RequestT,), ) -"""Multi-turn request structure supporting conversation history with optional delays.""" +"""Request including external metadata and scheduling config.""" + +HistoryT = TypeAliasType( + "HistoryT", + list[tuple[RequestT, ResponseT]], + type_params=(RequestT, ResponseT), +) +"""Record of requests + responses in conversation.""" + + +DatasetIterT = TypeAliasType( + "DatasetIterT", Iterable[Iterable[tuple[RequestT, float]]], type_params=(RequestT,) +) class SchedulerMessagingPydanticRegistry(RegistryMixin[RegistryObjT]): @@ -71,6 +82,21 @@ class SchedulerMessagingPydanticRegistry(RegistryMixin[RegistryObjT]): """ +@SchedulerMessagingPydanticRegistry.register() +class ScheduledRequestAugmentation(StandardBaseModel): + """ + Adjustments to scheduler logic for a paired request. + """ + + post_requeue_delay: float = Field( + description=( + "Delay in seconds to wait after a request to " + "queue the next request in the conversation." + ), + default=0.0, + ) + + @SchedulerMessagingPydanticRegistry.register() class RequestSchedulerTimings(StandardBaseModel): """ diff --git a/src/guidellm/scheduler/scheduler.py b/src/guidellm/scheduler/scheduler.py index e7d8b2c6..43948d18 100644 --- a/src/guidellm/scheduler/scheduler.py +++ b/src/guidellm/scheduler/scheduler.py @@ -10,7 +10,7 @@ from __future__ import annotations -from collections.abc import AsyncIterator, Iterable +from collections.abc import AsyncIterator from typing import Any, Generic from guidellm.scheduler.constraints import ( @@ -20,7 +20,7 @@ from guidellm.scheduler.environments import Environment, NonDistributedEnvironment from guidellm.scheduler.objects import ( BackendInterface, - MultiTurnRequestT, + DatasetIterT, RequestT, ResponseT, ScheduledRequestInfo, @@ -66,7 +66,7 @@ class Scheduler( async def run( self, - requests: Iterable[RequestT | MultiTurnRequestT[RequestT]], + requests: DatasetIterT[RequestT], backend: BackendInterface[RequestT, ResponseT], strategy: SchedulingStrategy, env: Environment | None, diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index 5f2fb74b..4c5903fb 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -13,7 +13,7 @@ import time from multiprocessing.synchronize import Barrier as ProcessingBarrier from multiprocessing.synchronize import Event as ProcessingEvent -from typing import Annotated, Generic, Literal +from typing import Annotated, Generic, Literal, TypeAliasType try: import uvloop @@ -31,9 +31,11 @@ from guidellm.scheduler.objects import ( BackendInterface, - MultiTurnRequestT, + HistoryT, + RequestDataT, RequestT, ResponseT, + ScheduledRequestAugmentation, ScheduledRequestInfo, SchedulerMessagingPydanticRegistry, ) @@ -47,6 +49,16 @@ __all__ = ["WorkerProcess"] +ProcessRequestT = TypeAliasType( + "ProcessRequestT", + tuple[ + HistoryT[RequestT, ResponseT], + list[RequestDataT[RequestT]], + ScheduledRequestAugmentation, + ], + type_params=(RequestT, ResponseT), +) + class WorkerProcess(Generic[RequestT, ResponseT]): """ @@ -74,11 +86,8 @@ class WorkerProcess(Generic[RequestT, ResponseT]): def __init__( self, messaging: InterProcessMessaging[ - tuple[ - ResponseT | None, - RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo, - ], + tuple[ResponseT | None, RequestT, ScheduledRequestInfo], + list[RequestDataT[RequestT]], ], backend: BackendInterface[RequestT, ResponseT], request_timings: ScheduledRequestTimings, @@ -118,6 +127,9 @@ def __init__( self.startup_completed = False self.backend_started = False self.messaging_started = False + self.turns_queue: list[ + tuple[HistoryT[RequestT, ResponseT], list[RequestDataT[RequestT]]] + ] = [] def run(self): """ @@ -265,12 +277,20 @@ async def _process_requests_loop(self): async_semaphore = asyncio.Semaphore(self.async_limit) pending_tasks: set[asyncio.Task] = set() - def _task_done(task): + def _task_done(task: asyncio.Task[ProcessRequestT[RequestT, ResponseT]]): pending_tasks.discard(task) async_semaphore.release() - if not task.cancelled() and (exception := task.exception()): - raise exception + if not task.cancelled(): + if exception := task.exception(): + raise exception + + history, conversation, aug = task.result() + if conversation: + requeue_task = asyncio.create_task( + self._wait_then_requeue(history, conversation, aug) + ) + pending_tasks.add(requeue_task) # Main loop; loop until canceled while True: @@ -290,28 +310,40 @@ async def _cancel_requests_loop(self): try: request: RequestT request_info: ScheduledRequestInfo - request, request_info = await self.messaging.get( - timeout=self.messaging.poll_interval + _, conversation = ( + self.turns_queue.pop(0) + if self.turns_queue + else ( + None, + await self.messaging.get(timeout=self.messaging.poll_interval), + ) ) except asyncio.TimeoutError: continue - request_info.scheduler_node_id = self.messaging.worker_index - request_info.error = "Request was cancelled" - request_info.scheduler_timings.resolve_end = time.time() - self._send_update("cancelled", None, request, request_info) + for request, _, request_info in conversation: + request_info.scheduler_node_id = self.messaging.worker_index + request_info.error = "Request was cancelled" + request_info.scheduler_timings.resolve_end = time.time() + self._send_update("cancelled", None, request, request_info) - async def _process_next_request(self): - request: RequestT | MultiTurnRequestT[RequestT] | None = None + async def _process_next_request(self) -> ProcessRequestT[RequestT, ResponseT]: + conversation: list[RequestDataT[RequestT]] = [] + history: HistoryT[RequestT, ResponseT] = [] + request: RequestT | None = None request_info: ScheduledRequestInfo | None = None response: ResponseT | None = None + aug: ScheduledRequestAugmentation | None = None + premature_exit: bool = False try: # Pull request from the queue - request, request_info = await self.messaging.get() - - if isinstance(request, (list, tuple)): - raise NotImplementedError("Multi-turn requests are not yet supported") + history, conversation = ( + self.turns_queue.pop(0) + if self.turns_queue + else ([], await self.messaging.get()) + ) + request, aug, request_info = conversation.pop(0) # Calculate targeted start and set pending state for request request_info.scheduler_node_id = self.messaging.worker_index @@ -333,7 +365,9 @@ async def _process_next_request(self): # Process the request with the backend request_info.scheduler_timings.resolve_start = time.time() self._send_update("in_progress", response, request, request_info) - async for resp, info in self.backend.resolve(request, request_info, None): + async for resp, info in self.backend.resolve( + request, request_info, history + ): response = resp request_info = info @@ -341,8 +375,12 @@ async def _process_next_request(self): request_info.scheduler_timings.resolve_end = time.time() self._send_update("completed", response, request, request_info) + # Record Turn + history.append((request, response)) + response = request = request_info = None except asyncio.CancelledError: + premature_exit = True # Handle cancellation if request is not None and request_info is not None: request_info.error = "Request was cancelled" @@ -350,10 +388,34 @@ async def _process_next_request(self): self._send_update("cancelled", response, request, request_info) raise except Exception as exc: # noqa: BLE001 + premature_exit = True if request is not None and request_info is not None: request_info.error = str(exc) request_info.scheduler_timings.resolve_end = time.time() self._send_update("errored", response, request, request_info) + finally: + if premature_exit and conversation: + for request, _, request_info in conversation: + request_info.error = "Request was cancelled" + request_info.scheduler_timings.resolve_end = time.time() + self._send_update("cancelled", response, request, request_info) + + return history, conversation, aug + + async def _wait_then_requeue( + self, + history: HistoryT[RequestT, ResponseT], + conversation: list[RequestDataT[RequestT]], + aug: ScheduledRequestAugmentation, + ): + try: + if aug.post_requeue_delay > 0: + await asyncio.sleep(aug.post_requeue_delay) + except asyncio.CancelledError: + # If we are cancelled, dump straight to queue + raise + finally: + self.turns_queue.append((history, conversation)) def _send_update( self, @@ -361,7 +423,7 @@ def _send_update( "pending", "in_progress", "completed", "errored", "cancelled" ], response: ResponseT | None, - request: RequestT | MultiTurnRequestT[RequestT], + request: RequestT, request_info: ScheduledRequestInfo, ): prev_status = request_info.status diff --git a/src/guidellm/scheduler/worker_group.py b/src/guidellm/scheduler/worker_group.py index c1d516f1..296152a8 100644 --- a/src/guidellm/scheduler/worker_group.py +++ b/src/guidellm/scheduler/worker_group.py @@ -25,9 +25,11 @@ from guidellm.scheduler.constraints import Constraint, RequestsExhaustedConstraint from guidellm.scheduler.objects import ( BackendInterface, - MultiTurnRequestT, + DatasetIterT, + RequestDataT, RequestT, ResponseT, + ScheduledRequestAugmentation, ScheduledRequestInfo, SchedulerMessagingPydanticRegistry, SchedulerState, @@ -81,8 +83,8 @@ class WorkerProcessGroup(Generic[RequestT, ResponseT]): def __init__( self, - requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, - cycle_requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, + requests: DatasetIterT[RequestT] | None, + cycle_requests: DatasetIterT[RequestT] | None, backend: BackendInterface[RequestT, ResponseT], strategy: SchedulingStrategy, constraints: dict[str, Constraint], @@ -129,16 +131,8 @@ def __init__( # Scheduler and messaging state, created in start self.state: WorkerGroupState[ResponseT, RequestT] = None self.messaging: InterProcessMessaging[ - tuple[ - RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo, - ], - tuple[ - ResponseT | None, - RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo, - SchedulerState, - ], + list[RequestDataT[RequestT]], + tuple[ResponseT | None, RequestT, ScheduledRequestInfo, SchedulerState], ] = None async def create_processes(self): @@ -471,9 +465,9 @@ def __init__( def requests_generator( self, - requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, - cycle_requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, - ) -> Generator[tuple[RequestT | MultiTurnRequestT[RequestT],], None, None]: + requests: DatasetIterT[RequestT] | None, + cycle_requests: DatasetIterT[RequestT] | None, + ) -> Generator[list[RequestDataT[RequestT]], None, None]: """ Generate request-info pairs for worker processing with constraint evaluation. @@ -494,27 +488,40 @@ def _iter(): while True: yield from cycle_requests - count = 0 - request_info: ScheduledRequestInfo = None - for request in _iter(): - count += 1 - - if hasattr(request, "request_id"): - request_id = request.request_id - elif hasattr(request, "id"): - request_id = request.id - else: - request_id = str(uuid.uuid4()) - request_info: ScheduledRequestInfo = ScheduledRequestInfo( - request_id=request_id, - status="queued", - scheduler_process_id=0, - scheduler_start_time=self.start_time, - ) - state_update = self._locked_update(request_info) - yield (request, request_info) + count: int = 0 + stop_queueing: bool = False + + def _turn_iter(requests_chain: Iterable[tuple[RequestT, float]]): + nonlocal count, stop_queueing + for request, delay in requests_chain: + count += 1 + + if hasattr(request, "request_id"): + request_id = request.request_id + elif hasattr(request, "id"): + request_id = request.id + else: + request_id = str(uuid.uuid4()) + request_augmentation = ScheduledRequestAugmentation( + post_requeue_delay=delay + ) + request_info: ScheduledRequestInfo = ScheduledRequestInfo( + request_id=request_id, + status="queued", + scheduler_process_id=0, + scheduler_start_time=self.start_time, + ) + state_update = self._locked_update(request_info) + yield (request, request_augmentation, request_info) - if state_update.stop_queueing: + if state_update.stop_queueing: + stop_queueing = True + return + + for request_chain in _iter(): + yield list(_turn_iter(request_chain)) + + if stop_queueing: self.stop_send_requests_event.set() return @@ -529,12 +536,12 @@ def received_callback( self, update: tuple[ ResponseT | None, - RequestT | MultiTurnRequestT, + RequestT, ScheduledRequestInfo, ], ) -> tuple[ ResponseT | None, - RequestT | MultiTurnRequestT, + RequestT, ScheduledRequestInfo, SchedulerState, ]: diff --git a/src/guidellm/utils/encoding.py b/src/guidellm/utils/encoding.py index ccd26982..d4fa007b 100644 --- a/src/guidellm/utils/encoding.py +++ b/src/guidellm/utils/encoding.py @@ -390,23 +390,11 @@ def to_dict(self, obj: Any) -> Any: if isinstance(obj, BaseModel): return self.to_dict_pydantic(obj) - if isinstance(obj, (list, tuple)) and any( - isinstance(item, BaseModel) for item in obj - ): - return [ - self.to_dict_pydantic(item) if isinstance(item, BaseModel) else item - for item in obj - ] + if isinstance(obj, (list, tuple)): + return [self.to_dict(item) for item in obj] - if isinstance(obj, dict) and any( - isinstance(value, BaseModel) for value in obj.values() - ): - return { - key: self.to_dict_pydantic(value) - if isinstance(value, BaseModel) - else value - for key, value in obj.items() - } + if isinstance(obj, dict): + return {key: self.to_dict(value) for key, value in obj.items()} return obj @@ -418,22 +406,13 @@ def from_dict(self, data: Any) -> Any: :return: Reconstructed object with proper types restored """ if isinstance(data, (list, tuple)): - return [ - self.from_dict_pydantic(item) - if isinstance(item, dict) and "*PYD*" in item - else item - for item in data - ] - elif isinstance(data, dict) and data: + return [self.from_dict(item) for item in data] + + if isinstance(data, dict) and data: if "*PYD*" in data: return self.from_dict_pydantic(data) - return { - key: self.from_dict_pydantic(value) - if isinstance(value, dict) and "*PYD*" in value - else value - for key, value in data.items() - } + return {key: self.from_dict(value) for key, value in data.items()} return data diff --git a/tests/unit/scheduler/test_objects.py b/tests/unit/scheduler/test_objects.py index df794ff8..f76fcfd1 100644 --- a/tests/unit/scheduler/test_objects.py +++ b/tests/unit/scheduler/test_objects.py @@ -7,13 +7,11 @@ import pytest from pydantic import ValidationError -from typing_extensions import TypeAliasType from guidellm.scheduler import ( BackendInterface, BackendT, MeasuredRequestTimings, - MultiTurnRequestT, RequestSchedulerTimings, RequestT, ResponseT, @@ -49,20 +47,6 @@ def test_backend_t(): assert BackendT.__constraints__ == () -def test_multi_turn_request_t(): - """Validate MultiTurnRequestT is a TypeAliasType for multi-turn requests.""" - assert isinstance(MultiTurnRequestT, TypeAliasType) - assert MultiTurnRequestT.__name__ == "MultiTurnRequestT" - - value = MultiTurnRequestT.__value__ - assert hasattr(value, "__origin__") - assert value.__origin__ is Union - - type_params = getattr(MultiTurnRequestT, "__type_params__", ()) - assert len(type_params) == 1 - assert type_params[0].__name__ == "RequestT" - - class TestBackendInterface: """Test the BackendInterface abstract base class."""