Skip to content

Commit f95d772

Browse files
committed
Rebase onto develop and re-run ruff formatter
1 parent 04e3511 commit f95d772

File tree

6 files changed

+38
-119
lines changed

6 files changed

+38
-119
lines changed

aiperf/aiperf_models.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,16 @@ class BaseConfig(BaseModel):
3333
description="Optional tokenizer Huggingface name, or local directory",
3434
)
3535
url: str = Field(..., description="Model base URL")
36-
endpoint: str = Field(
37-
default="/v1/chat/completions", description="API endpoint path"
38-
)
36+
endpoint: str = Field(default="/v1/chat/completions", description="API endpoint path")
3937
endpoint_type: Literal["chat", "completions"] = Field(
4038
default="chat",
4139
description="Type of endpoint (chat or completions)",
4240
)
43-
api_key_env_var: Optional[str] = Field(
44-
default=None, description="API key environment variable"
45-
)
41+
api_key_env_var: Optional[str] = Field(default=None, description="API key environment variable")
4642
streaming: Optional[bool] = Field(default=False, description="Streaming mode")
4743

4844
# Load generation settings
49-
warmup_request_count: int = Field(
50-
description="Requests to send before beginning performance-test"
51-
)
45+
warmup_request_count: int = Field(description="Requests to send before beginning performance-test")
5246
benchmark_duration: int = Field(description="Benchmark duration in seconds")
5347
concurrency: int = Field(description="Number of concurrent requests")
5448
request_rate: Optional[float] = Field(
@@ -61,9 +55,7 @@ class BaseConfig(BaseModel):
6155
)
6256

