Skip to content
148 changes: 68 additions & 80 deletions src/guidellm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from typing import Annotated, Union

import click
from pydantic import ValidationError

try:
import uvloop
Expand All @@ -55,6 +56,7 @@
)
from guidellm.benchmark.scenario import (
GenerativeTextScenario,
get_builtin_scenarios,
)
from guidellm.mock_server import MockServer, MockServerConfig
from guidellm.preprocess.dataset import ShortPromptStrategy, process_dataset
Expand Down Expand Up @@ -135,6 +137,25 @@ def benchmark():
help="Run a benchmark against a generative model using the specified arguments.",
context_settings={"auto_envvar_prefix": "GUIDELLM"},
)
@click.option(
"--scenario",
type=cli_tools.Union(
click.Path(
exists=True,
readable=True,
file_okay=True,
dir_okay=False,
path_type=Path,
),
click.Choice(get_builtin_scenarios()),
),
default=None,
help=(
"The name of a builtin scenario or path to a config file. "
"Missing values from the config will use defaults. "
"Options specified on the commandline will override the scenario."
),
)
@click.option(
"--target",
type=str,
Expand All @@ -161,7 +182,7 @@ def benchmark():
)
@click.option(
"--rate",
default=None,
default=GenerativeTextScenario.get_default("rate"),
help=(
"The rates to run the benchmark at. "
"Can be a single number or a comma-separated list of numbers. "
Expand All @@ -183,18 +204,18 @@ def benchmark():
"--backend-type", # legacy alias
"backend",
type=click.Choice(list(get_literal_vals(BackendType))),
default=GenerativeTextScenario.get_default("backend"),
help=(
"The type of backend to use to run requests against. Defaults to 'openai_http'."
f" Supported types: {', '.join(get_literal_vals(BackendType))}"
),
default="openai_http",
)
@click.option(
"--backend-kwargs",
"--backend-args", # legacy alias
"backend_kwargs",
callback=cli_tools.parse_json,
default=None,
default=GenerativeTextScenario.get_default("backend_kwargs"),
help=(
"A JSON string containing any arguments to pass to the backend as a "
"dict with **kwargs. Headers can be removed by setting their value to "
Expand All @@ -204,7 +225,7 @@ def benchmark():
)
@click.option(
"--model",
default=None,
default=GenerativeTextScenario.get_default("model"),
type=str,
help=(
"The ID of the model to benchmark within the backend. "
Expand All @@ -214,7 +235,7 @@ def benchmark():
# Data configuration
@click.option(
"--processor",
default=None,
default=GenerativeTextScenario.get_default("processor"),
type=str,
help=(
"The processor or tokenizer to use to calculate token counts for statistics "
Expand All @@ -224,7 +245,7 @@ def benchmark():
)
@click.option(
"--processor-args",
default=None,
default=GenerativeTextScenario.get_default("processor_args"),
callback=cli_tools.parse_json,
help=(
"A JSON string containing any arguments to pass to the processor constructor "
Expand All @@ -233,7 +254,7 @@ def benchmark():
)
@click.option(
"--data-args",
default=None,
default=GenerativeTextScenario.get_default("data_args"),
callback=cli_tools.parse_json,
help=(
"A JSON string containing any arguments to pass to the dataset creation "
Expand All @@ -242,7 +263,7 @@ def benchmark():
)
@click.option(
"--data-sampler",
default=None,
default=GenerativeTextScenario.get_default("data_sampler"),
type=click.Choice(["random"]),
help=(
"The data sampler type to use. 'random' will add a random shuffle on the data. "
Expand Down Expand Up @@ -301,7 +322,7 @@ def benchmark():
"--warmup-percent", # legacy alias
"warmup",
type=float,
default=None,
default=GenerativeTextScenario.get_default("warmup"),
help=(
"The specification around the number of requests to run before benchmarking. "
"If within (0, 1), then the percent of requests/time to use for warmup. "
Expand All @@ -315,7 +336,7 @@ def benchmark():
"--cooldown-percent", # legacy alias
"cooldown",
type=float,
default=GenerativeTextScenario.get_default("cooldown_percent"),
default=GenerativeTextScenario.get_default("cooldown"),
help=(
"The specification around the number of requests to run after benchmarking. "
"If within (0, 1), then the percent of requests/time to use for cooldown. "
Expand All @@ -328,19 +349,19 @@ def benchmark():
"--request-samples",
"--output-sampling", # legacy alias
"request_samples",
default=GenerativeTextScenario.get_default("request_samples"),
type=int,
help=(
"The number of samples for each request status and each benchmark to save "
"in the output file. If None (default), will save all samples. "
"Defaults to 20."
),
default=20,
)
# Constraints configuration
@click.option(
"--max-seconds",
type=float,
default=None,
default=GenerativeTextScenario.get_default("max_seconds"),
help=(
"The maximum number of seconds each benchmark can run for. "
"If None, will run until max_requests or the data is exhausted."
Expand All @@ -349,7 +370,7 @@ def benchmark():
@click.option(
"--max-requests",
type=int,
default=None,
default=GenerativeTextScenario.get_default("max_requests"),
help=(
"The maximum number of requests each benchmark can run for. "
"If None, will run until max_seconds or the data is exhausted."
Expand All @@ -358,55 +379,22 @@ def benchmark():
@click.option(
"--max-errors",
type=int,
default=None,
default=GenerativeTextScenario.get_default("max_errors"),
help="Maximum number of errors allowed before stopping the benchmark",
)
@click.option(
"--max-error-rate",
type=float,
default=None,
default=GenerativeTextScenario.get_default("max_error_rate"),
help="Maximum error rate allowed before stopping the benchmark",
)
@click.option(
"--max-global-error-rate",
type=float,
default=None,
default=GenerativeTextScenario.get_default("max_global_error_rate"),
help="Maximum global error rate allowed across all benchmarks",
)
def run(
target,
data,
profile,
rate,
random_seed,
# Backend Configuration
backend,
backend_kwargs,
model,
# Data configuration
processor,
processor_args,
data_args,
data_sampler,
# Output configuration
output_path,
output_formats,
# Updates configuration
disable_console_outputs,
disable_progress,
display_scheduler_stats,
# Aggregators configuration
output_extras,
warmup,
cooldown,
request_samples,
# Constraints configuration
max_seconds,
max_requests,
max_errors,
max_error_rate,
max_global_error_rate,
):
def run(**kwargs):
"""
Execute a generative text benchmark against a target model backend.

Expand All @@ -415,53 +403,53 @@ def run(
Supports multiple backends, data sources, output formats, and constraint types
for flexible benchmark configuration.
"""
scenario = kwargs.pop("scenario")
click_ctx = click.get_current_context()
overrides = cli_tools.set_if_not_default(click_ctx, **kwargs)

try:
# If a scenario file was specified read from it
if scenario is None:
_scenario = GenerativeTextScenario.model_validate(overrides)
elif isinstance(scenario, Path):
_scenario = GenerativeTextScenario.from_file(scenario, overrides)
else: # Only builtins can make it here; click will catch anything else
_scenario = GenerativeTextScenario.from_builtin(scenario, overrides)
except ValidationError as e:
# Translate pydantic valdation error to click argument error
errs = e.errors(include_url=False, include_context=True, include_input=True)
param_name = "--" + str(errs[0]["loc"][0]).replace("_", "-")
raise click.BadParameter(
errs[0]["msg"], ctx=click_ctx, param_hint=param_name
) from e

if HAS_UVLOOP:
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
asyncio.run(
benchmark_generative_text(
target=target,
data=data,
profile=profile,
rate=rate,
random_seed=random_seed,
# Backend configuration
backend=backend,
backend_kwargs=backend_kwargs,
model=model,
# Data configuration
processor=processor,
processor_args=processor_args,
data_args=data_args,
data_sampler=data_sampler,
scenario=_scenario,
# Output configuration
output_path=output_path,
output_path=kwargs["output_path"],
output_formats=[
fmt
for fmt in output_formats
if not disable_console_outputs or fmt != "console"
for fmt in kwargs["output_formats"]
if not kwargs["disable_console_outputs"] or fmt != "console"
],
# Updates configuration
progress=(
[
GenerativeConsoleBenchmarkerProgress(
display_scheduler_stats=display_scheduler_stats
display_scheduler_stats=kwargs["display_scheduler_stats"]
)
]
if not disable_progress
if not kwargs["disable_progress"]
else None
),
print_updates=not disable_console_outputs,
print_updates=not kwargs["disable_console_outputs"],
# Aggregators configuration
add_aggregators={"extras": InjectExtrasAggregator(extras=output_extras)},
warmup=warmup,
cooldown=cooldown,
request_samples=request_samples,
# Constraints configuration
max_seconds=max_seconds,
max_requests=max_requests,
max_errors=max_errors,
max_error_rate=max_error_rate,
max_global_error_rate=max_global_error_rate,
add_aggregators={
"extras": InjectExtrasAggregator(extras=kwargs["output_extras"])
},
)
)

Expand Down
27 changes: 21 additions & 6 deletions src/guidellm/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import time
from collections.abc import AsyncIterator
from itertools import chain
from pathlib import Path
from typing import Any, ClassVar, Optional, Union

Expand All @@ -29,7 +30,7 @@
GenerationRequestTimings,
GenerationResponse,
)
from guidellm.scheduler import ScheduledRequestInfo
from guidellm.scheduler import HistoryT, ScheduledRequestInfo

__all__ = ["OpenAIHTTPBackend", "UsageStats"]

Expand Down Expand Up @@ -280,7 +281,7 @@ async def resolve(
self,
request: GenerationRequest,
request_info: ScheduledRequestInfo,
history: Optional[list[tuple[GenerationRequest, GenerationResponse]]] = None,
history: Optional[HistoryT[GenerationRequest, GenerationResponse]] = None,
) -> AsyncIterator[tuple[GenerationResponse, ScheduledRequestInfo]]:
"""
Process a generation request and yield progressive responses.
Expand All @@ -295,10 +296,8 @@ async def resolve(
:yields: Tuples of (response, updated_request_info) as generation progresses.
"""
self._check_in_process()
if history is not None:
raise NotImplementedError(
"Multi-turn requests with conversation history are not yet supported"
)
if history:
request = self._apply_history(request, history)

response = GenerationResponse(
request_id=request.request_id,
Expand Down Expand Up @@ -500,6 +499,22 @@ async def chat_completions(
self._get_completions_usage_stats(data),
)

def _apply_history(
self,
request: GenerationRequest,
history: HistoryT[GenerationRequest, GenerationResponse],
) -> GenerationRequest:
"""
Apply conversation history to the current request.
"""

def turn_to_text(turn: tuple[GenerationRequest, GenerationResponse]) -> str:
req, res = turn
return f"{req.content}{res.value}"

request.content = "".join(chain(map(turn_to_text, history), (request.content,)))
return request

Comment on lines +502 to +517
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporary hack until we land request templates.

def _build_headers(
self,
api_key: Optional[str],
Expand Down
Loading
Loading