From 62b7317b7040784fd5bad9937c0ced1af5b87bd9 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Wed, 20 Aug 2025 13:19:12 +0000 Subject: [PATCH 1/3] Starting point for scheduler refactor --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 0b1014cb..38733843 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ keywords = [ "inference", "language-models", "large-language-model", + "load-generation", "llm", "machine-learning", "model-benchmark", From 3f7f7ac87796579f610c85ba3babc73550e7033d Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Tue, 26 Aug 2025 15:34:34 -0400 Subject: [PATCH 2/3] Scheduler refactor [utils]: auto_importer, pydantic_utils, registry, singleton (#289) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary This PR introduces four new utility modules that provide foundational capabilities for auto-discovery, registry patterns, singleton management, and enhanced Pydantic utilities. It additionally consolidates the pydantic utility classes from guidellm.objects into the dedicated guidellm.utils module containing the new additions. This change improves code organization and establishes the foundation for upcoming scheduler refactoring work. ## Details - **Moved core utilities** from `guidellm.objects.pydantic` to `guidellm.utils.pydantic_utils` - **Added `AutoImporterMixin`** for dynamic module discovery within packages - **Added `RegistryMixin`** for object registration with optional auto-discovery - **Added singleton mixins** (`SingletonMixin`, `ThreadSafeSingletonMixin`) for instance management - **Enhanced Pydantic utilities** with `ReloadableBaseModel`, `StandardBaseDict`, and `PydanticClassRegistryMixin` for polymorphic model serialization - **Updated all imports** across the codebase to reference new module locations - **Added comprehensive test coverage** for all new utility modules - **Removed obsolete** `guidellm.objects.pydantic` module and associated tests ## Test Plan - Run existing test suite to ensure no regressions from import changes - Execute new utility module test suites covering smoke, sanity, and regression scenarios ## Related Issues --- - [x] "I certify that all code in this PR is my own, except as noted below." ## Use of AI - [x] Includes AI-assisted code completion - [x] Includes code generated by an AI application - [ ] Includes AI-generated tests (NOTE: AI written tests should have a docstring that includes `## WRITTEN BY AI ##`) --------- Signed-off-by: Mark Kurtz Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/guidellm/backend/response.py | 2 +- src/guidellm/benchmark/aggregator.py | 4 +- src/guidellm/benchmark/benchmark.py | 3 +- src/guidellm/benchmark/benchmarker.py | 2 +- src/guidellm/benchmark/output.py | 3 +- src/guidellm/benchmark/profile.py | 2 +- src/guidellm/benchmark/scenario.py | 2 +- src/guidellm/objects/__init__.py | 3 - src/guidellm/objects/pydantic.py | 89 --- src/guidellm/objects/statistics.py | 2 +- src/guidellm/request/loader.py | 2 +- src/guidellm/request/request.py | 2 +- src/guidellm/scheduler/result.py | 2 +- src/guidellm/scheduler/strategy.py | 2 +- src/guidellm/scheduler/worker.py | 2 +- src/guidellm/utils/__init__.py | 19 + src/guidellm/utils/auto_importer.py | 98 ++++ src/guidellm/utils/pydantic_utils.py | 302 ++++++++++ src/guidellm/utils/registry.py | 206 +++++++ src/guidellm/utils/singleton.py | 130 +++++ tests/unit/objects/test_pydantic.py | 43 -- tests/unit/utils/test_auto_importer.py | 269 +++++++++ tests/unit/utils/test_pydantic_utils.py | 710 ++++++++++++++++++++++++ tests/unit/utils/test_registry.py | 533 ++++++++++++++++++ tests/unit/utils/test_singleton.py | 371 +++++++++++++ 25 files changed, 2651 insertions(+), 152 deletions(-) delete mode 100644 src/guidellm/objects/pydantic.py create mode 100644 src/guidellm/utils/auto_importer.py create mode 100644 src/guidellm/utils/pydantic_utils.py create mode 100644 src/guidellm/utils/registry.py create mode 100644 src/guidellm/utils/singleton.py delete mode 100644 tests/unit/objects/test_pydantic.py create mode 100644 tests/unit/utils/test_auto_importer.py create mode 100644 tests/unit/utils/test_pydantic_utils.py create mode 100644 tests/unit/utils/test_registry.py create mode 100644 tests/unit/utils/test_singleton.py diff --git a/src/guidellm/backend/response.py b/src/guidellm/backend/response.py index ee2101d7..bfa738d8 100644 --- a/src/guidellm/backend/response.py +++ b/src/guidellm/backend/response.py @@ -3,7 +3,7 @@ from pydantic import computed_field from guidellm.config import settings -from guidellm.objects.pydantic import StandardBaseModel +from guidellm.utils import StandardBaseModel __all__ = [ "RequestArgs", diff --git a/src/guidellm/benchmark/aggregator.py b/src/guidellm/benchmark/aggregator.py index af7f1a13..b322eadd 100644 --- a/src/guidellm/benchmark/aggregator.py +++ b/src/guidellm/benchmark/aggregator.py @@ -24,8 +24,6 @@ from guidellm.config import settings from guidellm.objects import ( RunningStats, - StandardBaseModel, - StatusBreakdown, TimeRunningStats, ) from guidellm.request import ( @@ -40,7 +38,7 @@ SchedulerRequestResult, WorkerDescription, ) -from guidellm.utils import check_load_processor +from guidellm.utils import StandardBaseModel, StatusBreakdown, check_load_processor __all__ = [ "AggregatorT", diff --git a/src/guidellm/benchmark/benchmark.py b/src/guidellm/benchmark/benchmark.py index 02eea02b..77d0fe38 100644 --- a/src/guidellm/benchmark/benchmark.py +++ b/src/guidellm/benchmark/benchmark.py @@ -13,8 +13,6 @@ ThroughputProfile, ) from guidellm.objects import ( - StandardBaseModel, - StatusBreakdown, StatusDistributionSummary, ) from guidellm.request import ( @@ -32,6 +30,7 @@ ThroughputStrategy, WorkerDescription, ) +from guidellm.utils import StandardBaseModel, StatusBreakdown __all__ = [ "Benchmark", diff --git a/src/guidellm/benchmark/benchmarker.py b/src/guidellm/benchmark/benchmarker.py index 11b6d245..876e6f43 100644 --- a/src/guidellm/benchmark/benchmarker.py +++ b/src/guidellm/benchmark/benchmarker.py @@ -22,7 +22,6 @@ ) from guidellm.benchmark.benchmark import BenchmarkArgs, GenerativeBenchmark from guidellm.benchmark.profile import Profile -from guidellm.objects import StandardBaseModel from guidellm.request import ( GenerationRequest, GenerativeRequestLoaderDescription, @@ -37,6 +36,7 @@ SchedulerRequestResult, SchedulingStrategy, ) +from guidellm.utils import StandardBaseModel __all__ = ["Benchmarker", "BenchmarkerResult", "GenerativeBenchmarker"] diff --git a/src/guidellm/benchmark/output.py b/src/guidellm/benchmark/output.py index 8a113f72..dd94f899 100644 --- a/src/guidellm/benchmark/output.py +++ b/src/guidellm/benchmark/output.py @@ -23,13 +23,12 @@ from guidellm.config import settings from guidellm.objects import ( DistributionSummary, - StandardBaseModel, StatusDistributionSummary, ) from guidellm.presentation import UIDataBuilder from guidellm.presentation.injector import create_report from guidellm.scheduler import strategy_display_str -from guidellm.utils import Colors, split_text_list_by_length +from guidellm.utils import Colors, StandardBaseModel, split_text_list_by_length __all__ = [ "GenerativeBenchmarksConsole", diff --git a/src/guidellm/benchmark/profile.py b/src/guidellm/benchmark/profile.py index 642cb7a8..d46f2b16 100644 --- a/src/guidellm/benchmark/profile.py +++ b/src/guidellm/benchmark/profile.py @@ -5,7 +5,6 @@ from pydantic import Field, computed_field from guidellm.config import settings -from guidellm.objects import StandardBaseModel from guidellm.scheduler import ( AsyncConstantStrategy, AsyncPoissonStrategy, @@ -15,6 +14,7 @@ SynchronousStrategy, ThroughputStrategy, ) +from guidellm.utils import StandardBaseModel __all__ = [ "AsyncProfile", diff --git a/src/guidellm/benchmark/scenario.py b/src/guidellm/benchmark/scenario.py index af43e426..57dfa98b 100644 --- a/src/guidellm/benchmark/scenario.py +++ b/src/guidellm/benchmark/scenario.py @@ -11,8 +11,8 @@ from guidellm.backend.backend import BackendType from guidellm.benchmark.profile import ProfileType -from guidellm.objects.pydantic import StandardBaseModel from guidellm.scheduler.strategy import StrategyType +from guidellm.utils import StandardBaseModel __ALL__ = ["Scenario", "GenerativeTextScenario", "get_builtin_scenarios"] diff --git a/src/guidellm/objects/__init__.py b/src/guidellm/objects/__init__.py index 89e3c9b9..119ac6e7 100644 --- a/src/guidellm/objects/__init__.py +++ b/src/guidellm/objects/__init__.py @@ -1,4 +1,3 @@ -from .pydantic import StandardBaseModel, StatusBreakdown from .statistics import ( DistributionSummary, Percentiles, @@ -11,8 +10,6 @@ "DistributionSummary", "Percentiles", "RunningStats", - "StandardBaseModel", - "StatusBreakdown", "StatusDistributionSummary", "TimeRunningStats", ] diff --git a/src/guidellm/objects/pydantic.py b/src/guidellm/objects/pydantic.py deleted file mode 100644 index fcededcf..00000000 --- a/src/guidellm/objects/pydantic.py +++ /dev/null @@ -1,89 +0,0 @@ -import json -from pathlib import Path -from typing import Any, Generic, Optional, TypeVar - -import yaml -from loguru import logger -from pydantic import BaseModel, ConfigDict, Field - -__all__ = ["StandardBaseModel", "StatusBreakdown"] - -T = TypeVar("T", bound="StandardBaseModel") - - -class StandardBaseModel(BaseModel): - """ - A base class for Pydantic models throughout GuideLLM enabling standard - configuration and logging. - """ - - model_config = ConfigDict( - extra="ignore", - use_enum_values=True, - validate_assignment=True, - from_attributes=True, - ) - - def __init__(self, /, **data: Any) -> None: - super().__init__(**data) - logger.debug( - "Initialized new instance of {} with data: {}", - self.__class__.__name__, - data, - ) - - @classmethod - def get_default(cls: type[T], field: str) -> Any: - """Get default values for model fields""" - return cls.model_fields[field].default - - @classmethod - def from_file(cls: type[T], filename: Path, overrides: Optional[dict] = None) -> T: - """ - Attempt to create a new instance of the model using - data loaded from json or yaml file. - """ - try: - with filename.open() as f: - if str(filename).endswith(".json"): - data = json.load(f) - else: # Assume everything else is yaml - data = yaml.safe_load(f) - except (json.JSONDecodeError, yaml.YAMLError) as e: - logger.error(f"Failed to parse {filename} as type {cls.__name__}") - raise ValueError(f"Error when parsing file: {filename}") from e - - data.update(overrides) - return cls.model_validate(data) - - -SuccessfulT = TypeVar("SuccessfulT") -ErroredT = TypeVar("ErroredT") -IncompleteT = TypeVar("IncompleteT") -TotalT = TypeVar("TotalT") - - -class StatusBreakdown(BaseModel, Generic[SuccessfulT, ErroredT, IncompleteT, TotalT]): - """ - A base class for Pydantic models that are separated by statuses including - successful, incomplete, and errored. It additionally enables the inclusion - of total, which is intended as the combination of all statuses. - Total may or may not be used depending on if it duplicates information. - """ - - successful: SuccessfulT = Field( - description="The results with a successful status.", - default=None, # type: ignore[assignment] - ) - errored: ErroredT = Field( - description="The results with an errored status.", - default=None, # type: ignore[assignment] - ) - incomplete: IncompleteT = Field( - description="The results with an incomplete status.", - default=None, # type: ignore[assignment] - ) - total: TotalT = Field( - description="The combination of all statuses.", - default=None, # type: ignore[assignment] - ) diff --git a/src/guidellm/objects/statistics.py b/src/guidellm/objects/statistics.py index 8ba504be..669aef6d 100644 --- a/src/guidellm/objects/statistics.py +++ b/src/guidellm/objects/statistics.py @@ -6,7 +6,7 @@ import numpy as np from pydantic import Field, computed_field -from guidellm.objects.pydantic import StandardBaseModel, StatusBreakdown +from guidellm.utils import StandardBaseModel, StatusBreakdown __all__ = [ "DistributionSummary", diff --git a/src/guidellm/request/loader.py b/src/guidellm/request/loader.py index 50ab3cca..2eff87d5 100644 --- a/src/guidellm/request/loader.py +++ b/src/guidellm/request/loader.py @@ -13,8 +13,8 @@ from guidellm.config import settings from guidellm.dataset import ColumnInputTypes, load_dataset -from guidellm.objects import StandardBaseModel from guidellm.request.request import GenerationRequest +from guidellm.utils import StandardBaseModel __all__ = [ "GenerativeRequestLoader", diff --git a/src/guidellm/request/request.py b/src/guidellm/request/request.py index 81c8cabd..bf4e59fb 100644 --- a/src/guidellm/request/request.py +++ b/src/guidellm/request/request.py @@ -3,7 +3,7 @@ from pydantic import Field -from guidellm.objects.pydantic import StandardBaseModel +from guidellm.utils import StandardBaseModel __all__ = ["GenerationRequest"] diff --git a/src/guidellm/scheduler/result.py b/src/guidellm/scheduler/result.py index 0f12687f..0cca530b 100644 --- a/src/guidellm/scheduler/result.py +++ b/src/guidellm/scheduler/result.py @@ -4,9 +4,9 @@ Optional, ) -from guidellm.objects import StandardBaseModel from guidellm.scheduler.strategy import SchedulingStrategy from guidellm.scheduler.types import RequestT, ResponseT +from guidellm.utils import StandardBaseModel __all__ = [ "SchedulerRequestInfo", diff --git a/src/guidellm/scheduler/strategy.py b/src/guidellm/scheduler/strategy.py index 200c799e..d4c065da 100644 --- a/src/guidellm/scheduler/strategy.py +++ b/src/guidellm/scheduler/strategy.py @@ -12,7 +12,7 @@ from pydantic import Field from guidellm.config import settings -from guidellm.objects import StandardBaseModel +from guidellm.utils import StandardBaseModel __all__ = [ "AsyncConstantStrategy", diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index a53b14c2..ab16e4db 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -24,10 +24,10 @@ ResponseSummary, StreamingTextResponse, ) -from guidellm.objects import StandardBaseModel from guidellm.request import GenerationRequest from guidellm.scheduler.result import SchedulerRequestInfo from guidellm.scheduler.types import RequestT, ResponseT +from guidellm.utils import StandardBaseModel __all__ = [ "GenerativeRequestsWorker", diff --git a/src/guidellm/utils/__init__.py b/src/guidellm/utils/__init__.py index fb9262c3..98ac1c36 100644 --- a/src/guidellm/utils/__init__.py +++ b/src/guidellm/utils/__init__.py @@ -1,3 +1,4 @@ +from .auto_importer import AutoImporterMixin from .colors import Colors from .default_group import DefaultGroupHandler from .hf_datasets import ( @@ -7,7 +8,16 @@ from .hf_transformers import ( check_load_processor, ) +from .pydantic_utils import ( + PydanticClassRegistryMixin, + ReloadableBaseModel, + StandardBaseDict, + StandardBaseModel, + StatusBreakdown, +) from .random import IntegerRangeSampler +from .registry import RegistryMixin +from .singleton import SingletonMixin, ThreadSafeSingletonMixin from .text import ( EndlessTextCreator, clean_text, @@ -20,10 +30,19 @@ __all__ = [ "SUPPORTED_TYPES", + "AutoImporterMixin", "Colors", "DefaultGroupHandler", "EndlessTextCreator", "IntegerRangeSampler", + "PydanticClassRegistryMixin", + "RegistryMixin", + "ReloadableBaseModel", + "SingletonMixin", + "StandardBaseDict", + "StandardBaseModel", + "StatusBreakdown", + "ThreadSafeSingletonMixin", "check_load_processor", "clean_text", "filter_text", diff --git a/src/guidellm/utils/auto_importer.py b/src/guidellm/utils/auto_importer.py new file mode 100644 index 00000000..5b939014 --- /dev/null +++ b/src/guidellm/utils/auto_importer.py @@ -0,0 +1,98 @@ +""" +Automatic module importing utilities for dynamic class discovery. + +This module provides a mixin class for automatic module importing within a package, +enabling dynamic discovery of classes and implementations without explicit imports. +It is particularly useful for auto-registering classes in a registry pattern where +subclasses need to be discoverable at runtime. + +The AutoImporterMixin can be combined with registration mechanisms to create +extensible systems where new implementations are automatically discovered and +registered when they are placed in the correct package structure. +""" + +from __future__ import annotations + +import importlib +import pkgutil +import sys +from typing import ClassVar + +__all__ = ["AutoImporterMixin"] + + +class AutoImporterMixin: + """ + Mixin class for automatic module importing within packages. + + This mixin enables dynamic discovery of classes and implementations without + explicit imports by automatically importing all modules within specified + packages. It is designed for use with class registration mechanisms to enable + automatic discovery and registration of classes when they are placed in the + correct package structure. + + Example: + :: + from guidellm.utils import AutoImporterMixin + + class MyRegistry(AutoImporterMixin): + auto_package = "my_package.implementations" + + MyRegistry.auto_import_package_modules() + + :cvar auto_package: Package name or tuple of package names to import modules from + :cvar auto_ignore_modules: Module names to ignore during import + :cvar auto_imported_modules: List tracking which modules have been imported + """ + + auto_package: ClassVar[str | tuple[str, ...] | None] = None + auto_ignore_modules: ClassVar[tuple[str, ...] | None] = None + auto_imported_modules: ClassVar[list[str] | None] = None + + @classmethod + def auto_import_package_modules(cls) -> None: + """ + Automatically import all modules within the specified package(s). + + Scans the package(s) defined in the `auto_package` class variable and imports + all modules found, tracking them in `auto_imported_modules`. Skips packages + (directories) and any modules listed in `auto_ignore_modules`. + + :raises ValueError: If the `auto_package` class variable is not set + """ + if cls.auto_package is None: + raise ValueError( + "The class variable 'auto_package' must be set to the package name to " + "import modules from." + ) + + cls.auto_imported_modules = [] + packages = ( + cls.auto_package + if isinstance(cls.auto_package, tuple) + else (cls.auto_package,) + ) + + for package_name in packages: + package = importlib.import_module(package_name) + + for _, module_name, is_pkg in pkgutil.walk_packages( + package.__path__, package.__name__ + "." + ): + if ( + is_pkg + or ( + cls.auto_ignore_modules is not None + and module_name in cls.auto_ignore_modules + ) + or module_name in cls.auto_imported_modules + ): + # Skip packages and ignored modules + continue + + if module_name in sys.modules: + # Avoid circular imports + cls.auto_imported_modules.append(module_name) + else: + importlib.import_module(module_name) + cls.auto_imported_modules.append(module_name) diff --git a/src/guidellm/utils/pydantic_utils.py b/src/guidellm/utils/pydantic_utils.py new file mode 100644 index 00000000..52bf6564 --- /dev/null +++ b/src/guidellm/utils/pydantic_utils.py @@ -0,0 +1,302 @@ +""" +Pydantic utilities for polymorphic model serialization and registry integration. + +Provides integration between Pydantic and the registry system, enabling +polymorphic serialization and deserialization of Pydantic models using +a discriminator field and dynamic class registry. Includes base model classes +with standardized configurations and generic status breakdown models for +structured result organization. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, ClassVar, Generic, TypeVar + +from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler +from pydantic_core import CoreSchema, core_schema + +from guidellm.utils.registry import RegistryMixin + +__all__ = [ + "PydanticClassRegistryMixin", + "ReloadableBaseModel", + "StandardBaseDict", + "StandardBaseModel", + "StatusBreakdown", +] + + +BaseModelT = TypeVar("BaseModelT", bound=BaseModel) +SuccessfulT = TypeVar("SuccessfulT") +ErroredT = TypeVar("ErroredT") +IncompleteT = TypeVar("IncompleteT") +TotalT = TypeVar("TotalT") + + +class ReloadableBaseModel(BaseModel): + """ + Base Pydantic model with schema reloading capabilities. + + Provides dynamic schema rebuilding functionality for models that need to + update their validation schemas at runtime, particularly useful when + working with registry-based polymorphic models where new types are + registered after initial class definition. + """ + + model_config = ConfigDict( + extra="ignore", + use_enum_values=True, + validate_assignment=True, + from_attributes=True, + arbitrary_types_allowed=True, + ) + + @classmethod + def reload_schema(cls) -> None: + """ + Reload the class schema with updated registry information. + + Forces a complete rebuild of the Pydantic model schema to incorporate + any changes made to associated registries or validation rules. + """ + cls.model_rebuild(force=True) + + +class StandardBaseModel(BaseModel): + """ + Base Pydantic model with standardized configuration for GuideLLM. + + Provides consistent validation behavior and configuration settings across + all Pydantic models in the application, including field validation, + attribute conversion, and default value handling. + + Example: + :: + class MyModel(StandardBaseModel): + name: str + value: int = 42 + + # Access default values + default_value = MyModel.get_default("value") # Returns 42 + """ + + model_config = ConfigDict( + extra="ignore", + use_enum_values=True, + validate_assignment=True, + from_attributes=True, + ) + + @classmethod + def get_default(cls: type[BaseModelT], field: str) -> Any: + """ + Get default value for a model field. + + :param field: Name of the field to get the default value for + :return: Default value of the specified field + :raises KeyError: If the field does not exist in the model + """ + return cls.model_fields[field].default + + +class StandardBaseDict(StandardBaseModel): + """ + Base Pydantic model allowing arbitrary additional fields. + + Extends StandardBaseModel to accept extra fields beyond those explicitly + defined in the model schema. Useful for flexible data structures that + need to accommodate varying or unknown field sets while maintaining + type safety for known fields. + """ + + model_config = ConfigDict( + extra="allow", + use_enum_values=True, + validate_assignment=True, + from_attributes=True, + arbitrary_types_allowed=True, + ) + + +class StatusBreakdown(BaseModel, Generic[SuccessfulT, ErroredT, IncompleteT, TotalT]): + """ + Generic model for organizing results by processing status. + + Provides structured categorization of results into successful, errored, + incomplete, and total status groups. Supports flexible typing for each + status category to accommodate different result types while maintaining + consistent organization patterns across the application. + + Example: + :: + from guidellm.utils.pydantic_utils import StatusBreakdown + + # Define a breakdown for request counts + breakdown = StatusBreakdown[int, int, int, int]( + successful=150, + errored=5, + incomplete=10, + total=165 + ) + """ + + successful: SuccessfulT = Field( + description="Results or metrics for requests with successful completion status", + default=None, # type: ignore[assignment] + ) + errored: ErroredT = Field( + description="Results or metrics for requests with error completion status", + default=None, # type: ignore[assignment] + ) + incomplete: IncompleteT = Field( + description="Results or metrics for requests with incomplete processing status", + default=None, # type: ignore[assignment] + ) + total: TotalT = Field( + description="Aggregated results or metrics combining all status categories", + default=None, # type: ignore[assignment] + ) + + +class PydanticClassRegistryMixin( + ReloadableBaseModel, RegistryMixin[type[BaseModelT]], ABC, Generic[BaseModelT] +): + """ + Polymorphic Pydantic model mixin enabling registry-based dynamic instantiation. + + Integrates Pydantic validation with the registry system to enable polymorphic + serialization and deserialization based on a discriminator field. Automatically + instantiates the correct subclass during validation based on registry mappings, + providing a foundation for extensible plugin-style architectures. + + Example: + :: + from guidellm.utils.pydantic_utils import PydanticClassRegistryMixin + + class BaseConfig(PydanticClassRegistryMixin["BaseConfig"]): + schema_discriminator: ClassVar[str] = "config_type" + config_type: str = Field(description="Configuration type identifier") + + @classmethod + def __pydantic_schema_base_type__(cls) -> type["BaseConfig"]: + return BaseConfig + + @BaseConfig.register("database") + class DatabaseConfig(BaseConfig): + config_type: str = "database" + connection_string: str = Field(description="Database connection string") + + # Dynamic instantiation based on discriminator + config = BaseConfig.model_validate({ + "config_type": "database", + "connection_string": "postgresql://localhost:5432/db" + }) + + :cvar schema_discriminator: Field name used for polymorphic type discrimination + """ + + schema_discriminator: ClassVar[str] = "model_type" + + @classmethod + def register_decorator( + cls, clazz: type[BaseModelT], name: str | list[str] | None = None + ) -> type[BaseModelT]: + """ + Register a Pydantic model class with type validation and schema reload. + + Validates that the class is a proper Pydantic BaseModel subclass before + registering it in the class registry. Automatically triggers schema + reload to incorporate the new type into polymorphic validation. + + :param clazz: Pydantic model class to register in the polymorphic hierarchy + :param name: Registry identifier for the class. Uses class name if None + :return: The registered class unchanged for decorator chaining + :raises TypeError: If clazz is not a Pydantic BaseModel subclass + """ + if not issubclass(clazz, BaseModel): + raise TypeError( + f"Cannot register {clazz.__name__} as it is not a subclass of " + "Pydantic BaseModel" + ) + + dec_clazz = super().register_decorator(clazz, name=name) + cls.reload_schema() + + return dec_clazz + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + """ + Generate polymorphic validation schema for dynamic type instantiation. + + Creates a tagged union schema that enables Pydantic to automatically + instantiate the correct subclass based on the discriminator field value. + Falls back to base schema generation when no registry is available. + + :param source_type: Type being processed for schema generation + :param handler: Pydantic core schema generation handler + :return: Tagged union schema for polymorphic validation or base schema + """ + if source_type == cls.__pydantic_schema_base_type__(): + if not cls.registry: + return cls.__pydantic_generate_base_schema__(handler) + + choices = { + name: handler(model_class) for name, model_class in cls.registry.items() + } + + return core_schema.tagged_union_schema( + choices=choices, + discriminator=cls.schema_discriminator, + ) + + return handler(cls) + + @classmethod + @abstractmethod + def __pydantic_schema_base_type__(cls) -> type[BaseModelT]: + """ + Define the base type for polymorphic validation hierarchy. + + Must be implemented by subclasses to specify which type serves as the + root of the polymorphic hierarchy for schema generation and validation. + + :return: Base class type for the polymorphic model hierarchy + """ + ... + + @classmethod + def __pydantic_generate_base_schema__( + cls, handler: GetCoreSchemaHandler + ) -> CoreSchema: + """ + Generate fallback schema for polymorphic models without registry. + + Provides a base schema that accepts any valid input when no registry + is available for polymorphic validation. Used as fallback during + schema generation when the registry has not been populated. + + :param handler: Pydantic core schema generation handler + :return: Base CoreSchema that accepts any valid input + """ + return core_schema.any_schema() + + @classmethod + def auto_populate_registry(cls) -> bool: + """ + Initialize registry with auto-discovery and reload validation schema. + + Triggers automatic population of the class registry through the parent + RegistryMixin functionality and ensures the Pydantic validation schema + is updated to include all discovered types for polymorphic validation. + + :return: True if registry was populated, False if already populated + :raises ValueError: If called when registry_auto_discovery is disabled + """ + populated = super().auto_populate_registry() + cls.reload_schema() + + return populated diff --git a/src/guidellm/utils/registry.py b/src/guidellm/utils/registry.py new file mode 100644 index 00000000..5d4bc055 --- /dev/null +++ b/src/guidellm/utils/registry.py @@ -0,0 +1,206 @@ +""" +Registry system for dynamic object registration and discovery. + +Provides a flexible object registration system with optional auto-discovery +capabilities through decorators and module imports. Enables dynamic discovery +and instantiation of implementations based on configuration parameters, supporting +both manual registration and automatic package-based discovery for extensible +plugin architectures. +""" + +from __future__ import annotations + +from typing import Any, Callable, ClassVar, Generic, TypeVar + +from guidellm.utils.auto_importer import AutoImporterMixin + +__all__ = ["RegistryMixin", "RegistryObjT"] + + +RegistryObjT = TypeVar("RegistryObjT", bound=Any) +""" +Generic type variable for objects managed by the registry system. +""" + + +class RegistryMixin(Generic[RegistryObjT], AutoImporterMixin): + """ + Generic mixin for creating object registries with optional auto-discovery. + + Enables classes to maintain separate registries of objects that can be + dynamically discovered and instantiated through decorators and module imports. + Supports both manual registration via decorators and automatic discovery + through package scanning for extensible plugin architectures. + + Example: + :: + class BaseAlgorithm(RegistryMixin): + pass + + @BaseAlgorithm.register() + class ConcreteAlgorithm(BaseAlgorithm): + pass + + @BaseAlgorithm.register("custom_name") + class AnotherAlgorithm(BaseAlgorithm): + pass + + # Get all registered implementations + algorithms = BaseAlgorithm.registered_objects() + + Example with auto-discovery: + :: + class TokenProposal(RegistryMixin): + registry_auto_discovery = True + auto_package = "mypackage.proposals" + + # Automatically imports and registers decorated objects + proposals = TokenProposal.registered_objects() + + :cvar registry: Dictionary mapping names to registered objects + :cvar registry_auto_discovery: Enable automatic package-based discovery + :cvar registry_populated: Track whether auto-discovery has completed + """ + + registry: ClassVar[dict[str, RegistryObjT] | None] = None + registry_auto_discovery: ClassVar[bool] = False + registry_populated: ClassVar[bool] = False + + @classmethod + def register( + cls, name: str | list[str] | None = None + ) -> Callable[[RegistryObjT], RegistryObjT]: + """ + Decorator that registers an object with the registry. + + :param name: Optional name(s) to register the object under. + If None, the object name is used as the registry key. + :return: A decorator function that registers the decorated object. + :raises ValueError: If name is provided but is not a string or list of strings. + """ + if name is not None and not isinstance(name, (str, list)): + raise ValueError( + "RegistryMixin.register() name must be a string, list of strings, " + f"or None. Got {name}." + ) + + return lambda obj: cls.register_decorator(obj, name=name) + + @classmethod + def register_decorator( + cls, obj: RegistryObjT, name: str | list[str] | None = None + ) -> RegistryObjT: + """ + Direct decorator that registers an object with the registry. + + :param obj: The object to register. + :param name: Optional name(s) to register the object under. + If None, the object name is used as the registry key. + :return: The registered object. + :raises ValueError: If the object is already registered or if name is invalid. + """ + + if not name: + name = obj.__name__ + elif not isinstance(name, (str, list)): + raise ValueError( + "RegistryMixin.register_decorator name must be a string or " + f"an iterable of strings. Got {name}." + ) + + if cls.registry is None: + cls.registry = {} + + names = [name] if isinstance(name, str) else list(name) + + for register_name in names: + if not isinstance(register_name, str): + raise ValueError( + "RegistryMixin.register_decorator name must be a string or " + f"a list of strings. Got {register_name}." + ) + + if register_name in cls.registry: + raise ValueError( + f"RegistryMixin.register_decorator cannot register an object " + f"{obj} with the name {register_name} because it is already " + "registered." + ) + + cls.registry[register_name.lower()] = obj + + return obj + + @classmethod + def auto_populate_registry(cls) -> bool: + """ + Import and register all modules from the specified auto_package. + + Automatically called by registered_objects when registry_auto_discovery is True + to ensure all available implementations are discovered before returning results. + + :return: True if the registry was populated, False if already populated. + :raises ValueError: If called when registry_auto_discovery is False. + """ + if not cls.registry_auto_discovery: + raise ValueError( + "RegistryMixin.auto_populate_registry() cannot be called " + "because registry_auto_discovery is set to False. " + "Set registry_auto_discovery to True to enable auto-discovery." + ) + + if cls.registry_populated: + return False + + cls.auto_import_package_modules() + cls.registry_populated = True + + return True + + @classmethod + def registered_objects(cls) -> tuple[RegistryObjT, ...]: + """ + Get all registered objects from the registry. + + Automatically triggers auto-discovery if registry_auto_discovery is enabled + to ensure all available implementations are included. + + :return: Tuple of all registered objects including auto-discovered ones. + :raises ValueError: If called before any objects have been registered. + """ + if cls.registry_auto_discovery: + cls.auto_populate_registry() + + if cls.registry is None: + raise ValueError( + "RegistryMixin.registered_objects() must be called after " + "registering objects with RegistryMixin.register()." + ) + + return tuple(cls.registry.values()) + + @classmethod + def is_registered(cls, name: str) -> bool: + """ + Check if an object is registered under the given name. + + :param name: The name to check for registration. + :return: True if the object is registered, False otherwise. + """ + if cls.registry is None: + return False + + return name.lower() in cls.registry + + @classmethod + def get_registered_object(cls, name: str) -> RegistryObjT | None: + """ + Get a registered object by its name. + + :param name: The name of the registered object. + :return: The registered object if found, None otherwise. + """ + if cls.registry is None: + return None + + return cls.registry.get(name.lower()) diff --git a/src/guidellm/utils/singleton.py b/src/guidellm/utils/singleton.py new file mode 100644 index 00000000..3ec10f79 --- /dev/null +++ b/src/guidellm/utils/singleton.py @@ -0,0 +1,130 @@ +""" +Singleton pattern implementations for ensuring single instance classes. + +Provides singleton mixins for creating classes that maintain a single instance +throughout the application lifecycle, with support for both basic and thread-safe +implementations. These mixins integrate with the scheduler and other system components +to ensure consistent state management and prevent duplicate resource allocation. +""" + +from __future__ import annotations + +import threading + +__all__ = ["SingletonMixin", "ThreadSafeSingletonMixin"] + + +class SingletonMixin: + """ + Basic singleton mixin ensuring single instance per class. + + Implements the singleton pattern using class variables to control instance + creation. Subclasses must call super().__init__() for proper initialization + state management. Suitable for single-threaded environments or when external + synchronization is provided. + + Example: + :: + class ConfigManager(SingletonMixin): + def __init__(self, config_path: str): + super().__init__() + if not self.initialized: + self.config = load_config(config_path) + + manager1 = ConfigManager("config.json") + manager2 = ConfigManager("config.json") + assert manager1 is manager2 + """ + + def __new__(cls, *args, **kwargs): # noqa: ARG004 + """ + Create or return the singleton instance. + + :param args: Positional arguments passed to the constructor + :param kwargs: Keyword arguments passed to the constructor + :return: The singleton instance of the class + """ + # Use class-specific attribute name to avoid inheritance issues + attr_name = f"_singleton_instance_{cls.__name__}" + + if not hasattr(cls, attr_name) or getattr(cls, attr_name) is None: + instance = super().__new__(cls) + setattr(cls, attr_name, instance) + instance._singleton_initialized = False + return getattr(cls, attr_name) + + def __init__(self): + """Initialize the singleton instance exactly once.""" + if hasattr(self, "_singleton_initialized") and self._singleton_initialized: + return + self._singleton_initialized = True + + @property + def initialized(self): + """Return True if the singleton has been initialized.""" + return getattr(self, "_singleton_initialized", False) + + +class ThreadSafeSingletonMixin(SingletonMixin): + """ + Thread-safe singleton mixin with locking mechanisms. + + Extends SingletonMixin with thread safety using locks to prevent race + conditions during instance creation in multi-threaded environments. Essential + for scheduler components and other shared resources accessed concurrently. + + Example: + :: + class SchedulerResource(ThreadSafeSingletonMixin): + def __init__(self): + super().__init__() + if not self.initialized: + self.resource_pool = initialize_resources() + """ + + def __new__(cls, *args, **kwargs): # noqa: ARG004 + """ + Create or return the singleton instance with thread safety. + + :param args: Positional arguments passed to the constructor + :param kwargs: Keyword arguments passed to the constructor + :return: The singleton instance of the class + """ + # Use class-specific lock and instance names to avoid inheritance issues + lock_attr_name = f"_singleton_lock_{cls.__name__}" + instance_attr_name = f"_singleton_instance_{cls.__name__}" + + with getattr(cls, lock_attr_name): + instance_exists = ( + hasattr(cls, instance_attr_name) + and getattr(cls, instance_attr_name) is not None + ) + if not instance_exists: + instance = super(SingletonMixin, cls).__new__(cls) + setattr(cls, instance_attr_name, instance) + instance._singleton_initialized = False + instance._init_lock = threading.Lock() + return getattr(cls, instance_attr_name) + + def __init_subclass__(cls, *args, **kwargs): + super().__init_subclass__(*args, **kwargs) + lock_attr_name = f"_singleton_lock_{cls.__name__}" + setattr(cls, lock_attr_name, threading.Lock()) + + def __init__(self): + """Initialize the singleton instance with thread-safe initialization.""" + with self._init_lock: + if hasattr(self, "_singleton_initialized") and self._singleton_initialized: + return + self._singleton_initialized = True + + @property + def thread_lock(self): + """Return the thread lock for this singleton instance.""" + return getattr(self, "_init_lock", None) + + @classmethod + def get_singleton_lock(cls): + """Get the class-specific singleton creation lock.""" + lock_attr_name = f"_singleton_lock_{cls.__name__}" + return getattr(cls, lock_attr_name, None) diff --git a/tests/unit/objects/test_pydantic.py b/tests/unit/objects/test_pydantic.py deleted file mode 100644 index cb7f438f..00000000 --- a/tests/unit/objects/test_pydantic.py +++ /dev/null @@ -1,43 +0,0 @@ -import pytest -from pydantic import computed_field - -from guidellm.objects.pydantic import StandardBaseModel - - -class ExampleModel(StandardBaseModel): - name: str - age: int - - @computed_field # type: ignore[misc] - @property - def computed(self) -> str: - return self.name + " " + str(self.age) - - -@pytest.mark.smoke -def test_standard_base_model_initialization(): - example = ExampleModel(name="John Doe", age=30) - assert example.name == "John Doe" - assert example.age == 30 - assert example.computed == "John Doe 30" - - -@pytest.mark.smoke -def test_standard_base_model_invalid_initialization(): - with pytest.raises(ValueError): - ExampleModel(name="John Doe", age="thirty") # type: ignore[arg-type] - - -@pytest.mark.smoke -def test_standard_base_model_marshalling(): - example = ExampleModel(name="John Doe", age=30) - serialized = example.model_dump() - assert serialized["name"] == "John Doe" - assert serialized["age"] == 30 - assert serialized["computed"] == "John Doe 30" - - serialized["computed"] = "Jane Doe 40" - deserialized = ExampleModel.model_validate(serialized) - assert deserialized.name == "John Doe" - assert deserialized.age == 30 - assert deserialized.computed == "John Doe 30" diff --git a/tests/unit/utils/test_auto_importer.py b/tests/unit/utils/test_auto_importer.py new file mode 100644 index 00000000..cc71bce3 --- /dev/null +++ b/tests/unit/utils/test_auto_importer.py @@ -0,0 +1,269 @@ +""" +Unit tests for the auto_importer module. +""" + +from __future__ import annotations + +from unittest import mock + +import pytest + +from guidellm.utils import AutoImporterMixin + + +class TestAutoImporterMixin: + """Test suite for AutoImporterMixin functionality.""" + + @pytest.fixture( + params=[ + { + "auto_package": "test.package", + "auto_ignore_modules": None, + "modules": [ + ("test.package.module1", False), + ("test.package.module2", False), + ], + "expected_imports": ["test.package.module1", "test.package.module2"], + }, + { + "auto_package": ("test.package1", "test.package2"), + "auto_ignore_modules": None, + "modules": [ + ("test.package1.moduleA", False), + ("test.package2.moduleB", False), + ], + "expected_imports": ["test.package1.moduleA", "test.package2.moduleB"], + }, + { + "auto_package": "test.package", + "auto_ignore_modules": ("test.package.module1",), + "modules": [ + ("test.package.module1", False), + ("test.package.module2", False), + ], + "expected_imports": ["test.package.module2"], + }, + ], + ids=["single_package", "multiple_packages", "ignored_modules"], + ) + def valid_instances(self, request): + """Fixture providing test data for AutoImporterMixin subclasses.""" + config = request.param + + class TestClass(AutoImporterMixin): + auto_package = config["auto_package"] + auto_ignore_modules = config["auto_ignore_modules"] + + return TestClass, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test AutoImporterMixin class signatures and attributes.""" + assert hasattr(AutoImporterMixin, "auto_package") + assert hasattr(AutoImporterMixin, "auto_ignore_modules") + assert hasattr(AutoImporterMixin, "auto_imported_modules") + assert hasattr(AutoImporterMixin, "auto_import_package_modules") + assert callable(AutoImporterMixin.auto_import_package_modules) + + # Test default class variables + assert AutoImporterMixin.auto_package is None + assert AutoImporterMixin.auto_ignore_modules is None + assert AutoImporterMixin.auto_imported_modules is None + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test AutoImporterMixin subclass initialization.""" + test_class, config = valid_instances + assert issubclass(test_class, AutoImporterMixin) + assert test_class.auto_package == config["auto_package"] + assert test_class.auto_ignore_modules == config["auto_ignore_modules"] + assert test_class.auto_imported_modules is None + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test AutoImporterMixin with missing auto_package.""" + + class TestClass(AutoImporterMixin): + pass + + with pytest.raises(ValueError, match="auto_package.*must be set"): + TestClass.auto_import_package_modules() + + @pytest.mark.smoke + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_auto_import_package_modules(self, mock_walk, mock_import, valid_instances): + """Test auto_import_package_modules core functionality.""" + test_class, config = valid_instances + + # Setup mocks based on config + packages = {} + modules = {} + + if isinstance(config["auto_package"], tuple): + for pkg in config["auto_package"]: + pkg_path = pkg.replace(".", "/") + packages[pkg] = MockHelper.create_mock_package(pkg, pkg_path) + else: + pkg = config["auto_package"] + packages[pkg] = MockHelper.create_mock_package(pkg, pkg.replace(".", "/")) + + for module_name, is_pkg in config["modules"]: + if not is_pkg: + modules[module_name] = MockHelper.create_mock_module(module_name) + + mock_import.side_effect = lambda name: {**packages, **modules}.get( + name, mock.MagicMock() + ) + + def walk_side_effect(path, prefix): + return [ + (None, module_name, is_pkg) + for module_name, is_pkg in config["modules"] + if module_name.startswith(prefix) + ] + + mock_walk.side_effect = walk_side_effect + + # Execute + test_class.auto_import_package_modules() + + # Verify + assert test_class.auto_imported_modules == config["expected_imports"] + + # Verify package imports + if isinstance(config["auto_package"], tuple): + for pkg in config["auto_package"]: + mock_import.assert_any_call(pkg) + else: + mock_import.assert_any_call(config["auto_package"]) + + # Verify expected module imports + for expected_module in config["expected_imports"]: + mock_import.assert_any_call(expected_module) + + @pytest.mark.sanity + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_auto_import_package_modules_invalid(self, mock_walk, mock_import): + """Test auto_import_package_modules with invalid configurations.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Test import error handling + mock_import.side_effect = ImportError("Module not found") + + with pytest.raises(ImportError): + TestClass.auto_import_package_modules() + + @pytest.mark.sanity + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_skip_packages(self, mock_walk, mock_import): + """Test that packages (is_pkg=True) are skipped.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Setup mocks + mock_package = MockHelper.create_mock_package("test.package", "test/package") + mock_module = MockHelper.create_mock_module("test.package.module") + + mock_import.side_effect = lambda name: { + "test.package": mock_package, + "test.package.module": mock_module, + }[name] + + mock_walk.return_value = [ + (None, "test.package.subpackage", True), + (None, "test.package.module", False), + ] + + # Execute + TestClass.auto_import_package_modules() + + # Verify + assert TestClass.auto_imported_modules == ["test.package.module"] + mock_import.assert_any_call("test.package.module") + # subpackage should not be imported + with pytest.raises(AssertionError): + mock_import.assert_any_call("test.package.subpackage") + + @pytest.mark.sanity + @mock.patch("sys.modules", {"test.package.existing": mock.MagicMock()}) + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_skip_already_imported_modules(self, mock_walk, mock_import): + """Test that modules already in sys.modules are tracked but not re-imported.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Setup mocks + mock_package = MockHelper.create_mock_package("test.package", "test/package") + mock_import.side_effect = lambda name: { + "test.package": mock_package, + }.get(name, mock.MagicMock()) + + mock_walk.return_value = [ + (None, "test.package.existing", False), + ] + + # Execute + TestClass.auto_import_package_modules() + + # Verify + assert TestClass.auto_imported_modules == ["test.package.existing"] + mock_import.assert_called_once_with("test.package") + with pytest.raises(AssertionError): + mock_import.assert_any_call("test.package.existing") + + @pytest.mark.sanity + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_prevent_duplicate_module_imports(self, mock_walk, mock_import): + """Test that modules already in auto_imported_modules are not re-imported.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Setup mocks + mock_package = MockHelper.create_mock_package("test.package", "test/package") + mock_module = MockHelper.create_mock_module("test.package.module") + + mock_import.side_effect = lambda name: { + "test.package": mock_package, + "test.package.module": mock_module, + }[name] + + mock_walk.return_value = [ + (None, "test.package.module", False), + (None, "test.package.module", False), + ] + + # Execute + TestClass.auto_import_package_modules() + + # Verify + assert TestClass.auto_imported_modules == ["test.package.module"] + assert mock_import.call_count == 2 # Package + module (not duplicate) + + +class MockHelper: + """Helper class to create consistent mock objects for testing.""" + + @staticmethod + def create_mock_package(name: str, path: str): + """Create a mock package with required attributes.""" + package = mock.MagicMock() + package.__name__ = name + package.__path__ = [path] + return package + + @staticmethod + def create_mock_module(name: str): + """Create a mock module with required attributes.""" + module = mock.MagicMock() + module.__name__ = name + return module diff --git a/tests/unit/utils/test_pydantic_utils.py b/tests/unit/utils/test_pydantic_utils.py new file mode 100644 index 00000000..8683604b --- /dev/null +++ b/tests/unit/utils/test_pydantic_utils.py @@ -0,0 +1,710 @@ +""" +Unit tests for the pydantic_utils module. +""" + +from __future__ import annotations + +from typing import ClassVar +from unittest import mock + +import pytest +from pydantic import BaseModel, Field, ValidationError + +from guidellm.utils.pydantic_utils import ( + PydanticClassRegistryMixin, + ReloadableBaseModel, + StandardBaseDict, + StandardBaseModel, + StatusBreakdown, +) + + +class TestReloadableBaseModel: + """Test suite for ReloadableBaseModel.""" + + @pytest.fixture( + params=[ + {"name": "test_value"}, + {"name": "hello_world"}, + {"name": "another_test"}, + ], + ids=["basic_string", "multi_word", "underscore"], + ) + def valid_instances(self, request) -> tuple[ReloadableBaseModel, dict[str, str]]: + """Fixture providing test data for ReloadableBaseModel.""" + + class TestModel(ReloadableBaseModel): + name: str + + constructor_args = request.param + instance = TestModel(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test ReloadableBaseModel inheritance and class variables.""" + assert issubclass(ReloadableBaseModel, BaseModel) + assert hasattr(ReloadableBaseModel, "model_config") + assert hasattr(ReloadableBaseModel, "reload_schema") + + # Check model configuration + config = ReloadableBaseModel.model_config + assert config["extra"] == "ignore" + assert config["use_enum_values"] is True + assert config["validate_assignment"] is True + assert config["from_attributes"] is True + assert config["arbitrary_types_allowed"] is True + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test ReloadableBaseModel initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, ReloadableBaseModel) + assert instance.name == constructor_args["name"] # type: ignore[attr-defined] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("name", None), + ("name", 123), + ("name", []), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test ReloadableBaseModel with invalid field values.""" + + class TestModel(ReloadableBaseModel): + name: str + + data = {field: value} + with pytest.raises(ValidationError): + TestModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test ReloadableBaseModel initialization without required field.""" + + class TestModel(ReloadableBaseModel): + name: str + + with pytest.raises(ValidationError): + TestModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_reload_schema(self): + """Test ReloadableBaseModel.reload_schema method.""" + + class TestModel(ReloadableBaseModel): + name: str + + # Mock the model_rebuild method to simulate schema reload + with mock.patch.object(TestModel, "model_rebuild") as mock_rebuild: + TestModel.reload_schema() + mock_rebuild.assert_called_once_with(force=True) + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test ReloadableBaseModel serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["name"] == constructor_args["name"] + + recreated = instance.__class__.model_validate(data_dict) + assert isinstance(recreated, instance.__class__) + assert recreated.name == constructor_args["name"] + + +class TestStandardBaseModel: + """Test suite for StandardBaseModel.""" + + @pytest.fixture( + params=[ + {"field_str": "test_value", "field_int": 42}, + {"field_str": "hello_world", "field_int": 100}, + {"field_str": "another_test", "field_int": 0}, + ], + ids=["basic_values", "positive_values", "zero_value"], + ) + def valid_instances( + self, request + ) -> tuple[StandardBaseModel, dict[str, int | str]]: + """Fixture providing test data for StandardBaseModel.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=10, description="Test integer field") + + constructor_args = request.param + instance = TestModel(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test StandardBaseModel inheritance and class variables.""" + assert issubclass(StandardBaseModel, BaseModel) + assert hasattr(StandardBaseModel, "model_config") + assert hasattr(StandardBaseModel, "get_default") + + # Check model configuration + config = StandardBaseModel.model_config + assert config["extra"] == "ignore" + assert config["use_enum_values"] is True + assert config["validate_assignment"] is True + assert config["from_attributes"] is True + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test StandardBaseModel initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, StandardBaseModel) + assert instance.field_str == constructor_args["field_str"] # type: ignore[attr-defined] + assert instance.field_int == constructor_args["field_int"] # type: ignore[attr-defined] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("field_str", None), + ("field_str", 123), + ("field_int", "not_int"), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test StandardBaseModel with invalid field values.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=10, description="Test integer field") + + data = {field: value} + if field == "field_str": + data["field_int"] = 42 + else: + data["field_str"] = "test" + + with pytest.raises(ValidationError): + TestModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test StandardBaseModel initialization without required field.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=10, description="Test integer field") + + with pytest.raises(ValidationError): + TestModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_get_default(self): + """Test StandardBaseModel.get_default method.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=42, description="Test integer field") + + default_value = TestModel.get_default("field_int") + assert default_value == 42 + + @pytest.mark.sanity + def test_get_default_invalid(self): + """Test StandardBaseModel.get_default with invalid field.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + + with pytest.raises(KeyError): + TestModel.get_default("nonexistent_field") + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test StandardBaseModel serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["field_str"] == constructor_args["field_str"] + assert data_dict["field_int"] == constructor_args["field_int"] + + recreated = instance.__class__.model_validate(data_dict) + assert isinstance(recreated, instance.__class__) + assert recreated.field_str == constructor_args["field_str"] + assert recreated.field_int == constructor_args["field_int"] + + +class TestStandardBaseDict: + """Test suite for StandardBaseDict.""" + + @pytest.fixture( + params=[ + {"field_str": "test_value", "extra_field": "extra_value"}, + {"field_str": "hello_world", "another_extra": 123}, + {"field_str": "another_test", "complex_extra": {"nested": "value"}}, + ], + ids=["string_extra", "int_extra", "dict_extra"], + ) + def valid_instances( + self, request + ) -> tuple[StandardBaseDict, dict[str, str | int | dict[str, str]]]: + """Fixture providing test data for StandardBaseDict.""" + + class TestModel(StandardBaseDict): + field_str: str = Field(description="Test string field") + + constructor_args = request.param + instance = TestModel(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test StandardBaseDict inheritance and class variables.""" + assert issubclass(StandardBaseDict, StandardBaseModel) + assert hasattr(StandardBaseDict, "model_config") + + # Check model configuration + config = StandardBaseDict.model_config + assert config["extra"] == "allow" + assert config["use_enum_values"] is True + assert config["validate_assignment"] is True + assert config["from_attributes"] is True + assert config["arbitrary_types_allowed"] is True + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test StandardBaseDict initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, StandardBaseDict) + assert instance.field_str == constructor_args["field_str"] # type: ignore[attr-defined] + + # Check extra fields are preserved + for key, value in constructor_args.items(): + if key != "field_str": + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("field_str", None), + ("field_str", 123), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test StandardBaseDict with invalid field values.""" + + class TestModel(StandardBaseDict): + field_str: str = Field(description="Test string field") + + data = {field: value} + with pytest.raises(ValidationError): + TestModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test StandardBaseDict initialization without required field.""" + + class TestModel(StandardBaseDict): + field_str: str = Field(description="Test string field") + + with pytest.raises(ValidationError): + TestModel() # type: ignore[call-arg] + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test StandardBaseDict serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["field_str"] == constructor_args["field_str"] + + # Check extra fields are in the serialized data + for key, value in constructor_args.items(): + if key != "field_str": + assert key in data_dict + assert data_dict[key] == value + + recreated = instance.__class__.model_validate(data_dict) + assert isinstance(recreated, instance.__class__) + assert recreated.field_str == constructor_args["field_str"] + + # Check extra fields are preserved after deserialization + for key, value in constructor_args.items(): + if key != "field_str": + assert hasattr(recreated, key) + assert getattr(recreated, key) == value + + +class TestStatusBreakdown: + """Test suite for StatusBreakdown.""" + + @pytest.fixture( + params=[ + {"successful": 100, "errored": 5, "incomplete": 10, "total": 115}, + { + "successful": "success_data", + "errored": "error_data", + "incomplete": "incomplete_data", + "total": "total_data", + }, + { + "successful": [1, 2, 3], + "errored": [4, 5], + "incomplete": [6], + "total": [1, 2, 3, 4, 5, 6], + }, + ], + ids=["int_values", "string_values", "list_values"], + ) + def valid_instances(self, request) -> tuple[StatusBreakdown, dict]: + """Fixture providing test data for StatusBreakdown.""" + constructor_args = request.param + instance = StatusBreakdown(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test StatusBreakdown inheritance and type relationships.""" + assert issubclass(StatusBreakdown, BaseModel) + # Check if Generic is in the MRO (method resolution order) + assert any(cls.__name__ == "Generic" for cls in StatusBreakdown.__mro__) + assert "successful" in StatusBreakdown.model_fields + assert "errored" in StatusBreakdown.model_fields + assert "incomplete" in StatusBreakdown.model_fields + assert "total" in StatusBreakdown.model_fields + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test StatusBreakdown initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, StatusBreakdown) + assert instance.successful == constructor_args["successful"] + assert instance.errored == constructor_args["errored"] + assert instance.incomplete == constructor_args["incomplete"] + assert instance.total == constructor_args["total"] + + @pytest.mark.smoke + def test_initialization_defaults(self): + """Test StatusBreakdown initialization with default values.""" + instance: StatusBreakdown = StatusBreakdown() + assert instance.successful is None + assert instance.errored is None + assert instance.incomplete is None + assert instance.total is None + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test StatusBreakdown serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["successful"] == constructor_args["successful"] + assert data_dict["errored"] == constructor_args["errored"] + assert data_dict["incomplete"] == constructor_args["incomplete"] + assert data_dict["total"] == constructor_args["total"] + + recreated: StatusBreakdown = StatusBreakdown.model_validate(data_dict) + assert isinstance(recreated, StatusBreakdown) + assert recreated.successful == constructor_args["successful"] + assert recreated.errored == constructor_args["errored"] + assert recreated.incomplete == constructor_args["incomplete"] + assert recreated.total == constructor_args["total"] + + +class TestPydanticClassRegistryMixin: + """Test suite for PydanticClassRegistryMixin.""" + + @pytest.fixture( + params=[ + {"test_type": "test_sub", "value": "test_value"}, + {"test_type": "test_sub", "value": "hello_world"}, + ], + ids=["basic_value", "multi_word"], + ) + def valid_instances( + self, request + ) -> tuple[PydanticClassRegistryMixin, dict, type, type]: + """Fixture providing test data for PydanticClassRegistryMixin.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + TestBaseModel.reload_schema() + + constructor_args = request.param + instance = TestSubModel(value=constructor_args["value"]) + return instance, constructor_args, TestBaseModel, TestSubModel + + @pytest.mark.smoke + def test_class_signatures(self): + """Test PydanticClassRegistryMixin inheritance and class variables.""" + assert issubclass(PydanticClassRegistryMixin, ReloadableBaseModel) + assert hasattr(PydanticClassRegistryMixin, "schema_discriminator") + assert PydanticClassRegistryMixin.schema_discriminator == "model_type" + assert hasattr(PydanticClassRegistryMixin, "register_decorator") + assert hasattr(PydanticClassRegistryMixin, "__get_pydantic_core_schema__") + assert hasattr(PydanticClassRegistryMixin, "__pydantic_generate_base_schema__") + assert hasattr(PydanticClassRegistryMixin, "auto_populate_registry") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test PydanticClassRegistryMixin initialization.""" + instance, constructor_args, base_class, sub_class = valid_instances + assert isinstance(instance, sub_class) + assert isinstance(instance, base_class) + assert instance.test_type == constructor_args["test_type"] + assert instance.value == constructor_args["value"] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("test_type", None), + ("test_type", 123), + ("value", None), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test PydanticClassRegistryMixin with invalid field values.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + data = {field: value} + if field == "test_type": + data["value"] = "test" + else: + data["test_type"] = "test_sub" + + with pytest.raises(ValidationError): + TestSubModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test PydanticClassRegistryMixin initialization without required field.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + with pytest.raises(ValidationError): + TestSubModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_register_decorator(self): + """Test PydanticClassRegistryMixin.register_decorator method.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register() + class TestSubModel(TestBaseModel): + test_type: str = "TestSubModel" + value: str + + assert TestBaseModel.registry is not None # type: ignore[misc] + assert "testsubmodel" in TestBaseModel.registry # type: ignore[misc] + assert TestBaseModel.registry["testsubmodel"] is TestSubModel # type: ignore[misc] + + @pytest.mark.sanity + def test_register_decorator_with_name(self): + """Test PydanticClassRegistryMixin.register_decorator with custom name.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("custom_name") + class TestSubModel(TestBaseModel): + test_type: str = "custom_name" + value: str + + assert TestBaseModel.registry is not None # type: ignore[misc] + assert "custom_name" in TestBaseModel.registry # type: ignore[misc] + assert TestBaseModel.registry["custom_name"] is TestSubModel # type: ignore[misc] + + @pytest.mark.sanity + def test_register_decorator_invalid_type(self): + """Test PydanticClassRegistryMixin.register_decorator with invalid type.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + class RegularClass: + pass + + with pytest.raises(TypeError) as exc_info: + TestBaseModel.register_decorator(RegularClass) # type: ignore[arg-type] + + assert "not a subclass of Pydantic BaseModel" in str(exc_info.value) + + @pytest.mark.smoke + def test_auto_populate_registry(self): + """Test PydanticClassRegistryMixin.auto_populate_registry method.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + registry_auto_discovery: ClassVar[bool] = True + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + with ( + mock.patch.object(TestBaseModel, "reload_schema") as mock_reload, + mock.patch( + "guidellm.utils.registry.RegistryMixin.auto_populate_registry", + return_value=True, + ), + ): + result = TestBaseModel.auto_populate_registry() + assert result is True + mock_reload.assert_called_once() + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test PydanticClassRegistryMixin serialization and deserialization.""" + instance, constructor_args, base_class, sub_class = valid_instances + + # Test serialization with model_dump + dump_data = instance.model_dump() + assert isinstance(dump_data, dict) + assert dump_data["test_type"] == constructor_args["test_type"] + assert dump_data["value"] == constructor_args["value"] + + # Test deserialization via subclass + recreated = sub_class.model_validate(dump_data) + assert isinstance(recreated, sub_class) + assert recreated.test_type == constructor_args["test_type"] + assert recreated.value == constructor_args["value"] + + # Test polymorphic deserialization via base class + recreated_base = base_class.model_validate(dump_data) # type: ignore[assignment] + assert isinstance(recreated_base, sub_class) + assert recreated_base.test_type == constructor_args["test_type"] + assert recreated_base.value == constructor_args["value"] + + @pytest.mark.regression + def test_polymorphic_container_marshalling(self): + """Test PydanticClassRegistryMixin in container models.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @classmethod + def __pydantic_generate_base_schema__(cls, handler): + return handler(cls) + + @TestBaseModel.register("sub_a") + class TestSubModelA(TestBaseModel): + test_type: str = "sub_a" + value_a: str + + @TestBaseModel.register("sub_b") + class TestSubModelB(TestBaseModel): + test_type: str = "sub_b" + value_b: int + + class ContainerModel(BaseModel): + name: str + model: TestBaseModel + models: list[TestBaseModel] + + sub_a = TestSubModelA(value_a="test") + sub_b = TestSubModelB(value_b=123) + + container = ContainerModel(name="container", model=sub_a, models=[sub_a, sub_b]) + + # Verify container construction + assert isinstance(container.model, TestSubModelA) + assert container.model.test_type == "sub_a" + assert container.model.value_a == "test" + assert len(container.models) == 2 + assert isinstance(container.models[0], TestSubModelA) + assert isinstance(container.models[1], TestSubModelB) + + # Test serialization + dump_data = container.model_dump() + assert isinstance(dump_data, dict) + assert dump_data["name"] == "container" + assert dump_data["model"]["test_type"] == "sub_a" + assert dump_data["model"]["value_a"] == "test" + assert len(dump_data["models"]) == 2 + assert dump_data["models"][0]["test_type"] == "sub_a" + assert dump_data["models"][1]["test_type"] == "sub_b" + + # Test deserialization + recreated = ContainerModel.model_validate(dump_data) + assert isinstance(recreated, ContainerModel) + assert recreated.name == "container" + assert isinstance(recreated.model, TestSubModelA) + assert len(recreated.models) == 2 + assert isinstance(recreated.models[0], TestSubModelA) + assert isinstance(recreated.models[1], TestSubModelB) diff --git a/tests/unit/utils/test_registry.py b/tests/unit/utils/test_registry.py new file mode 100644 index 00000000..b5c17975 --- /dev/null +++ b/tests/unit/utils/test_registry.py @@ -0,0 +1,533 @@ +""" +Unit tests for the registry module. +""" + +from __future__ import annotations + +from typing import TypeVar +from unittest import mock + +import pytest + +from guidellm.utils.registry import RegistryMixin, RegistryObjT + + +def test_registry_obj_type(): + """Test that RegistryObjT is configured correctly as a TypeVar.""" + assert isinstance(RegistryObjT, type(TypeVar("test"))) + assert RegistryObjT.__name__ == "RegistryObjT" + assert RegistryObjT.__bound__ is not None # bound to Any + assert RegistryObjT.__constraints__ == () + + +class TestRegistryMixin: + """Test suite for RegistryMixin class.""" + + @pytest.fixture( + params=[ + {"registry_auto_discovery": False, "auto_package": None}, + {"registry_auto_discovery": True, "auto_package": "test.package"}, + ], + ids=["manual_registry", "auto_discovery"], + ) + def valid_instances(self, request): + """Fixture providing test data for RegistryMixin subclasses.""" + config = request.param + + class TestRegistryClass(RegistryMixin): + registry_auto_discovery = config["registry_auto_discovery"] + if config["auto_package"]: + auto_package = config["auto_package"] + + return TestRegistryClass, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test RegistryMixin inheritance and exposed methods.""" + assert hasattr(RegistryMixin, "registry") + assert hasattr(RegistryMixin, "registry_auto_discovery") + assert hasattr(RegistryMixin, "registry_populated") + assert hasattr(RegistryMixin, "register") + assert hasattr(RegistryMixin, "register_decorator") + assert hasattr(RegistryMixin, "auto_populate_registry") + assert hasattr(RegistryMixin, "registered_objects") + assert hasattr(RegistryMixin, "is_registered") + assert hasattr(RegistryMixin, "get_registered_object") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test RegistryMixin initialization.""" + registry_class, config = valid_instances + + assert registry_class.registry is None + assert ( + registry_class.registry_auto_discovery == config["registry_auto_discovery"] + ) + assert registry_class.registry_populated is False + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test RegistryMixin with missing auto_package when auto_discovery enabled.""" + + class TestRegistryClass(RegistryMixin): + registry_auto_discovery = True + + with pytest.raises(ValueError, match="auto_package.*must be set"): + TestRegistryClass.auto_import_package_modules() + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("name", "expected_key"), + [ + ("custom_name", "custom_name"), + (["name1", "name2"], ["name1", "name2"]), + (None, None), # Uses class name + ], + ) + def test_register(self, valid_instances, name, expected_key): + """Test register method with various name configurations.""" + registry_class, _ = valid_instances + + if name is None: + + @registry_class.register() + class TestClass: + pass + + expected_key = "testclass" + else: + + @registry_class.register(name) + class TestClass: + pass + + assert registry_class.registry is not None + if isinstance(expected_key, list): + for key in expected_key: + assert key in registry_class.registry + assert registry_class.registry[key] is TestClass + else: + assert expected_key in registry_class.registry + assert registry_class.registry[expected_key] is TestClass + + @pytest.mark.sanity + @pytest.mark.parametrize( + "invalid_name", + [123, 42.5, True, {"key": "value"}], + ) + def test_register_invalid(self, valid_instances, invalid_name): + """Test register method with invalid name types.""" + registry_class, _ = valid_instances + + with pytest.raises(ValueError, match="name must be a string, list of strings"): + registry_class.register(invalid_name) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("name", "expected_key"), + [ + ("custom_name", "custom_name"), + (["name1", "name2"], ["name1", "name2"]), + (None, "testclass"), + ], + ) + def test_register_decorator(self, valid_instances, name, expected_key): + """Test register_decorator method with various name configurations.""" + registry_class, _ = valid_instances + + class TestClass: + pass + + registry_class.register_decorator(TestClass, name=name) + + assert registry_class.registry is not None + if isinstance(expected_key, list): + for key in expected_key: + assert key in registry_class.registry + assert registry_class.registry[key] is TestClass + else: + assert expected_key in registry_class.registry + assert registry_class.registry[expected_key] is TestClass + + @pytest.mark.sanity + @pytest.mark.parametrize( + "invalid_name", + [123, 42.5, True, {"key": "value"}], + ) + def test_register_decorator_invalid(self, valid_instances, invalid_name): + """Test register_decorator with invalid name types.""" + registry_class, _ = valid_instances + + class TestClass: + pass + + with pytest.raises( + ValueError, match="name must be a string or an iterable of strings" + ): + registry_class.register_decorator(TestClass, name=invalid_name) + + @pytest.mark.smoke + def test_auto_populate_registry(self): + """Test auto_populate_registry method with valid configuration.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test.package" + + with mock.patch.object( + TestAutoRegistry, "auto_import_package_modules" + ) as mock_import: + result = TestAutoRegistry.auto_populate_registry() + assert result is True + mock_import.assert_called_once() + assert TestAutoRegistry.registry_populated is True + + # Second call should return False + result = TestAutoRegistry.auto_populate_registry() + assert result is False + mock_import.assert_called_once() # Should not be called again + + @pytest.mark.sanity + def test_auto_populate_registry_invalid(self): + """Test auto_populate_registry when auto-discovery is disabled.""" + + class TestDisabledRegistry(RegistryMixin): + registry_auto_discovery = False + + with pytest.raises(ValueError, match="registry_auto_discovery is set to False"): + TestDisabledRegistry.auto_populate_registry() + + @pytest.mark.smoke + def test_registered_objects(self, valid_instances): + """Test registered_objects method with manual registration.""" + registry_class, config = valid_instances + + @registry_class.register("class1") + class TestClass1: + pass + + @registry_class.register("class2") + class TestClass2: + pass + + if config["registry_auto_discovery"]: + with mock.patch.object(registry_class, "auto_import_package_modules"): + objects = registry_class.registered_objects() + else: + objects = registry_class.registered_objects() + + assert isinstance(objects, tuple) + assert len(objects) == 2 + assert TestClass1 in objects + assert TestClass2 in objects + + @pytest.mark.sanity + def test_registered_objects_invalid(self): + """Test registered_objects when no objects are registered.""" + + class TestRegistryClass(RegistryMixin): + pass + + with pytest.raises( + ValueError, match="must be called after registering objects" + ): + TestRegistryClass.registered_objects() + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("register_name", "check_name", "expected"), + [ + ("test_name", "test_name", True), + ("TestName", "testname", True), + ("UPPERCASE", "uppercase", True), + ("test_name", "nonexistent", False), + ], + ) + def test_is_registered(self, valid_instances, register_name, check_name, expected): + """Test is_registered with various name combinations.""" + registry_class, _ = valid_instances + + @registry_class.register(register_name) + class TestClass: + pass + + result = registry_class.is_registered(check_name) + assert result == expected + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("register_name", "lookup_name"), + [ + ("test_name", "test_name"), + ("TestName", "testname"), + ("UPPERCASE", "uppercase"), + ], + ) + def test_get_registered_object(self, valid_instances, register_name, lookup_name): + """Test get_registered_object with valid names.""" + registry_class, _ = valid_instances + + @registry_class.register(register_name) + class TestClass: + pass + + result = registry_class.get_registered_object(lookup_name) + assert result is TestClass + + @pytest.mark.sanity + @pytest.mark.parametrize( + "lookup_name", + ["nonexistent", "wrong_name", "DIFFERENT_CASE"], + ) + def test_get_registered_object_invalid(self, valid_instances, lookup_name): + """Test get_registered_object with invalid names.""" + registry_class, _ = valid_instances + + @registry_class.register("valid_name") + class TestClass: + pass + + result = registry_class.get_registered_object(lookup_name) + assert result is None + + @pytest.mark.regression + def test_multiple_registries_isolation(self): + """Test that different registry classes maintain separate registries.""" + + class Registry1(RegistryMixin): + pass + + class Registry2(RegistryMixin): + pass + + @Registry1.register() + class TestClass1: + pass + + @Registry2.register() + class TestClass2: + pass + + assert Registry1.registry is not None + assert Registry2.registry is not None + assert Registry1.registry != Registry2.registry + assert "testclass1" in Registry1.registry + assert "testclass2" in Registry2.registry + assert "testclass1" not in Registry2.registry + assert "testclass2" not in Registry1.registry + + @pytest.mark.regression + def test_inheritance_registry_sharing(self): + """Test that inherited registry classes share the same registry.""" + + class BaseRegistry(RegistryMixin): + pass + + class ChildRegistry(BaseRegistry): + pass + + @BaseRegistry.register() + class BaseClass: + pass + + @ChildRegistry.register() + class ChildClass: + pass + + # Child classes share the same registry as their parent + assert BaseRegistry.registry is ChildRegistry.registry + + # Both classes can see all registered objects + base_objects = BaseRegistry.registered_objects() + child_objects = ChildRegistry.registered_objects() + + assert len(base_objects) == 2 + assert len(child_objects) == 2 + assert base_objects == child_objects + assert BaseClass in base_objects + assert ChildClass in base_objects + + @pytest.mark.smoke + def test_auto_discovery_initialization(self): + """Test initialization of auto-discovery enabled registry.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + assert TestAutoRegistry.registry is None + assert TestAutoRegistry.registry_populated is False + assert TestAutoRegistry.auto_package == "test_package.modules" + assert TestAutoRegistry.registry_auto_discovery is True + + @pytest.mark.smoke + def test_auto_discovery_registered_objects(self): + """Test automatic population during registered_objects call.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + with mock.patch.object( + TestAutoRegistry, "auto_populate_registry" + ) as mock_populate: + TestAutoRegistry.registry = {"class1": "obj1", "class2": "obj2"} + objects = TestAutoRegistry.registered_objects() + mock_populate.assert_called_once() + assert objects == ("obj1", "obj2") + + @pytest.mark.sanity + def test_register_duplicate_registration(self, valid_instances): + """Test register method with duplicate names.""" + registry_class, _ = valid_instances + + @registry_class.register("duplicate_name") + class TestClass1: + pass + + with pytest.raises(ValueError, match="already registered"): + + @registry_class.register("duplicate_name") + class TestClass2: + pass + + @pytest.mark.sanity + def test_register_decorator_duplicate_registration(self, valid_instances): + """Test register_decorator with duplicate names.""" + registry_class, _ = valid_instances + + class TestClass1: + pass + + class TestClass2: + pass + + registry_class.register_decorator(TestClass1, name="duplicate_name") + with pytest.raises(ValueError, match="already registered"): + registry_class.register_decorator(TestClass2, name="duplicate_name") + + @pytest.mark.sanity + def test_register_decorator_invalid_list_element(self, valid_instances): + """Test register_decorator with invalid elements in name list.""" + registry_class, _ = valid_instances + + class TestClass: + pass + + with pytest.raises( + ValueError, match="name must be a string or a list of strings" + ): + registry_class.register_decorator(TestClass, name=["valid", 123]) + + @pytest.mark.sanity + def test_register_decorator_invalid_object(self, valid_instances): + """Test register_decorator with object lacking __name__ attribute.""" + registry_class, _ = valid_instances + + with pytest.raises(AttributeError): + registry_class.register_decorator("not_a_class") + + @pytest.mark.smoke + def test_is_registered_empty_registry(self, valid_instances): + """Test is_registered with empty registry.""" + registry_class, _ = valid_instances + + result = registry_class.is_registered("any_name") + assert result is False + + @pytest.mark.smoke + def test_get_registered_object_empty_registry(self, valid_instances): + """Test get_registered_object with empty registry.""" + registry_class, _ = valid_instances + + result = registry_class.get_registered_object("any_name") + assert result is None + + @pytest.mark.regression + def test_auto_registry_integration(self): + """Test complete auto-discovery workflow with mocked imports.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + with ( + mock.patch("pkgutil.walk_packages") as walk_mock, + mock.patch("importlib.import_module") as import_mock, + ): + # Setup mock package + package_mock = mock.MagicMock() + package_mock.__path__ = ["test_package/modules"] + package_mock.__name__ = "test_package.modules" + + # Setup mock module with test class + module_mock = mock.MagicMock() + module_mock.__name__ = "test_package.modules.module1" + + class Module1Class: + pass + + TestAutoRegistry.register_decorator(Module1Class, "Module1Class") + + # Setup import behavior + import_mock.side_effect = lambda name: ( + package_mock + if name == "test_package.modules" + else module_mock + if name == "test_package.modules.module1" + else (_ for _ in ()).throw(ImportError(f"No module named {name}")) + ) + + # Setup package walking behavior + walk_mock.side_effect = lambda path, prefix: ( + [(None, "test_package.modules.module1", False)] + if prefix == "test_package.modules." + else (_ for _ in ()).throw(ValueError(f"Unknown package: {prefix}")) + ) + + objects = TestAutoRegistry.registered_objects() + assert len(objects) == 1 + assert TestAutoRegistry.registry_populated is True + assert TestAutoRegistry.registry is not None + assert "module1class" in TestAutoRegistry.registry + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + with ( + mock.patch("pkgutil.walk_packages") as mock_walk, + mock.patch("importlib.import_module") as mock_import, + ): + mock_package = mock.MagicMock() + mock_package.__path__ = ["test_package/modules"] + mock_package.__name__ = "test_package.modules" + + def import_module(name: str): + if name == "test_package.modules": + return mock_package + elif name == "test_package.modules.module1": + module = mock.MagicMock() + module.__name__ = "test_package.modules.module1" + + class Module1Class: + pass + + TestAutoRegistry.register_decorator(Module1Class, "Module1Class") + return module + else: + raise ImportError(f"No module named {name}") + + def walk_packages(package_path, package_name): + if package_name == "test_package.modules.": + return [(None, "test_package.modules.module1", False)] + else: + raise ValueError(f"Unknown package: {package_name}") + + mock_walk.side_effect = walk_packages + mock_import.side_effect = import_module + + objects = TestAutoRegistry.registered_objects() + assert len(objects) == 1 + assert TestAutoRegistry.registry_populated is True + assert TestAutoRegistry.registry is not None diff --git a/tests/unit/utils/test_singleton.py b/tests/unit/utils/test_singleton.py new file mode 100644 index 00000000..ee01ead1 --- /dev/null +++ b/tests/unit/utils/test_singleton.py @@ -0,0 +1,371 @@ +from __future__ import annotations + +import threading +import time + +import pytest + +from guidellm.utils.singleton import SingletonMixin, ThreadSafeSingletonMixin + + +class TestSingletonMixin: + """Test suite for SingletonMixin class.""" + + @pytest.fixture( + params=[ + {"init_value": "test_value"}, + {"init_value": "another_value"}, + ], + ids=["basic_singleton", "different_value"], + ) + def valid_instances(self, request): + """Provide parameterized test configurations for singleton testing.""" + config = request.param + + class TestSingleton(SingletonMixin): + def __init__(self): + # Check if we need to initialize before calling super().__init__() + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = config["init_value"] + + return TestSingleton, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test SingletonMixin inheritance and exposed attributes.""" + assert hasattr(SingletonMixin, "__new__") + assert hasattr(SingletonMixin, "__init__") + assert hasattr(SingletonMixin, "initialized") + assert isinstance(SingletonMixin.initialized, property) + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test SingletonMixin initialization.""" + singleton_class, config = valid_instances + + # Create first instance + instance1 = singleton_class() + + assert isinstance(instance1, singleton_class) + assert instance1.initialized is True + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + + # Check that the class has the singleton instance stored + instance_attr = f"_singleton_instance_{singleton_class.__name__}" + assert hasattr(singleton_class, instance_attr) + assert getattr(singleton_class, instance_attr) is instance1 + + @pytest.mark.smoke + def test_singleton_behavior(self, valid_instances): + """Test that multiple instantiations return the same instance.""" + singleton_class, config = valid_instances + + # Create multiple instances + instance1 = singleton_class() + instance2 = singleton_class() + instance3 = singleton_class() + + # All should be the same instance + assert instance1 is instance2 + assert instance2 is instance3 + assert instance1 is instance3 + + # Value should remain from first initialization + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + assert instance2.value == config["init_value"] + assert instance3.value == config["init_value"] + + @pytest.mark.sanity + def test_initialization_called_once(self, valid_instances): + """Test that __init__ is only called once despite multiple instantiations.""" + singleton_class, config = valid_instances + + class TestSingletonWithCounter(SingletonMixin): + init_count = 0 + + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + TestSingletonWithCounter.init_count += 1 + self.value = config["init_value"] + + # Create multiple instances + instance1 = TestSingletonWithCounter() + instance2 = TestSingletonWithCounter() + + assert TestSingletonWithCounter.init_count == 1 + assert instance1 is instance2 + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + + @pytest.mark.regression + def test_multiple_singleton_classes_isolation(self): + """Test that different singleton classes maintain separate instances.""" + + class Singleton1(SingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value1" + + class Singleton2(SingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value2" + + instance1a = Singleton1() + instance2a = Singleton2() + instance1b = Singleton1() + instance2b = Singleton2() + + # Each class has its own singleton instance + assert instance1a is instance1b + assert instance2a is instance2b + assert instance1a is not instance2a + + # Each maintains its own value + assert hasattr(instance1a, "value") + assert hasattr(instance2a, "value") + assert instance1a.value == "value1" + assert instance2a.value == "value2" + + @pytest.mark.regression + def test_inheritance_singleton_sharing(self): + """Test that inherited singleton classes share the same singleton_instance.""" + + class BaseSingleton(SingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "base_value" + + class ChildSingleton(BaseSingleton): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.extra = "extra_value" + + # Child classes now have separate singleton instances + base_instance = BaseSingleton() + child_instance = ChildSingleton() + + # They should be different instances now (fixed inheritance behavior) + assert base_instance is not child_instance + assert hasattr(base_instance, "value") + assert base_instance.value == "base_value" + assert hasattr(child_instance, "value") + assert child_instance.value == "base_value" + assert hasattr(child_instance, "extra") + assert child_instance.extra == "extra_value" + + @pytest.mark.sanity + def test_without_super_init_call(self): + """Test singleton behavior when subclass doesn't call super().__init__().""" + + class BadSingleton(SingletonMixin): + def __init__(self): + # Not calling super().__init__() + self.value = "bad_value" + + instance1 = BadSingleton() + instance2 = BadSingleton() + + assert instance1 is instance2 + assert hasattr(instance1, "initialized") + assert instance1.initialized is False + + +class TestThreadSafeSingletonMixin: + """Test suite for ThreadSafeSingletonMixin class.""" + + @pytest.fixture( + params=[ + {"init_value": "thread_safe_value"}, + {"init_value": "concurrent_value"}, + ], + ids=["basic_thread_safe", "concurrent_test"], + ) + def valid_instances(self, request): + """Fixture providing test data for ThreadSafeSingletonMixin subclasses.""" + config = request.param + + class TestThreadSafeSingleton(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = config["init_value"] + + return TestThreadSafeSingleton, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test ThreadSafeSingletonMixin inheritance and exposed attributes.""" + assert issubclass(ThreadSafeSingletonMixin, SingletonMixin) + assert hasattr(ThreadSafeSingletonMixin, "get_singleton_lock") + assert hasattr(ThreadSafeSingletonMixin, "__new__") + assert hasattr(ThreadSafeSingletonMixin, "__init__") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test ThreadSafeSingletonMixin initialization.""" + singleton_class, config = valid_instances + + instance = singleton_class() + + assert isinstance(instance, singleton_class) + assert instance.initialized is True + assert hasattr(instance, "value") + assert instance.value == config["init_value"] + assert hasattr(instance, "thread_lock") + lock_type = type(threading.Lock()) + assert isinstance(instance.thread_lock, lock_type) + + @pytest.mark.smoke + def test_singleton_behavior(self, valid_instances): + """Test multiple instantiations return same instance with thread safety.""" + singleton_class, config = valid_instances + + instance1 = singleton_class() + instance2 = singleton_class() + + assert instance1 is instance2 + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + assert hasattr(instance1, "thread_lock") + + @pytest.mark.regression + def test_thread_safety_concurrent_creation(self, valid_instances): + """Test thread safety during concurrent instance creation.""" + singleton_class, config = valid_instances + + instances = [] + exceptions = [] + creation_count = 0 + lock = threading.Lock() + + class ThreadSafeTestSingleton(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + nonlocal creation_count + with lock: + creation_count += 1 + + time.sleep(0.01) + self.value = config["init_value"] + + def create_instance(): + try: + instance = ThreadSafeTestSingleton() + instances.append(instance) + except (TypeError, ValueError, AttributeError) as exc: + exceptions.append(exc) + + threads = [] + for _ in range(10): + thread = threading.Thread(target=create_instance) + threads.append(thread) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + assert len(exceptions) == 0, f"Exceptions occurred: {exceptions}" + + assert len(instances) == 10 + for instance in instances: + assert instance is instances[0] + + assert creation_count == 1 + assert all(instance.value == config["init_value"] for instance in instances) + + @pytest.mark.sanity + def test_thread_lock_creation(self, valid_instances): + """Test that thread_lock is created during initialization.""" + singleton_class, config = valid_instances + + instance1 = singleton_class() + instance2 = singleton_class() + + assert hasattr(instance1, "thread_lock") + lock_type = type(threading.Lock()) + assert isinstance(instance1.thread_lock, lock_type) + assert instance1.thread_lock is instance2.thread_lock + + @pytest.mark.regression + def test_multiple_thread_safe_classes_isolation(self): + """Test thread-safe singleton classes behavior with separate locks.""" + + class ThreadSafeSingleton1(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value1" + + class ThreadSafeSingleton2(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value2" + + instance1 = ThreadSafeSingleton1() + instance2 = ThreadSafeSingleton2() + + lock1 = ThreadSafeSingleton1.get_singleton_lock() + lock2 = ThreadSafeSingleton2.get_singleton_lock() + + assert lock1 is not None + assert lock2 is not None + assert lock1 is not lock2 + + assert instance1 is not instance2 + assert hasattr(instance1, "value") + assert hasattr(instance2, "value") + assert instance1.value == "value1" + assert instance2.value == "value2" + + @pytest.mark.sanity + def test_inheritance_with_thread_safety(self): + """Test inheritance behavior with thread-safe singletons.""" + + class BaseThreadSafeSingleton(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "base_value" + + class ChildThreadSafeSingleton(BaseThreadSafeSingleton): + def __init__(self): + super().__init__() + + base_instance = BaseThreadSafeSingleton() + child_instance = ChildThreadSafeSingleton() + + base_lock = BaseThreadSafeSingleton.get_singleton_lock() + child_lock = ChildThreadSafeSingleton.get_singleton_lock() + + assert base_lock is not None + assert child_lock is not None + assert base_lock is not child_lock + + assert base_instance is not child_instance + assert hasattr(base_instance, "value") + assert base_instance.value == "base_value" + assert hasattr(base_instance, "thread_lock") From 3db57b05db327b8980915e4aa99b19e69ac35a68 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Tue, 26 Aug 2025 15:41:01 -0400 Subject: [PATCH 3/3] Scheduler refactor [utils]: functions, mixins, statistics, text (#290) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary This PR adds new utility modules for safe math function operations, object info extraction, and reorganizes to consolidate into the utils package. ## Details - **Move statistics classes from `objects` to `utils`**: Relocates `DistributionSummary`, `Percentiles`, `RunningStats`, `StatusDistributionSummary`, and `TimeRunningStats` from `guidellm.objects.statistics` to `guidellm.utils.statistics` - **Add new `functions.py` module**: Implements defensive programming utilities including `safe_getattr`, `safe_divide`, `safe_multiply`, `safe_add`, `safe_format_timestamp`, and `all_defined` for handling None values and edge cases - **Add new `mixins.py` module**: Provides `InfoMixin` class for standardized metadata extraction and object introspection across different class hierarchies - **Enhance `text.py` module**: Adds comprehensive documentation, `format_value_display` function for consistent metric formatting, and improved text processing utilities - **Update import statements**: Modifies all affected modules (`benchmark`, `presentation`) to import statistics classes from their new location in `utils` - **Remove deprecated `objects` package**: Deletes the now-empty `objects` directory and associated test files - **Add comprehensive test coverage**: Includes new test suites for `functions.py`, `mixins.py`, and `statistics.py` - **Update `__init__.py` exports**: Adds new utility functions and classes to the main utils package exports for easy access ## Test Plan - Run the existing test suite to ensure no regressions from the statistics class relocation - Execute new test files: - test_functions.py - Tests for safe operation utilities - test_mixins.py - Tests for InfoMixin functionality - test_statistics.py - Comprehensive tests for statistical analysis utilities - test_text.py - Tests for enhanced text processing functions ## Related Issues - Resolves # --- - [X] "I certify that all code in this PR is my own, except as noted below." ## Use of AI - [X] Includes AI-assisted code completion - [X] Includes code generated by an AI application - [ ] Includes AI-generated tests (NOTE: AI written tests should have a docstring that includes `## WRITTEN BY AI ##`) --------- Signed-off-by: Mark Kurtz Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Samuel Monson --- src/guidellm/benchmark/aggregator.py | 12 +- src/guidellm/benchmark/benchmark.py | 9 +- src/guidellm/benchmark/output.py | 12 +- src/guidellm/objects/__init__.py | 15 - src/guidellm/presentation/data_models.py | 2 +- src/guidellm/utils/__init__.py | 30 +- src/guidellm/utils/functions.py | 130 +++++ src/guidellm/utils/mixins.py | 115 ++++ src/guidellm/{objects => utils}/statistics.py | 0 src/guidellm/utils/text.py | 200 +++++-- tests/unit/mock_benchmark.py | 2 +- tests/unit/objects/__init__.py | 0 tests/unit/utils/test_functions.py | 222 ++++++++ tests/unit/utils/test_mixins.py | 245 ++++++++ .../{objects => utils}/test_statistics.py | 79 --- tests/unit/utils/test_text.py | 531 ++++++++++++++++++ 16 files changed, 1454 insertions(+), 150 deletions(-) delete mode 100644 src/guidellm/objects/__init__.py create mode 100644 src/guidellm/utils/functions.py create mode 100644 src/guidellm/utils/mixins.py rename src/guidellm/{objects => utils}/statistics.py (100%) delete mode 100644 tests/unit/objects/__init__.py create mode 100644 tests/unit/utils/test_functions.py create mode 100644 tests/unit/utils/test_mixins.py rename tests/unit/{objects => utils}/test_statistics.py (90%) create mode 100644 tests/unit/utils/test_text.py diff --git a/src/guidellm/benchmark/aggregator.py b/src/guidellm/benchmark/aggregator.py index b322eadd..450b536a 100644 --- a/src/guidellm/benchmark/aggregator.py +++ b/src/guidellm/benchmark/aggregator.py @@ -22,10 +22,6 @@ GenerativeTextResponseStats, ) from guidellm.config import settings -from guidellm.objects import ( - RunningStats, - TimeRunningStats, -) from guidellm.request import ( GenerationRequest, GenerativeRequestLoaderDescription, @@ -38,7 +34,13 @@ SchedulerRequestResult, WorkerDescription, ) -from guidellm.utils import StandardBaseModel, StatusBreakdown, check_load_processor +from guidellm.utils import ( + RunningStats, + StandardBaseModel, + StatusBreakdown, + TimeRunningStats, + check_load_processor, +) __all__ = [ "AggregatorT", diff --git a/src/guidellm/benchmark/benchmark.py b/src/guidellm/benchmark/benchmark.py index 77d0fe38..eadcf984 100644 --- a/src/guidellm/benchmark/benchmark.py +++ b/src/guidellm/benchmark/benchmark.py @@ -12,9 +12,6 @@ SynchronousProfile, ThroughputProfile, ) -from guidellm.objects import ( - StatusDistributionSummary, -) from guidellm.request import ( GenerativeRequestLoaderDescription, RequestLoaderDescription, @@ -30,7 +27,11 @@ ThroughputStrategy, WorkerDescription, ) -from guidellm.utils import StandardBaseModel, StatusBreakdown +from guidellm.utils import ( + StandardBaseModel, + StatusBreakdown, + StatusDistributionSummary, +) __all__ = [ "Benchmark", diff --git a/src/guidellm/benchmark/output.py b/src/guidellm/benchmark/output.py index dd94f899..225ed2b1 100644 --- a/src/guidellm/benchmark/output.py +++ b/src/guidellm/benchmark/output.py @@ -21,14 +21,16 @@ ThroughputProfile, ) from guidellm.config import settings -from guidellm.objects import ( - DistributionSummary, - StatusDistributionSummary, -) from guidellm.presentation import UIDataBuilder from guidellm.presentation.injector import create_report from guidellm.scheduler import strategy_display_str -from guidellm.utils import Colors, StandardBaseModel, split_text_list_by_length +from guidellm.utils import ( + Colors, + DistributionSummary, + StandardBaseModel, + StatusDistributionSummary, + split_text_list_by_length, +) __all__ = [ "GenerativeBenchmarksConsole", diff --git a/src/guidellm/objects/__init__.py b/src/guidellm/objects/__init__.py deleted file mode 100644 index 119ac6e7..00000000 --- a/src/guidellm/objects/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from .statistics import ( - DistributionSummary, - Percentiles, - RunningStats, - StatusDistributionSummary, - TimeRunningStats, -) - -__all__ = [ - "DistributionSummary", - "Percentiles", - "RunningStats", - "StatusDistributionSummary", - "TimeRunningStats", -] diff --git a/src/guidellm/presentation/data_models.py b/src/guidellm/presentation/data_models.py index ff5221e3..3164dc86 100644 --- a/src/guidellm/presentation/data_models.py +++ b/src/guidellm/presentation/data_models.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from guidellm.benchmark.benchmark import GenerativeBenchmark -from guidellm.objects.statistics import DistributionSummary +from guidellm.utils.statistics import DistributionSummary class Bucket(BaseModel): diff --git a/src/guidellm/utils/__init__.py b/src/guidellm/utils/__init__.py index 98ac1c36..576fe64d 100644 --- a/src/guidellm/utils/__init__.py +++ b/src/guidellm/utils/__init__.py @@ -1,6 +1,14 @@ from .auto_importer import AutoImporterMixin from .colors import Colors from .default_group import DefaultGroupHandler +from .functions import ( + all_defined, + safe_add, + safe_divide, + safe_format_timestamp, + safe_getattr, + safe_multiply, +) from .hf_datasets import ( SUPPORTED_TYPES, save_dataset_to_file, @@ -18,11 +26,18 @@ from .random import IntegerRangeSampler from .registry import RegistryMixin from .singleton import SingletonMixin, ThreadSafeSingletonMixin +from .statistics import ( + DistributionSummary, + Percentiles, + RunningStats, + StatusDistributionSummary, + TimeRunningStats, +) from .text import ( EndlessTextCreator, clean_text, filter_text, - is_puncutation, + is_punctuation, load_text, split_text, split_text_list_by_length, @@ -33,21 +48,32 @@ "AutoImporterMixin", "Colors", "DefaultGroupHandler", + "DistributionSummary", "EndlessTextCreator", "IntegerRangeSampler", + "Percentiles", "PydanticClassRegistryMixin", "RegistryMixin", "ReloadableBaseModel", + "RunningStats", "SingletonMixin", "StandardBaseDict", "StandardBaseModel", "StatusBreakdown", + "StatusDistributionSummary", "ThreadSafeSingletonMixin", + "TimeRunningStats", + "all_defined", "check_load_processor", "clean_text", "filter_text", - "is_puncutation", + "is_punctuation", "load_text", + "safe_add", + "safe_divide", + "safe_format_timestamp", + "safe_getattr", + "safe_multiply", "save_dataset_to_file", "split_text", "split_text_list_by_length", diff --git a/src/guidellm/utils/functions.py b/src/guidellm/utils/functions.py new file mode 100644 index 00000000..b28aa21e --- /dev/null +++ b/src/guidellm/utils/functions.py @@ -0,0 +1,130 @@ +""" +Utility functions for safe operations and value handling. + +Provides defensive programming utilities for common operations that may encounter +None values, invalid inputs, or edge cases. Includes safe arithmetic operations, +attribute access, and timestamp formatting. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +__all__ = [ + "all_defined", + "safe_add", + "safe_divide", + "safe_format_timestamp", + "safe_getattr", + "safe_multiply", +] + + +def safe_getattr(obj: Any | None, attr: str, default: Any = None) -> Any: + """ + Safely get an attribute from an object with None handling. + + :param obj: Object to get the attribute from, or None + :param attr: Name of the attribute to retrieve + :param default: Value to return if object is None or attribute doesn't exist + :return: Attribute value or default if not found or object is None + """ + if obj is None: + return default + + return getattr(obj, attr, default) + + +def all_defined(*values: Any | None) -> bool: + """ + Check if all provided values are defined (not None). + + :param values: Variable number of values to check for None + :return: True if all values are not None, False otherwise + """ + return all(value is not None for value in values) + + +def safe_divide( + numerator: int | float | None, + denominator: int | float | None, + num_default: float = 0.0, + den_default: float = 1.0, +) -> float: + """ + Safely divide two numbers with None handling and zero protection. + + :param numerator: Number to divide, or None to use num_default + :param denominator: Number to divide by, or None to use den_default + :param num_default: Default value for numerator if None + :param den_default: Default value for denominator if None + :return: Division result with protection against division by zero + """ + numerator = numerator if numerator is not None else num_default + denominator = denominator if denominator is not None else den_default + + return numerator / (denominator or 1e-10) + + +def safe_multiply(*values: int | float | None, default: float = 1.0) -> float: + """ + Safely multiply multiple numbers with None handling. + + :param values: Variable number of values to multiply, None values treated as 1.0 + :param default: Starting value for multiplication + :return: Product of all non-None values multiplied by default + """ + result = default + for val in values: + result *= val if val is not None else 1.0 + return result + + +def safe_add( + *values: int | float | None, signs: list[int] | None = None, default: float = 0.0 +) -> float: + """ + Safely add multiple numbers with None handling and optional signs. + + :param values: Variable number of values to add, None values use default + :param signs: Optional list of 1 (add) or -1 (subtract) for each value. + If None, all values are added. Must match length of values. + :param default: Value to substitute for None values + :return: Result of adding all values safely (default used when value is None) + """ + if not values: + return default + + values = list(values) + + if signs is None: + signs = [1] * len(values) + + if len(signs) != len(values): + raise ValueError("Length of signs must match length of values") + + result = values[0] if values[0] is not None else default + + for ind in range(1, len(values)): + val = values[ind] if values[ind] is not None else default + result += signs[ind] * val + + return result + + +def safe_format_timestamp( + timestamp: float | None, format_: str = "%H:%M:%S", default: str = "N/A" +) -> str: + """ + Safely format a timestamp with error handling and validation. + + :param timestamp: Unix timestamp to format, or None + :param format_: Strftime format string for timestamp formatting + :param default: Value to return if timestamp is invalid or None + :return: Formatted timestamp string or default value + """ + try: + return datetime.fromtimestamp(timestamp).strftime(format_) + except (ValueError, TypeError, OverflowError, OSError): + return default diff --git a/src/guidellm/utils/mixins.py b/src/guidellm/utils/mixins.py new file mode 100644 index 00000000..b001ff2d --- /dev/null +++ b/src/guidellm/utils/mixins.py @@ -0,0 +1,115 @@ +""" +Mixin classes for common metadata extraction and object introspection. + +Provides reusable mixins for extracting structured metadata from objects, +enabling consistent information exposure across different class hierarchies. +""" + +from __future__ import annotations + +from typing import Any + +__all__ = ["InfoMixin"] + + +PYTHON_PRIMITIVES = (str, int, float, bool, list, tuple, dict) +"""Type alias for serialized object representations""" + + +class InfoMixin: + """ + Mixin class providing standardized metadata extraction for introspection. + + Enables consistent object metadata extraction patterns across different + class hierarchies for debugging, serialization, and runtime analysis. + Provides both instance and class-level methods for extracting structured + information from arbitrary objects with fallback handling for objects + without built-in info capabilities. + + Example: + :: + from guidellm.utils.mixins import InfoMixin + + class ConfiguredClass(InfoMixin): + def __init__(self, setting: str): + self.setting = setting + + obj = ConfiguredClass("value") + # Returns {'str': 'ConfiguredClass(...)', 'type': 'ConfiguredClass', ...} + print(obj.info) + """ + + @classmethod + def extract_from_obj(cls, obj: Any) -> dict[str, Any]: + """ + Extract structured metadata from any object. + + Attempts to use the object's own `info` method or property if available, + otherwise constructs metadata from object attributes and type information. + Provides consistent metadata format across different object types. + + :param obj: Object to extract metadata from + :return: Dictionary containing object metadata including type, class, + module, and public attributes + """ + if hasattr(obj, "info"): + return obj.info() if callable(obj.info) else obj.info + + return { + "str": str(obj), + "type": type(obj).__name__, + "class": obj.__class__.__name__ if hasattr(obj, "__class__") else None, + "module": obj.__class__.__module__ if hasattr(obj, "__class__") else None, + "attributes": ( + { + key: val if isinstance(val, PYTHON_PRIMITIVES) else repr(val) + for key, val in obj.__dict__.items() + if not key.startswith("_") + } + if hasattr(obj, "__dict__") + else {} + ), + } + + @classmethod + def create_info_dict(cls, obj: Any) -> dict[str, Any]: + """ + Create a structured info dictionary for the given object. + + Builds standardized metadata dictionary containing object identification, + type information, and accessible attributes. Used internally by other + info extraction methods and available for direct metadata construction. + + :param obj: Object to extract info from + :return: Dictionary containing structured metadata about the object + """ + return { + "str": str(obj), + "type": type(obj).__name__, + "class": obj.__class__.__name__ if hasattr(obj, "__class__") else None, + "module": obj.__class__.__module__ if hasattr(obj, "__class__") else None, + "attributes": ( + { + key: val + if isinstance(val, (str, int, float, bool, list, dict)) + else repr(val) + for key, val in obj.__dict__.items() + if not key.startswith("_") + } + if hasattr(obj, "__dict__") + else {} + ), + } + + @property + def info(self) -> dict[str, Any]: + """ + Return structured metadata about this instance. + + Provides consistent access to object metadata for debugging, serialization, + and introspection. Uses the create_info_dict method to generate standardized + metadata format including class information and public attributes. + + :return: Dictionary containing class name, module, and public attributes + """ + return self.create_info_dict(self) diff --git a/src/guidellm/objects/statistics.py b/src/guidellm/utils/statistics.py similarity index 100% rename from src/guidellm/objects/statistics.py rename to src/guidellm/utils/statistics.py diff --git a/src/guidellm/utils/text.py b/src/guidellm/utils/text.py index cdefaa14..beebfe37 100644 --- a/src/guidellm/utils/text.py +++ b/src/guidellm/utils/text.py @@ -1,9 +1,21 @@ +""" +Text processing utilities for content manipulation and formatting operations. + +Provides comprehensive text processing capabilities including cleaning, filtering, +splitting, loading from various sources, and formatting utilities. Supports loading +text from URLs, compressed files, package resources, and local files with automatic +encoding detection. Includes specialized formatting for display values and text +wrapping operations for consistent presentation across the system. +""" + +from __future__ import annotations + import gzip import re import textwrap from importlib.resources import as_file, files # type: ignore[attr-defined] from pathlib import Path -from typing import Any, Optional, Union +from typing import Any import ftfy import httpx @@ -11,35 +23,86 @@ from guidellm import data as package_data from guidellm.config import settings +from guidellm.utils.colors import Colors __all__ = [ + "MAX_PATH_LENGTH", "EndlessTextCreator", "clean_text", "filter_text", - "is_puncutation", + "format_value_display", + "is_punctuation", "load_text", "split_text", "split_text_list_by_length", ] -MAX_PATH_LENGTH = 4096 +MAX_PATH_LENGTH: int = 4096 + + +def format_value_display( + value: float, + label: str, + units: str = "", + total_characters: int | None = None, + digits_places: int | None = None, + decimal_places: int | None = None, +) -> str: + """ + Format a numeric value with units and label for consistent display output. + + Creates standardized display strings for metrics and measurements with + configurable precision, width, and color formatting. Supports both + fixed-width and variable-width output for tabular displays. + + :param value: Numeric value to format and display + :param label: Descriptive label for the value + :param units: Units string to append after the value + :param total_characters: Total width for right-aligned output formatting + :param digits_places: Total number of digits for numeric formatting + :param decimal_places: Number of decimal places for numeric precision + :return: Formatted string with value, units, and colored label + """ + if decimal_places is None and digits_places is None: + formatted_number = f"{value}:.0f" + elif digits_places is None: + formatted_number = f"{value:.{decimal_places}f}" + elif decimal_places is None: + formatted_number = f"{value:>{digits_places}f}" + else: + formatted_number = f"{value:>{digits_places}.{decimal_places}f}" + + result = f"{formatted_number}{units} [{Colors.info}]{label}[/{Colors.info}]" + + if total_characters is not None: + total_characters += len(Colors.info) * 2 + 5 + + if len(result) < total_characters: + result = result.rjust(total_characters) + + return result def split_text_list_by_length( text_list: list[Any], - max_characters: Union[int, list[int]], + max_characters: int | list[int], pad_horizontal: bool = True, pad_vertical: bool = True, ) -> list[list[str]]: """ - Split a list of strings into a list of strings, - each with a maximum length of max_characters - - :param text_list: the list of strings to split - :param max_characters: the maximum length of each string - :param pad_horizontal: whether to pad the strings horizontally, defaults to True - :param pad_vertical: whether to pad the strings vertically, defaults to True - :return: a list of strings + Split text strings into wrapped lines with specified maximum character limits. + + Processes each string in the input list by wrapping text to fit within character + limits, with optional padding for consistent formatting in tabular displays. + Supports different character limits per string and uniform padding across results. + + :param text_list: List of strings to process and wrap + :param max_characters: Maximum characters per line, either single value or + per-string limits + :param pad_horizontal: Right-align lines within their character limits + :param pad_vertical: Pad shorter results to match the longest wrapped result + :return: List of wrapped line lists, one per input string + :raises ValueError: If max_characters list length doesn't match text_list length """ if not isinstance(max_characters, list): max_characters = [max_characters] * len(text_list) @@ -75,16 +138,21 @@ def split_text_list_by_length( def filter_text( text: str, - filter_start: Optional[Union[str, int]] = None, - filter_end: Optional[Union[str, int]] = None, + filter_start: str | int | None = None, + filter_end: str | int | None = None, ) -> str: """ - Filter text by start and end strings or indices + Extract text substring using start and end markers or indices. + + Filters text content by locating string markers or using numeric indices + to extract specific portions. Supports flexible filtering for content + extraction and preprocessing operations. - :param text: the text to filter - :param filter_start: the start string or index to filter from - :param filter_end: the end string or index to filter to - :return: the filtered text + :param text: Source text to filter and extract from + :param filter_start: Starting marker string or index position + :param filter_end: Ending marker string or index position + :return: Filtered text substring between specified boundaries + :raises ValueError: If filter indices are invalid or markers not found """ filter_start_index = -1 filter_end_index = -1 @@ -112,10 +180,29 @@ def filter_text( def clean_text(text: str) -> str: + """ + Normalize text by fixing encoding issues and standardizing whitespace. + + Applies Unicode normalization and whitespace standardization for consistent + text processing. Removes excessive whitespace and fixes common encoding problems. + + :param text: Raw text string to clean and normalize + :return: Cleaned text with normalized encoding and whitespace + """ return re.sub(r"\s+", " ", ftfy.fix_text(text)).strip() def split_text(text: str, split_punctuation: bool = False) -> list[str]: + """ + Split text into tokens with optional punctuation separation. + + Tokenizes text into words and optionally separates punctuation marks + for detailed text analysis and processing operations. + + :param text: Text string to tokenize and split + :param split_punctuation: Separate punctuation marks as individual tokens + :return: List of text tokens + """ text = clean_text(text) if split_punctuation: @@ -124,16 +211,20 @@ def split_text(text: str, split_punctuation: bool = False) -> list[str]: return text.split() -def load_text(data: Union[str, Path], encoding: Optional[str] = None) -> str: +def load_text(data: str | Path, encoding: str | None = None) -> str: """ - Load an HTML file from a path or URL - - :param data: the path or URL to load the HTML file from - :type data: Union[str, Path] - :param encoding: the encoding to use when reading the file - :type encoding: str - :return: the HTML content - :rtype: str + Load text content from various sources including URLs, files, and package data. + + Supports loading from HTTP/FTP URLs, local files, compressed archives, package + resources, and raw text strings. Automatically detects source type and applies + appropriate loading strategy with encoding support. + + :param data: Source location or raw text - URL, file path, package resource + identifier, or text content + :param encoding: Character encoding for file reading operations + :return: Loaded text content as string + :raises FileNotFoundError: If local file path does not exist + :raises httpx.HTTPStatusError: If URL request fails """ logger.debug("Loading text: {}", data) @@ -177,38 +268,71 @@ def load_text(data: Union[str, Path], encoding: Optional[str] = None) -> str: return data.read_text(encoding=encoding) -def is_puncutation(text: str) -> bool: +def is_punctuation(text: str) -> bool: """ - Check if the text is a punctuation + Check if a single character is a punctuation mark. + + Identifies punctuation characters by excluding alphanumeric characters + and whitespace from single-character strings. - :param text: the text to check - :type text: str - :return: True if the text is a punctuation, False otherwise - :rtype: bool + :param text: Single character string to test + :return: True if the character is punctuation, False otherwise """ return len(text) == 1 and not text.isalnum() and not text.isspace() class EndlessTextCreator: + """ + Infinite text generator for load testing and content creation operations. + + Provides deterministic text generation by cycling through preprocessed word + tokens from source content. Supports filtering and punctuation handling for + realistic text patterns in benchmarking scenarios. + + Example: + :: + creator = EndlessTextCreator("path/to/source.txt") + generated = creator.create_text(start=0, length=100) + more_text = creator.create_text(start=50, length=200) + """ + def __init__( self, - data: Union[str, Path], - filter_start: Optional[Union[str, int]] = None, - filter_end: Optional[Union[str, int]] = None, + data: str | Path, + filter_start: str | int | None = None, + filter_end: str | int | None = None, ): + """ + Initialize text creator with source content and optional filtering. + + :param data: Source text location or content - file path, URL, or raw text + :param filter_start: Starting marker or index for content filtering + :param filter_end: Ending marker or index for content filtering + """ self.data = data self.text = load_text(data) self.filtered_text = filter_text(self.text, filter_start, filter_end) self.words = split_text(self.filtered_text, split_punctuation=True) def create_text(self, start: int, length: int) -> str: + """ + Generate text by cycling through word tokens from the specified position. + + Creates deterministic text sequences by selecting consecutive tokens from + the preprocessed word list, wrapping around when reaching the end. + Maintains proper spacing and punctuation formatting. + + :param start: Starting position in the token sequence + :param length: Number of tokens to include in generated text + :return: Generated text string with proper spacing and punctuation + """ text = "" for counter in range(length): index = (start + counter) % len(self.words) add_word = self.words[index] - if counter != 0 and not is_puncutation(add_word): + if counter != 0 and not is_punctuation(add_word): text += " " text += add_word diff --git a/tests/unit/mock_benchmark.py b/tests/unit/mock_benchmark.py index 81364fa1..29c092c8 100644 --- a/tests/unit/mock_benchmark.py +++ b/tests/unit/mock_benchmark.py @@ -6,13 +6,13 @@ GenerativeTextResponseStats, SynchronousProfile, ) -from guidellm.objects import StatusBreakdown from guidellm.request import GenerativeRequestLoaderDescription from guidellm.scheduler import ( GenerativeRequestsWorkerDescription, SchedulerRequestInfo, SynchronousStrategy, ) +from guidellm.utils import StatusBreakdown __all__ = ["mock_generative_benchmark"] diff --git a/tests/unit/objects/__init__.py b/tests/unit/objects/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unit/utils/test_functions.py b/tests/unit/utils/test_functions.py new file mode 100644 index 00000000..3b353759 --- /dev/null +++ b/tests/unit/utils/test_functions.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +from datetime import datetime + +import pytest + +from guidellm.utils.functions import ( + all_defined, + safe_add, + safe_divide, + safe_format_timestamp, + safe_getattr, + safe_multiply, +) + + +class TestAllDefined: + """Test suite for all_defined function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("values", "expected"), + [ + ((1, 2, 3), True), + (("test", "hello"), True), + ((0, False, ""), True), + ((1, None, 3), False), + ((None,), False), + ((None, None), False), + ((), True), + ], + ) + def test_invocation(self, values, expected): + """Test all_defined with valid inputs.""" + result = all_defined(*values) + assert result == expected + + @pytest.mark.sanity + def test_mixed_types(self): + """Test all_defined with mixed data types.""" + result = all_defined(1, "test", [], {}, 0.0, False) + assert result is True + + result = all_defined(1, "test", None, {}) + assert result is False + + +class TestSafeGetattr: + """Test suite for safe_getattr function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("obj", "attr", "default", "expected"), + [ + (None, "any_attr", "default_val", "default_val"), + (None, "any_attr", None, None), + ("test_string", "nonexistent", "default_val", "default_val"), + ], + ) + def test_invocation(self, obj, attr, default, expected): + """Test safe_getattr with valid inputs.""" + result = safe_getattr(obj, attr, default) + assert result == expected + + @pytest.mark.smoke + def test_with_object(self): + """Test safe_getattr with actual object attributes.""" + + class TestObj: + test_attr = "test_value" + + obj = TestObj() + result = safe_getattr(obj, "test_attr", "default") + assert result == "test_value" + + result = safe_getattr(obj, "missing_attr", "default") + assert result == "default" + + # Test with method attribute + result = safe_getattr("test_string", "upper", None) + assert callable(result) + assert result() == "TEST_STRING" + + +class TestSafeDivide: + """Test suite for safe_divide function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("numerator", "denominator", "num_default", "den_default", "expected"), + [ + (10, 2, 0.0, 1.0, 5.0), + (None, 2, 6.0, 1.0, 3.0), + (10, None, 0.0, 5.0, 2.0), + (None, None, 8.0, 4.0, 2.0), + (10, 0, 0.0, 1.0, 10 / 1e-10), + ], + ) + def test_invocation( + self, numerator, denominator, num_default, den_default, expected + ): + """Test safe_divide with valid inputs.""" + result = safe_divide(numerator, denominator, num_default, den_default) + assert result == pytest.approx(expected, rel=1e-6) + + @pytest.mark.sanity + def test_zero_division_protection(self): + """Test safe_divide protection against zero division.""" + result = safe_divide(10, 0) + assert result == 10 / 1e-10 + + result = safe_divide(5, None, den_default=0) + assert result == 5 / 1e-10 + + +class TestSafeMultiply: + """Test suite for safe_multiply function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("values", "default", "expected"), + [ + ((2, 3, 4), 1.0, 24.0), + ((2, None, 4), 1.0, 8.0), + ((None, None), 5.0, 5.0), + ((), 3.0, 3.0), + ((2, 3, None, 5), 2.0, 60.0), + ], + ) + def test_invocation(self, values, default, expected): + """Test safe_multiply with valid inputs.""" + result = safe_multiply(*values, default=default) + assert result == expected + + @pytest.mark.sanity + def test_with_zero(self): + """Test safe_multiply with zero values.""" + result = safe_multiply(2, 0, 3, default=1.0) + assert result == 0.0 + + result = safe_multiply(None, 0, None, default=5.0) + assert result == 0.0 + + +class TestSafeAdd: + """Test suite for safe_add function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("values", "signs", "default", "expected"), + [ + ((1, 2, 3), None, 0.0, 6.0), + ((1, None, 3), None, 5.0, 9.0), + ((10, 5), [1, -1], 0.0, 5.0), + ((None, None), [1, -1], 2.0, 0.0), + ((), None, 3.0, 3.0), + ((1, 2, 3), [1, 1, -1], 0.0, 0.0), + ], + ) + def test_invocation(self, values, signs, default, expected): + """Test safe_add with valid inputs.""" + result = safe_add(*values, signs=signs, default=default) + assert result == expected + + @pytest.mark.sanity + def test_invalid_signs_length(self): + """Test safe_add with invalid signs length.""" + with pytest.raises( + ValueError, match="Length of signs must match length of values" + ): + safe_add(1, 2, 3, signs=[1, -1]) + + @pytest.mark.sanity + def test_single_value(self): + """Test safe_add with single value.""" + result = safe_add(5, default=1.0) + assert result == 5.0 + + result = safe_add(None, default=3.0) + assert result == 3.0 + + +class TestSafeFormatTimestamp: + """Test suite for safe_format_timestamp function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("timestamp", "format_", "default", "expected"), + [ + (1609459200.0, "%Y-%m-%d", "N/A", "2020-12-31"), + (1609459200.0, "%H:%M:%S", "N/A", "19:00:00"), + (None, "%H:%M:%S", "N/A", "N/A"), + (-1, "%H:%M:%S", "N/A", "N/A"), + (2**32, "%H:%M:%S", "N/A", "N/A"), + ], + ) + def test_invocation(self, timestamp, format_, default, expected): + """Test safe_format_timestamp with valid inputs.""" + result = safe_format_timestamp(timestamp, format_, default) + assert result == expected + + @pytest.mark.sanity + def test_edge_cases(self): + """Test safe_format_timestamp with edge case timestamps.""" + result = safe_format_timestamp(0.0, "%Y", "N/A") + assert result == "1969" + + result = safe_format_timestamp(1.0, "%Y", "N/A") + assert result == "1969" + + result = safe_format_timestamp(2**31 - 1, "%Y", "N/A") + expected_year = datetime.fromtimestamp(2**31 - 1).strftime("%Y") + assert result == expected_year + + @pytest.mark.sanity + def test_invalid_timestamp_ranges(self): + """Test safe_format_timestamp with invalid timestamp ranges.""" + result = safe_format_timestamp(2**31 + 1, "%Y", "ERROR") + assert result == "ERROR" + + result = safe_format_timestamp(-1000, "%Y", "ERROR") + assert result == "ERROR" diff --git a/tests/unit/utils/test_mixins.py b/tests/unit/utils/test_mixins.py new file mode 100644 index 00000000..cd8990de --- /dev/null +++ b/tests/unit/utils/test_mixins.py @@ -0,0 +1,245 @@ +from __future__ import annotations + +import pytest + +from guidellm.utils.mixins import InfoMixin + + +class TestInfoMixin: + """Test suite for InfoMixin.""" + + @pytest.fixture( + params=[ + {"attr_one": "test_value", "attr_two": 42}, + {"attr_one": "hello_world", "attr_two": 100, "attr_three": [1, 2, 3]}, + ], + ids=["basic_attributes", "extended_attributes"], + ) + def valid_instances(self, request): + """Fixture providing test data for InfoMixin.""" + constructor_args = request.param + + class TestClass(InfoMixin): + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + instance = TestClass(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test InfoMixin class signatures and methods.""" + assert hasattr(InfoMixin, "extract_from_obj") + assert callable(InfoMixin.extract_from_obj) + assert hasattr(InfoMixin, "create_info_dict") + assert callable(InfoMixin.create_info_dict) + assert hasattr(InfoMixin, "info") + assert isinstance(InfoMixin.info, property) + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test InfoMixin initialization through inheritance.""" + instance, constructor_args = valid_instances + assert isinstance(instance, InfoMixin) + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.smoke + def test_info_property(self, valid_instances): + """Test InfoMixin.info property.""" + instance, constructor_args = valid_instances + result = instance.info + assert isinstance(result, dict) + assert "str" in result + assert "type" in result + assert "class" in result + assert "module" in result + assert "attributes" in result + assert result["type"] == "TestClass" + assert result["class"] == "TestClass" + assert isinstance(result["attributes"], dict) + for key, value in constructor_args.items(): + assert key in result["attributes"] + assert result["attributes"][key] == value + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("obj_data", "expected_attributes"), + [ + ({"name": "test", "value": 42}, {"name": "test", "value": 42}), + ({"data": [1, 2, 3], "flag": True}, {"data": [1, 2, 3], "flag": True}), + ({"nested": {"key": "value"}}, {"nested": {"key": "value"}}), + ], + ) + def test_create_info_dict(self, obj_data, expected_attributes): + """Test InfoMixin.create_info_dict class method.""" + + class SimpleObject: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + obj = SimpleObject(**obj_data) + result = InfoMixin.create_info_dict(obj) + + assert isinstance(result, dict) + assert "str" in result + assert "type" in result + assert "class" in result + assert "module" in result + assert "attributes" in result + assert result["type"] == "SimpleObject" + assert result["class"] == "SimpleObject" + assert result["attributes"] == expected_attributes + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("obj_data", "expected_attributes"), + [ + ({"name": "test", "value": 42}, {"name": "test", "value": 42}), + ({"data": [1, 2, 3], "flag": True}, {"data": [1, 2, 3], "flag": True}), + ], + ) + def test_extract_from_obj_without_info(self, obj_data, expected_attributes): + """Test InfoMixin.extract_from_obj with objects without info method.""" + + class SimpleObject: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + obj = SimpleObject(**obj_data) + result = InfoMixin.extract_from_obj(obj) + + assert isinstance(result, dict) + assert "str" in result + assert "type" in result + assert "class" in result + assert "module" in result + assert "attributes" in result + assert result["type"] == "SimpleObject" + assert result["class"] == "SimpleObject" + assert result["attributes"] == expected_attributes + + @pytest.mark.smoke + def test_extract_from_obj_with_info_method(self): + """Test InfoMixin.extract_from_obj with objects that have info method.""" + + class ObjectWithInfoMethod: + def info(self): + return {"custom": "info_method", "type": "custom_type"} + + obj = ObjectWithInfoMethod() + result = InfoMixin.extract_from_obj(obj) + + assert result == {"custom": "info_method", "type": "custom_type"} + + @pytest.mark.smoke + def test_extract_from_obj_with_info_property(self): + """Test InfoMixin.extract_from_obj with objects that have info property.""" + + class ObjectWithInfoProperty: + @property + def info(self): + return {"custom": "info_property", "type": "custom_type"} + + obj = ObjectWithInfoProperty() + result = InfoMixin.extract_from_obj(obj) + + assert result == {"custom": "info_property", "type": "custom_type"} + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("obj_type", "obj_value"), + [ + (str, "test_string"), + (int, 42), + (float, 3.14), + (list, [1, 2, 3]), + (dict, {"key": "value"}), + ], + ) + def test_extract_from_obj_builtin_types(self, obj_type, obj_value): + """Test InfoMixin.extract_from_obj with built-in types.""" + result = InfoMixin.extract_from_obj(obj_value) + + assert isinstance(result, dict) + assert "str" in result + assert "type" in result + assert result["type"] == obj_type.__name__ + assert result["str"] == str(obj_value) + + @pytest.mark.sanity + def test_extract_from_obj_without_dict(self): + """Test InfoMixin.extract_from_obj with objects without __dict__.""" + obj = 42 + result = InfoMixin.extract_from_obj(obj) + + assert isinstance(result, dict) + assert "attributes" in result + assert result["attributes"] == {} + assert result["type"] == "int" + assert result["str"] == "42" + + @pytest.mark.sanity + def test_extract_from_obj_with_private_attributes(self): + """Test InfoMixin.extract_from_obj filters private attributes.""" + + class ObjectWithPrivate: + def __init__(self): + self.public_attr = "public" + self._private_attr = "private" + self.__very_private = "very_private" + + obj = ObjectWithPrivate() + result = InfoMixin.extract_from_obj(obj) + + assert "public_attr" in result["attributes"] + assert result["attributes"]["public_attr"] == "public" + assert "_private_attr" not in result["attributes"] + assert "__very_private" not in result["attributes"] + + @pytest.mark.sanity + def test_extract_from_obj_complex_attributes(self): + """Test InfoMixin.extract_from_obj with complex attribute types.""" + + class ComplexObject: + def __init__(self): + self.simple_str = "test" + self.simple_int = 42 + self.simple_list = [1, 2, 3] + self.simple_dict = {"key": "value"} + self.complex_object = object() + + obj = ComplexObject() + result = InfoMixin.extract_from_obj(obj) + + attributes = result["attributes"] + assert attributes["simple_str"] == "test" + assert attributes["simple_int"] == 42 + assert attributes["simple_list"] == [1, 2, 3] + assert attributes["simple_dict"] == {"key": "value"} + assert isinstance(attributes["complex_object"], str) + + @pytest.mark.regression + def test_create_info_dict_consistency(self, valid_instances): + """Test InfoMixin.create_info_dict produces consistent results.""" + instance, _ = valid_instances + + result1 = InfoMixin.create_info_dict(instance) + result2 = InfoMixin.create_info_dict(instance) + + assert result1 == result2 + assert result1 is not result2 + + @pytest.mark.regression + def test_info_property_uses_create_info_dict(self, valid_instances): + """Test InfoMixin.info property uses create_info_dict method.""" + instance, _ = valid_instances + + info_result = instance.info + create_result = InfoMixin.create_info_dict(instance) + + assert info_result == create_result diff --git a/tests/unit/objects/test_statistics.py b/tests/unit/utils/test_statistics.py similarity index 90% rename from tests/unit/objects/test_statistics.py rename to tests/unit/utils/test_statistics.py index ede77175..fa8cccd0 100644 --- a/tests/unit/objects/test_statistics.py +++ b/tests/unit/utils/test_statistics.py @@ -704,82 +704,3 @@ def test_time_running_stats_update(): assert time_running_stats.rate_ms == pytest.approx( 3000 / (time.time() - time_running_stats.start_time), abs=0.1 ) - - -@pytest.mark.regression -def test_distribution_summary_concurrency_double_counting_regression(): - """Specific regression test for the double-counting bug in concurrency calculation. - - Before the fix, when events were merged due to epsilon, the deltas were summed - but then the active count wasn't properly accumulated, causing incorrect results. - - ### WRITTEN BY AI ### - """ - epsilon = 1e-6 - - # Create a scenario where multiple requests start at exactly the same time - # This should result in events being merged, testing the accumulation logic - same_start_time = 1.0 - requests = [ - (same_start_time, 3.0), - (same_start_time, 4.0), - (same_start_time, 5.0), - (same_start_time + epsilon / 3, 6.0), # Very close start (within epsilon) - ] - - distribution_summary = DistributionSummary.from_request_times( - requests, distribution_type="concurrency", epsilon=epsilon - ) - - # All requests start at the same time (or within epsilon), so they should - # all be considered concurrent from the start - # Expected timeline: - # - t=1.0-3.0: 4 concurrent requests - # - t=3.0-4.0: 3 concurrent requests - # - t=4.0-5.0: 2 concurrent requests - # - t=5.0-6.0: 1 concurrent request - - assert distribution_summary.max == 4.0 # All 4 requests concurrent at start - assert distribution_summary.min == 1.0 # 1 request still running at the end - - -@pytest.mark.sanity -def test_distribution_summary_concurrency_epsilon_edge_case(): - """Test the exact epsilon boundary condition. - - ### WRITTEN BY AI ### - """ - epsilon = 1e-6 - - # Test requests that are exactly epsilon apart - should be merged - requests_exactly_epsilon = [ - (1.0, 2.0), - (1.0 + epsilon, 2.5), # Exactly epsilon apart - (2.0, 2.5), # Another close request - ] - - dist_epsilon = DistributionSummary.from_request_times( - requests_exactly_epsilon, distribution_type="concurrency", epsilon=epsilon - ) - - # Should be treated as concurrent (merged events) - assert dist_epsilon.max == 2.0 - assert dist_epsilon.min == 2.0 - - # Test requests that are just over epsilon apart - should NOT be merged - requests_over_epsilon = [ - (1.0, 2.0), - (1.0 + epsilon * 1.1, 2.5), # Just over epsilon apart - (2.0, 2.5), # Another close request - ] - - dist_over_epsilon = DistributionSummary.from_request_times( - requests_over_epsilon, distribution_type="concurrency", epsilon=epsilon - ) - - # These should be treated separately, so max concurrency depends on overlap - # At t=1.0 to 1.0+epsilon*1.1: 1 concurrent - # At t=1.0+epsilon*1.1 to 2.0: 2 concurrent - # At t=2.0 to 2.5: 1 concurrent - assert dist_over_epsilon.max == 2.0 - assert dist_over_epsilon.min == 1.0 diff --git a/tests/unit/utils/test_text.py b/tests/unit/utils/test_text.py new file mode 100644 index 00000000..2f363c46 --- /dev/null +++ b/tests/unit/utils/test_text.py @@ -0,0 +1,531 @@ +from __future__ import annotations + +import gzip +import tempfile +from pathlib import Path +from unittest.mock import Mock, patch + +import httpx +import pytest + +from guidellm.utils.text import ( + MAX_PATH_LENGTH, + EndlessTextCreator, + clean_text, + filter_text, + format_value_display, + is_punctuation, + load_text, + split_text, + split_text_list_by_length, +) + + +def test_max_path_length(): + """Test that MAX_PATH_LENGTH is correctly defined.""" + assert isinstance(MAX_PATH_LENGTH, int) + assert MAX_PATH_LENGTH == 4096 + + +class TestFormatValueDisplay: + """Test suite for format_value_display.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ( + "value", + "label", + "units", + "total_characters", + "digits_places", + "decimal_places", + "expected", + ), + [ + (42.0, "test", "", None, None, None, "42 [info]test[/info]"), + (42.5, "test", "ms", None, None, 1, "42.5ms [info]test[/info]"), + (42.123, "test", "", None, 5, 2, " 42.12 [info]test[/info]"), + ( + 42.0, + "test", + "ms", + 30, + None, + 0, + " 42ms [info]test[/info]", + ), + ], + ) + def test_invocation( + self, + value, + label, + units, + total_characters, + digits_places, + decimal_places, + expected, + ): + """Test format_value_display with various parameters.""" + result = format_value_display( + value=value, + label=label, + units=units, + total_characters=total_characters, + digits_places=digits_places, + decimal_places=decimal_places, + ) + assert label in result + assert units in result + value_check = ( + str(int(value)) + if decimal_places == 0 + else ( + f"{value:.{decimal_places}f}" + if decimal_places is not None + else str(value) + ) + ) + assert value_check in result or str(value) in result + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("value", "label"), + [ + (None, "test"), + (42.0, None), + ("not_number", "test"), + ], + ) + def test_invocation_with_none_values(self, value, label): + """Test format_value_display with None/invalid inputs still works.""" + result = format_value_display(value, label) + assert isinstance(result, str) + if label is not None: + assert str(label) in result + if value is not None: + assert str(value) in result + + +class TestSplitTextListByLength: + """Test suite for split_text_list_by_length.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ( + "text_list", + "max_characters", + "pad_horizontal", + "pad_vertical", + "expected_structure", + ), + [ + ( + ["hello world", "test"], + 5, + False, + False, + [["hello", "world"], ["test"]], + ), + ( + ["short", "longer text"], + [5, 10], + True, + True, + [[" short"], ["longer", "text"]], + ), + ( + ["a", "b", "c"], + 10, + True, + True, + [[" a"], [" b"], [" c"]], + ), + ], + ) + def test_invocation( + self, + text_list, + max_characters, + pad_horizontal, + pad_vertical, + expected_structure, + ): + """Test split_text_list_by_length with various parameters.""" + result = split_text_list_by_length( + text_list, max_characters, pad_horizontal, pad_vertical + ) + assert len(result) == len(text_list) + if pad_vertical: + max_lines = max(len(lines) for lines in result) + assert all(len(lines) == max_lines for lines in result) + + @pytest.mark.sanity + def test_invalid_max_characters_length(self): + """Test split_text_list_by_length with mismatched max_characters length.""" + error_msg = "max_characters must be a list of the same length" + with pytest.raises(ValueError, match=error_msg): + split_text_list_by_length(["a", "b"], [5, 10, 15]) + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("text_list", "max_characters"), + [ + (None, 5), + (["test"], None), + (["test"], []), + ], + ) + def test_invalid_invocation(self, text_list, max_characters): + """Test split_text_list_by_length with invalid inputs.""" + with pytest.raises((TypeError, ValueError)): + split_text_list_by_length(text_list, max_characters) + + +class TestFilterText: + """Test suite for filter_text.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("text", "filter_start", "filter_end", "expected"), + [ + ("hello world test", "world", None, "world test"), + ("hello world test", None, "world", "hello "), + ("hello world test", "hello", "test", "hello world "), + ("hello world test", 6, 11, "world test"), + ("hello world test", 0, 5, "hello"), + ("hello world test", None, None, "hello world test"), + ], + ) + def test_invocation(self, text, filter_start, filter_end, expected): + """Test filter_text with various start and end markers.""" + result = filter_text(text, filter_start, filter_end) + assert result == expected + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("text", "filter_start", "filter_end"), + [ + ("hello", "notfound", None), + ("hello", None, "notfound"), + ("hello", "invalid_type", None), + ("hello", None, "invalid_type"), + ], + ) + def test_invalid_invocation(self, text, filter_start, filter_end): + """Test filter_text with invalid markers.""" + with pytest.raises((ValueError, TypeError)): + filter_text(text, filter_start, filter_end) + + +class TestCleanText: + """Test suite for clean_text.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("text", "expected"), + [ + ("hello world", "hello world"), + (" hello\n\nworld ", "hello world"), + ("hello\tworld\r\ntest", "hello world test"), + ("", ""), + (" ", ""), + ], + ) + def test_invocation(self, text, expected): + """Test clean_text with various whitespace scenarios.""" + result = clean_text(text) + assert result == expected + + @pytest.mark.sanity + @pytest.mark.parametrize( + "text", + [ + None, + 123, + ], + ) + def test_invalid_invocation(self, text): + """Test clean_text with invalid inputs.""" + with pytest.raises((TypeError, AttributeError)): + clean_text(text) + + +class TestSplitText: + """Test suite for split_text.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("text", "split_punctuation", "expected"), + [ + ("hello world", False, ["hello", "world"]), + ("hello, world!", True, ["hello", ",", "world", "!"]), + ("test.example", False, ["test.example"]), + ("test.example", True, ["test", ".", "example"]), + ("", False, []), + ], + ) + def test_invocation(self, text, split_punctuation, expected): + """Test split_text with various punctuation options.""" + result = split_text(text, split_punctuation) + assert result == expected + + @pytest.mark.sanity + @pytest.mark.parametrize( + "text", + [ + None, + 123, + ], + ) + def test_invalid_invocation(self, text): + """Test split_text with invalid inputs.""" + with pytest.raises((TypeError, AttributeError)): + split_text(text) + + +class TestLoadText: + """Test suite for load_text.""" + + @pytest.mark.smoke + def test_empty_data(self): + """Test load_text with empty data.""" + result = load_text("") + assert result == "" + + @pytest.mark.smoke + def test_raw_text(self): + """Test load_text with raw text that's not a file.""" + long_text = "a" * (MAX_PATH_LENGTH + 1) + result = load_text(long_text) + assert result == long_text + + @pytest.mark.smoke + def test_local_file(self): + """Test load_text with local file.""" + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as tmp: + test_content = "test file content" + tmp.write(test_content) + tmp.flush() + + result = load_text(tmp.name) + assert result == test_content + + Path(tmp.name).unlink() + + @pytest.mark.smoke + def test_gzipped_file(self): + """Test load_text with gzipped file.""" + with tempfile.NamedTemporaryFile(delete=False, suffix=".gz") as tmp: + test_content = "test gzipped content" + with gzip.open(tmp.name, "wt") as gzf: + gzf.write(test_content) + + result = load_text(tmp.name) + assert result == test_content + + Path(tmp.name).unlink() + + @pytest.mark.smoke + @patch("httpx.Client") + def test_url_loading(self, mock_client): + """Test load_text with HTTP URL.""" + mock_response = Mock() + mock_response.text = "url content" + mock_client.return_value.__enter__.return_value.get.return_value = mock_response + + result = load_text("http://example.com/test.txt") + assert result == "url content" + + @pytest.mark.smoke + @patch("guidellm.utils.text.files") + @patch("guidellm.utils.text.as_file") + def test_package_data_loading(self, mock_as_file, mock_files): + """Test load_text with package data.""" + mock_resource = Mock() + mock_files.return_value.joinpath.return_value = mock_resource + + mock_file = Mock() + mock_file.read.return_value = "package data content" + mock_as_file.return_value.__enter__.return_value = mock_file + + with patch("gzip.open") as mock_gzip: + mock_gzip.return_value.__enter__.return_value = mock_file + result = load_text("data:test.txt") + assert result == "package data content" + + @pytest.mark.sanity + def test_nonexistent_file(self): + """Test load_text with nonexistent file returns the path as raw text.""" + result = load_text("/nonexistent/path/file.txt") + assert result == "/nonexistent/path/file.txt" + + @pytest.mark.sanity + @patch("httpx.Client") + def test_url_error(self, mock_client): + """Test load_text with HTTP error.""" + mock_client.return_value.__enter__.return_value.get.side_effect = ( + httpx.HTTPStatusError("HTTP error", request=None, response=None) + ) + + with pytest.raises(httpx.HTTPStatusError): + load_text("http://example.com/error.txt") + + +class TestIsPunctuation: + """Test suite for is_puncutation.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("text", "expected"), + [ + (".", True), + (",", True), + ("!", True), + ("?", True), + (";", True), + ("a", False), + ("1", False), + (" ", False), + ("ab", False), + ("", False), + ], + ) + def test_invocation(self, text, expected): + """Test is_punctuation with various characters.""" + result = is_punctuation(text) + assert result == expected + + @pytest.mark.sanity + @pytest.mark.parametrize( + "text", + [ + None, + 123, + ], + ) + def test_invalid_invocation(self, text): + """Test is_punctuation with invalid inputs.""" + with pytest.raises((TypeError, AttributeError)): + is_punctuation(text) + + +class TestEndlessTextCreator: + """Test suite for EndlessTextCreator.""" + + @pytest.fixture( + params=[ + { + "data": "hello world test", + "filter_start": None, + "filter_end": None, + }, + { + "data": "hello world test", + "filter_start": "world", + "filter_end": None, + }, + {"data": "one two three four", "filter_start": 0, "filter_end": 9}, + ], + ids=["no_filter", "string_filter", "index_filter"], + ) + def valid_instances(self, request): + """Fixture providing test data for EndlessTextCreator.""" + constructor_args = request.param + instance = EndlessTextCreator(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test EndlessTextCreator signatures and methods.""" + assert hasattr(EndlessTextCreator, "__init__") + assert hasattr(EndlessTextCreator, "create_text") + instance = EndlessTextCreator("test") + assert hasattr(instance, "data") + assert hasattr(instance, "text") + assert hasattr(instance, "filtered_text") + assert hasattr(instance, "words") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test EndlessTextCreator initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, EndlessTextCreator) + assert instance.data == constructor_args["data"] + assert isinstance(instance.text, str) + assert isinstance(instance.filtered_text, str) + assert isinstance(instance.words, list) + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("data", "filter_start", "filter_end"), + [ + ("test", "notfound", None), + ], + ) + def test_invalid_initialization_values(self, data, filter_start, filter_end): + """Test EndlessTextCreator with invalid initialization values.""" + with pytest.raises((TypeError, ValueError)): + EndlessTextCreator(data, filter_start, filter_end) + + @pytest.mark.smoke + def test_initialization_with_none(self): + """Test EndlessTextCreator handles None data gracefully.""" + instance = EndlessTextCreator(None) + assert isinstance(instance, EndlessTextCreator) + assert instance.data is None + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("start", "length", "expected_length"), + [ + (0, 5, 5), + (2, 3, 3), + (0, 0, 0), + ], + ) + def test_create_text(self, valid_instances, start, length, expected_length): + """Test EndlessTextCreator.create_text.""" + instance, constructor_args = valid_instances + result = instance.create_text(start, length) + assert isinstance(result, str) + if length > 0 and instance.words: + assert len(result) > 0 + + @pytest.mark.smoke + def test_create_text_cycling(self): + """Test EndlessTextCreator.create_text cycling behavior.""" + instance = EndlessTextCreator("one two three") + result1 = instance.create_text(0, 3) + result2 = instance.create_text(3, 3) + assert isinstance(result1, str) + assert isinstance(result2, str) + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("start", "length"), + [ + ("invalid", 5), + (0, "invalid"), + ], + ) + def test_create_text_invalid(self, valid_instances, start, length): + """Test EndlessTextCreator.create_text with invalid inputs.""" + instance, constructor_args = valid_instances + with pytest.raises((TypeError, ValueError)): + instance.create_text(start, length) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("start", "length", "min_length"), + [ + (-1, 5, 0), + (0, -1, 0), + ], + ) + def test_create_text_edge_cases(self, valid_instances, start, length, min_length): + """Test EndlessTextCreator.create_text with edge cases.""" + instance, constructor_args = valid_instances + result = instance.create_text(start, length) + assert isinstance(result, str) + assert len(result) >= min_length