6357
# Synthetic data generation
64-
random_seed: Optional[int] = Field(
65-
default=None, description="Random seed for reproducibility"
66-
)
58+
random_seed: Optional[int] = Field(default=None, description="Random seed for reproducibility")
6759
prompt_input_tokens_mean: Optional[int] = Field(
6860
default=None,
6961
description="Mean number of input tokens",
@@ -85,26 +77,20 @@ class BaseConfig(BaseModel):
8577
class AIPerfConfig(BaseModel):
8678
"""Main configuration model for AIPerf benchmark runner."""
8779

88-
batch_name: str = Field(
89-
default="benchmark", description="Name for this batch of benchmarks"
90-
)
80+
batch_name: str = Field(default="benchmark", description="Name for this batch of benchmarks")
9181
output_base_dir: str = Field(
9282
default="aiperf_results",
9383
description="Base directory for benchmark results",
9484
)
95-
base_config: BaseConfig = Field(
96-
..., description="Base configuration applied to all benchmark runs"
97-
)
85+
base_config: BaseConfig = Field(..., description="Base configuration applied to all benchmark runs")
9886
sweeps: Optional[Dict[str, List[Union[int, str]]]] = Field(
9987
default=None,
10088
description="Parameter sweeps. Key is the parameter to change, value is a list of values to use",
10189
)
10290

10391
@field_validator("sweeps")
10492
@classmethod
105-
def validate_sweeps(
106-
cls, v: Optional[Dict[str, List[Any]]]
107-
) -> Optional[Dict[str, List[Any]]]:
93+
def validate_sweeps(cls, v: Optional[Dict[str, List[Any]]]) -> Optional[Dict[str, List[Any]]]:
10894
"""Validate that sweep values are lists of ints or strings."""
10995
if v is None:
11096
return v

aiperf/run_aiperf.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from datetime import datetime
2626
from pathlib import Path
2727
from subprocess import CompletedProcess
28-
from typing import Any, Dict, List, Optional, Tuple, Union
28+
from typing import Any, Dict, List, Optional, Union
2929

3030
import httpx
3131
import typer
@@ -37,9 +37,7 @@
3737
log = logging.getLogger(__name__)
3838
log.setLevel(logging.INFO)
3939

40-
formatter = logging.Formatter(
41-
"%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
42-
)
40+
formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
4341
console_handler = logging.StreamHandler()
4442
console_handler.setLevel(logging.DEBUG)
4543
console_handler.setFormatter(formatter)
@@ -143,9 +141,7 @@ def _sanitize_command_for_logging(cmd: List[str]) -> str:
143141

144142
return " ".join(sanitized)
145143

146-
def _build_command(
147-
self, sweep_params: Optional[Dict[str, Union[str, int]]], output_dir: Path
148-
) -> List[str]:
144+
def _build_command(self, sweep_params: Optional[Dict[str, Union[str, int]]], output_dir: Path) -> List[str]:
149145
"""Create a list of strings with the aiperf command and arguments to execute"""
150146

151147
# Run aiperf in profile mode: `aiperf profile`
@@ -239,9 +235,7 @@ def _save_run_metadata(
239235
json.dump(metadata, f, indent=2)
240236

241237
@staticmethod
242-
def _save_subprocess_result_json(
243-
output_dir: Path, result: CompletedProcess
244-
) -> None:
238+
def _save_subprocess_result_json(output_dir: Path, result: CompletedProcess) -> None:
245239
"""Save the subprocess result to the given filename"""
246240

247241
process_result_file = output_dir / "process_result.json"
@@ -252,15 +246,11 @@ def _save_subprocess_result_json(
252246
json.dump(save_data, f, indent=2)
253247

254248
except (IOError, OSError) as e:
255-
log.error(
256-
"Could not write %s to file %s: %s", save_data, process_result_file, e
257-
)
249+
log.error("Could not write %s to file %s: %s", save_data, process_result_file, e)
258250
raise
259251

260252
except TypeError as e:
261-
log.error(
262-
"Couldn't serialize %s to %s: %s", save_data, process_result_file, e
263-
)
253+
log.error("Couldn't serialize %s to %s: %s", save_data, process_result_file, e)
264254
raise
265255

266256
def _check_service(self, endpoint: Optional[str] = "/v1/models") -> None:
@@ -357,9 +347,7 @@ def run_single_benchmark(
357347
log.info("Run completed successfully")
358348
self._save_subprocess_result_json(run_output_dir, result)
359349
run_completed = 1 if result.returncode == 0 else 0
360-
return AIPerfSummary(
361-
total=1, completed=run_completed, failed=1 - run_completed
362-
)
350+
return AIPerfSummary(total=1, completed=run_completed, failed=1 - run_completed)
363351

364352
except subprocess.CalledProcessError as e:
365353
log.error("Run failed with exit code %s", e.returncode)
@@ -379,9 +367,7 @@ def run_batch_benchmarks(
379367
# Generate all sweep combinations
380368
combinations = self._get_sweep_combinations()
381369
if not combinations:
382-
raise RuntimeError(
383-
f"Can't generate sweep combinations from {self.config.sweeps}"
384-
)
370+
raise RuntimeError(f"Can't generate sweep combinations from {self.config.sweeps}")
385371

386372
num_combinations = len(combinations)
387373
log.info("Running %s benchmarks", num_combinations)

nemoguardrails/cli/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from nemoguardrails import __version__
2727
from nemoguardrails.actions_server import actions_server
28-
from nemoguardrails.benchmark.aiperf.run_aiperf import app as aiperf_app
2928
from nemoguardrails.cli.chat import run_chat
3029
from nemoguardrails.cli.migration import migrate
3130
from nemoguardrails.cli.providers import _list_providers, select_provider_with_type

nemoguardrails/server/api.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -475,9 +475,7 @@ async def chat_completion(body: RequestBody, request: Request):
475475

476476
except Exception as ex:
477477
log.exception(ex)
478-
return ResponseBody(
479-
messages=[{"role": "assistant", "content": "Internal server error."}]
480-
)
478+
return ResponseBody(messages=[{"role": "assistant", "content": "Internal server error."}])
481479

482480

483481
# By default, there are no challenges

tests/benchmark/test_aiperf_models.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,7 @@ def test_aiperf_config_sweep_invalid_value_type_dict(self, valid_base_config):
236236
error_msg = str(exc_info.value)
237237
# Pydantic catches this during type validation
238238
assert "sweeps.concurrency" in error_msg
239-
assert (
240-
"must be int or str" in error_msg
241-
or "int_type" in error_msg
242-
or "string_type" in error_msg
243-
)
239+
assert "must be int or str" in error_msg or "int_type" in error_msg or "string_type" in error_msg
244240

245241
def test_aiperf_config_sweep_invalid_value_type_list(self, valid_base_config):
246242
"""Test that list values in sweeps raise validation error."""
@@ -254,11 +250,7 @@ def test_aiperf_config_sweep_invalid_value_type_list(self, valid_base_config):
254250
error_msg = str(exc_info.value)
255251
# Pydantic catches this during type validation
256252
assert "sweeps.concurrency" in error_msg
257-
assert (
258-
"must be int or str" in error_msg
259-
or "int_type" in error_msg
260-
or "string_type" in error_msg
261-
)
253+
assert "must be int or str" in error_msg or "int_type" in error_msg or "string_type" in error_msg
262254

263255
def test_aiperf_config_sweep_empty_list(self, valid_base_config):
264256
"""Test that empty sweep list raises validation error."""
@@ -302,9 +294,7 @@ def test_aiperf_config_multiple_invalid_sweep_keys(self, valid_base_config):
302294

303295
def test_aiperf_config_get_output_base_path(self, valid_base_config):
304296
"""Test get_output_base_path method."""
305-
config = AIPerfConfig(
306-
output_base_dir="custom_results", base_config=valid_base_config
307-
)
297+
config = AIPerfConfig(output_base_dir="custom_results", base_config=valid_base_config)
308298
path = config.get_output_base_path()
309299
assert isinstance(path, Path)
310300
assert str(path) == "custom_results"

0 commit comments

Comments
 (0)