diff --git a/fastapi_assets/request_validators/path_validator.py b/fastapi_assets/request_validators/path_validator.py new file mode 100644 index 0000000..954b239 --- /dev/null +++ b/fastapi_assets/request_validators/path_validator.py @@ -0,0 +1,291 @@ +"""Module providing the PathValidator for validating path parameters in FastAPI.""" +import re +from typing import Any, Callable, List, Optional, Union +from fastapi import Path +from fastapi_assets.core.base_validator import BaseValidator, ValidationError + + +class PathValidator(BaseValidator): + r""" + A general-purpose dependency for validating path parameters in FastAPI. + + It validates path parameters with additional constraints like allowed values, + regex patterns, string length checks, numeric bounds, and custom validators. + + .. code-block:: python + from fastapi import FastAPI + from fastapi_assets.path_validator import PathValidator + + app = FastAPI() + + # Create reusable validators + item_id_validator = PathValidator( + gt=0, + lt=1000, + on_error_detail="Item ID must be between 1 and 999" + ) + + username_validator = PathValidator( + min_length=5, + max_length=15, + pattern=r"^[a-zA-Z0-9]+$", + on_error_detail="Username must be 5-15 alphanumeric characters" + ) + + @app.get("/items/{item_id}") + def get_item(item_id: int = item_id_validator): + return {"item_id": item_id} + + @app.get("/users/{username}") + def get_user(username: str = username_validator): + return {"username": username} + """ + + def __init__( + self, + default: Any = ..., + *, + allowed_values: Optional[List[Any]] = None, + pattern: Optional[str] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + gt: Optional[Union[int, float]] = None, + lt: Optional[Union[int, float]] = None, + ge: Optional[Union[int, float]] = None, + le: Optional[Union[int, float]] = None, + validator: Optional[Callable[[Any], bool]] = None, + on_error_detail: Optional[Union[str, Callable[[Any], str]]] = None, + # Standard Path() parameters + title: Optional[str] = None, + description: Optional[str] = None, + alias: Optional[str] = None, + deprecated: Optional[bool] = None, + **path_kwargs : Any + ) -> None: + """ + Initializes the PathValidator. + + Args: + default: Default value for the path parameter (usually ... for required). + allowed_values: List of allowed values for the parameter. + pattern: Regex pattern the parameter must match (for strings). + min_length: Minimum length for string parameters. + max_length: Maximum length for string parameters. + gt: Value must be greater than this (for numeric parameters). + lt: Value must be less than this (for numeric parameters). + ge: Value must be greater than or equal to this. + le: Value must be less than or equal to this. + validator: Custom validation function that takes the value and returns bool. + on_error_detail: Custom error message for validation failures. + title: Title for API documentation. + description: Description for API documentation. + alias: Alternative parameter name. + deprecated: Whether the parameter is deprecated. + **path_kwargs: Additional arguments passed to FastAPI's Path(). + """ + # Call super() with default error handling + super().__init__( + status_code=400, + error_detail=on_error_detail or "Path parameter validation failed." + ) + + # Store validation rules + self._allowed_values = allowed_values + self._pattern = re.compile(pattern) if pattern else None + self._min_length = min_length + self._max_length = max_length + self._gt = gt + self._lt = lt + self._ge = ge + self._le = le + self._custom_validator = validator + + # Store the underlying FastAPI Path parameter + # This preserves all standard Path() features (title, description, etc.) + self._path_param = Path( + default, + title=title, + description=description, + alias=alias, + deprecated=deprecated, + gt=gt, + lt=lt, + ge=ge, + le=le, + **path_kwargs + ) + + def __call__(self, value: Any = None) -> Any: + """ + FastAPI dependency entry point for path validation. + + Args: + value: The path parameter value extracted from the URL. + + Returns: + The validated path parameter value. + + Raises: + HTTPException: If validation fails. + """ + # If value is None, it means FastAPI will inject the actual path parameter + # This happens because FastAPI handles the Path() dependency internally + if value is None: + # Return a dependency that FastAPI will use + async def dependency(param_value: Any = self._path_param) -> Any: + return self._validate(param_value) + return dependency + + # If value is provided (for testing), validate directly + return self._validate(value) + + def _validate(self, value: Any) -> Any: + """ + Runs all validation checks on the parameter value. + + Args: + value: The path parameter value to validate. + + Returns: + The validated value. + + Raises: + HTTPException: If any validation check fails. + """ + try: + self._validate_allowed_values(value) + self._validate_pattern(value) + self._validate_length(value) + self._validate_numeric_bounds(value) + self._validate_custom(value) + except ValidationError as e: + # Convert ValidationError to HTTPException + self._raise_error( + status_code=e.status_code, + detail=str(e.detail) + ) + + return value + + def _validate_allowed_values(self, value: Any) -> None: + """ + Checks if the value is in the list of allowed values. + + Args: + value: The parameter value to check. + + Raises: + ValidationError: If the value is not in allowed_values. + """ + if self._allowed_values is None: + return # No validation rule set + + if value not in self._allowed_values: + detail = ( + f"Value '{value}' is not allowed. " + f"Allowed values are: {', '.join(map(str, self._allowed_values))}" + ) + raise ValidationError(detail=detail, status_code=400) + + def _validate_pattern(self, value: Any) -> None: + """ + Checks if the string value matches the required regex pattern. + + Args: + value: The parameter value to check. + + Raises: + ValidationError: If the value doesn't match the pattern. + """ + if self._pattern is None: + return # No validation rule set + + if not isinstance(value, str): + return # Pattern validation only applies to strings + + if not self._pattern.match(value): + detail = ( + f"Value '{value}' does not match the required pattern: " + f"{self._pattern.pattern}" + ) + raise ValidationError(detail=detail, status_code=400) + + def _validate_length(self, value: Any) -> None: + """ + Checks if the string length is within the specified bounds. + + Args: + value: The parameter value to check. + + Raises: + ValidationError: If the length is out of bounds. + """ + if not isinstance(value, str): + return # Length validation only applies to strings + + value_len = len(value) + + if self._min_length is not None and value_len < self._min_length: + detail = ( + f"Value '{value}' is too short. " + f"Minimum length is {self._min_length} characters." + ) + raise ValidationError(detail=detail, status_code=400) + + if self._max_length is not None and value_len > self._max_length: + detail = ( + f"Value '{value}' is too long. " + f"Maximum length is {self._max_length} characters." + ) + raise ValidationError(detail=detail, status_code=400) + + def _validate_numeric_bounds(self, value: Any) -> None: + """ + Checks if numeric values satisfy gt, lt, ge, le constraints. + + Args: + value: The parameter value to check. + + Raises: + ValidationError: If the value is out of the specified bounds. + """ + if not isinstance(value, (int, float)): + return # Numeric validation only applies to numbers + + if self._gt is not None and value <= self._gt: + detail = f"Value must be greater than {self._gt}" + raise ValidationError(detail=detail, status_code=400) + + if self._lt is not None and value >= self._lt: + detail = f"Value must be less than {self._lt}" + raise ValidationError(detail=detail, status_code=400) + + if self._ge is not None and value < self._ge: + detail = f"Value must be greater than or equal to {self._ge}" + raise ValidationError(detail=detail, status_code=400) + + if self._le is not None and value > self._le: + detail = f"Value must be less than or equal to {self._le}" + raise ValidationError(detail=detail, status_code=400) + + def _validate_custom(self, value: Any) -> None: + """ + Runs a custom validation function if provided. + + Args: + value: The parameter value to check. + + Raises: + ValidationError: If the custom validator returns False or raises an exception. + """ + if self._custom_validator is None: + return # No custom validator set + + try: + if not self._custom_validator(value): + detail = f"Custom validation failed for value '{value}'" + raise ValidationError(detail=detail, status_code=400) + except Exception as e: + # If the validator itself raises an exception, catch it + detail = f"Custom validation error: {str(e)}" + raise ValidationError(detail=detail, status_code=400) \ No newline at end of file diff --git a/fastapi_assets/validators/header_validator.py b/fastapi_assets/validators/header_validator.py new file mode 100644 index 0000000..df7de14 --- /dev/null +++ b/fastapi_assets/validators/header_validator.py @@ -0,0 +1,275 @@ +"""HeaderValidator for validating HTTP headers in FastAPI.""" +import re +from typing import Any, Callable, Dict, List, Optional, Union, Pattern +from fastapi_assets.core.base_validator import BaseValidator, ValidationError +from fastapi import Header + + +# Predefined format patterns for common header validation use cases +_FORMAT_PATTERNS: Dict[str, str] = { + "uuid4": r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$", + "email": r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$", + "bearer_token": r"^Bearer [a-zA-Z0-9\-._~+/]+=*$", + "datetime": r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(\.\d+)?(Z|[+-]\d{2}:\d{2})?$", + "alphanumeric": r"^[a-zA-Z0-9]+$", + "api_key": r"^[a-zA-Z0-9]{32,}$", +} + + +class HeaderValidator(BaseValidator): + r""" + A general-purpose dependency for validating HTTP request headers in FastAPI. + + It extends FastAPI's built-in Header with additional validation capabilities + including pattern matching, format validation, allowed values, and custom validators. + + .. code-block:: python + from fastapi import FastAPI + from fastapi_assets.validators.header_validator import HeaderValidator + + app = FastAPI() + + # Validate API key header with pattern + api_key_validator = HeaderValidator( + alias="X-API-Key", + pattern=r"^[a-zA-Z0-9]{32}$", + required=True, + on_error_detail="Invalid API key format" + ) + + # Validate authorization header with bearer token format + auth_validator = HeaderValidator( + alias="Authorization", + format="bearer_token", + required=True + ) + + # Validate custom header with allowed values + version_validator = HeaderValidator( + alias="X-API-Version", + allowed_values=["v1", "v2", "v3"], + required=False, + default="v1" + ) + + @app.get("/secure") + def secure_endpoint( + api_key: str = api_key_validator, + auth: str = auth_validator, + version: str = version_validator + ): + return {"message": "Access granted", "version": version} + """ + + def __init__( + self, + default: Any = ..., + *, + alias: Optional[str] = None, + convert_underscores: bool = True, + pattern: Optional[str] = None, + format: Optional[str] = None, + allowed_values: Optional[List[str]] = None, + validator: Optional[Callable[[str], bool]] = None, + required: Optional[bool] = None, + on_error_detail: Optional[Union[str, Callable[[Any], str]]] = None, + title: Optional[str] = None, + description: Optional[str] = None, + **header_kwargs: Any + ) -> None: + # Call super() with default error handling + super().__init__( + status_code=400, + error_detail=on_error_detail or "Header validation failed." + ) + + # Determine if header is required + if required is None: + self._required = default is ... + else: + self._required = required + + # Store validation rules + self._allowed_values = allowed_values + self._custom_validator = validator + + # Define type hints for attributes + self._pattern: Optional[Pattern[str]] = None + self._format_name: Optional[str] = None + + # Handle pattern and format keys + if pattern and format: + raise ValueError("Cannot specify both 'pattern' and 'format'. Choose one.") + + if format: + if format not in _FORMAT_PATTERNS: + raise ValueError( + f"Unknown format '{format}'. " + f"Available formats: {', '.join(_FORMAT_PATTERNS.keys())}" + ) + self._pattern = re.compile(_FORMAT_PATTERNS[format], re.IGNORECASE) + self._format_name = format + elif pattern: + self._pattern = re.compile(pattern) + self._format_name = None + else: + self._pattern = None + self._format_name = None + + # Store the underlying FastAPI Header parameter + self._header_param = Header( + default, + alias=alias, + convert_underscores=convert_underscores, + title=title, + description=description, + **header_kwargs + ) + + # Store custom error detail + self._on_error_detail = on_error_detail + def __call__(self, header_value: Optional[str] = None) -> Any: + """ + FastAPI dependency entry point for header validation. + + Args: + header_value: The header value extracted from the request. + + Returns: + The validated header value. + + Raises: + HTTPException: If validation fails. + """ + # If value is None, return a dependency that FastAPI will use + if header_value is None: + def dependency(value: Optional[str] = self._header_param) -> Optional[str]: + return self._validate(value) + return dependency + + # If value is provided (for testing), validate directly + return self._validate(header_value) + + def _validate(self, value: Optional[str]) -> Optional[str]: + """ + Runs all validation checks on the header value. + + Args: + value: The header value to validate. + + Returns: + The validated value. + + Raises: + HTTPException: If any validation check fails. + """ + try: + self._validate_required(value) + except ValidationError as e: + self._raise_error( + value= value, + status_code=e.status_code, + detail=str(e.detail) + ) + if value is None or value == "": + return value or "" + try: + self._validate_allowed_values(value) + self._validate_pattern(value) + self._validate_custom(value) + + except ValidationError as e: + # Convert ValidationError to HTTPException + self._raise_error( + value= value, + status_code=e.status_code, + detail=str(e.detail) + ) + + return value + + def _validate_required(self, value: Optional[str]) -> None: + """ + Checks if the header is present when required. + + Args: + value: The header value to check. + + Raises: + ValidationError: If the header is required but missing. + """ + if self._required and (value is None or value == ""): + detail = self._on_error_detail or "Required header is missing." + if callable(detail): + detail_str = detail(value) + else: + detail_str = str(detail) + + raise ValidationError(detail=detail_str, status_code=400) + + def _validate_allowed_values(self, value: str) -> None: + """ + Checks if the value is in the list of allowed values. + + Args: + value: The header value to check. + + Raises: + ValidationError: If the value is not in allowed_values. + """ + if self._allowed_values is None: + return # No validation rule set + + if value not in self._allowed_values: + detail = ( + f"Header value '{value}' is not allowed. " + f"Allowed values are: {', '.join(self._allowed_values)}" + ) + raise ValidationError(detail=detail, status_code=400) + + def _validate_pattern(self, value: str) -> None: + """ + Checks if the header value matches the required regex pattern. + + Args: + value: The header value to check. + + Raises: + ValidationError: If the value doesn't match the pattern. + """ + if self._pattern is None: + return # No validation rule set + + if not self._pattern.match(value): + if self._format_name: + detail = ( + f"Header value does not match the required format: '{self._format_name}'" + ) + else: + detail = ( + f"Header value '{value}' does not match the required pattern: " + f"{self._pattern.pattern}" + ) + raise ValidationError(detail=detail, status_code=400) + + def _validate_custom(self, value: str) -> None: + """ + Runs a custom validation function if provided. + + Args: + value: The header value to check. + + Raises: + ValidationError: If the custom validator returns False or raises an exception. + """ + if self._custom_validator is None: + return # No custom validator set + + try: + if not self._custom_validator(value): + detail = f"Custom validation failed for header value '{value}'" + raise ValidationError(detail=detail, status_code=400) + except Exception as e: + # If the validator itself raises an exception, catch it + detail = f"Custom validation error: {str(e)}" + raise ValidationError(detail=detail, status_code=400) \ No newline at end of file diff --git a/tests/test_header_validator.py b/tests/test_header_validator.py new file mode 100644 index 0000000..b8b7f21 --- /dev/null +++ b/tests/test_header_validator.py @@ -0,0 +1,373 @@ +""" +Tests for the HeaderValidator class. +""" + +import pytest +from fastapi import HTTPException +from fastapi_assets.core.base_validator import ValidationError +from fastapi_assets.validators.header_validator import HeaderValidator + + +# --- Fixtures --- + +@pytest.fixture +def base_validator(): + """Returns a basic HeaderValidator with no rules.""" + return HeaderValidator() + + +@pytest.fixture +def required_validator(): + """Returns a HeaderValidator with required=True.""" + return HeaderValidator(required=True) + + +@pytest.fixture +def pattern_validator(): + """Returns a HeaderValidator with pattern validation.""" + return HeaderValidator(pattern=r"^[a-zA-Z0-9]{32}$") + + +@pytest.fixture +def format_validator(): + """Returns a HeaderValidator with bearer_token format.""" + return HeaderValidator(format="bearer_token") + + +@pytest.fixture +def allowed_values_validator(): + """Returns a HeaderValidator with allowed values.""" + return HeaderValidator(allowed_values=["v1", "v2", "v3"]) + + +@pytest.fixture +def custom_validator_obj(): + """Returns a HeaderValidator with custom validator function.""" + def is_even_length(val: str) -> bool: + return len(val) % 2 == 0 + + return HeaderValidator(validator=is_even_length) + + +# --- Test Classes --- + +class TestHeaderValidatorInit: + """Tests for the HeaderValidator's __init__ method.""" + + def test_init_defaults(self): + """Tests that all validation rules are None by default.""" + validator = HeaderValidator() + assert validator._allowed_values is None + assert validator._pattern is None + assert validator._custom_validator is None + assert validator._format_name is None + + def test_init_required_true(self): + """Tests that required flag is stored correctly.""" + validator = HeaderValidator(required=True) + assert validator._required is True + + def test_init_required_false(self): + """Tests that required can be set to False.""" + validator = HeaderValidator(required=False, default="default_value") + assert validator._required is False + + def test_init_pattern_compilation(self): + """Tests that pattern is compiled to regex.""" + pattern = r"^[A-Z0-9]+$" + validator = HeaderValidator(pattern=pattern) + assert validator._pattern is not None + assert validator._pattern.pattern == pattern + + def test_init_format_uuid4(self): + """Tests that format='uuid4' is recognized.""" + validator = HeaderValidator(format="uuid4") + assert validator._format_name == "uuid4" + assert validator._pattern is not None + + def test_init_format_email(self): + """Tests that format='email' is recognized.""" + validator = HeaderValidator(format="email") + assert validator._format_name == "email" + assert validator._pattern is not None + + def test_init_format_bearer_token(self): + """Tests that format='bearer_token' is recognized.""" + validator = HeaderValidator(format="bearer_token") + assert validator._format_name == "bearer_token" + assert validator._pattern is not None + + def test_init_invalid_format(self): + """Tests that invalid format raises ValueError.""" + with pytest.raises(ValueError, match="Unknown format"): + HeaderValidator(format="invalid_format") + + def test_init_pattern_and_format_conflict(self): + """Tests that both pattern and format cannot be specified.""" + with pytest.raises(ValueError, match="Cannot specify both"): + HeaderValidator(pattern=r"^test$", format="uuid4") + + def test_init_allowed_values(self): + """Tests that allowed values are stored correctly.""" + values = ["alpha", "beta", "gamma"] + validator = HeaderValidator(allowed_values=values) + assert validator._allowed_values == values + + def test_init_custom_validator_function(self): + """Tests that custom validator function is stored.""" + def is_positive(val: str) -> bool: + return val.startswith("+") + + validator = HeaderValidator(validator=is_positive) + assert validator._custom_validator is not None + assert validator._custom_validator("+test") is True + assert validator._custom_validator("-test") is False + + def test_init_custom_error_detail(self): + """Tests that custom error detail is stored.""" + custom_msg = "Invalid header value" + validator = HeaderValidator(on_error_detail=custom_msg) + assert validator._on_error_detail == custom_msg + + def test_init_alias(self): + """Tests that alias for header name is set.""" + validator = HeaderValidator(alias="X-API-Key") + assert validator._header_param is not None + + +class TestHeaderValidatorValidateRequired: + """Tests for the _validate_required method.""" + + def test_required_with_value(self, required_validator): + """Tests required validation passes when value is present.""" + try: + required_validator._validate_required("some_value") + except ValidationError: + pytest.fail("Required validation failed with valid value") + + def test_required_missing_value(self, required_validator): + """Tests required validation fails when value is None.""" + with pytest.raises(ValidationError) as e: + required_validator._validate_required(None) + + assert e.value.status_code == 400 + assert "missing" in e.value.detail.lower() + + def test_required_empty_string(self, required_validator): + """Tests required validation fails with empty string.""" + with pytest.raises(ValidationError): + required_validator._validate_required("") + + def test_not_required_with_none(self, base_validator): + """Tests validation passes when not required and value is None.""" + base_validator._required = False + try: + base_validator._validate_required(None) + except ValidationError: + pytest.fail("Non-required validation should pass with None") + + +class TestHeaderValidatorValidateAllowedValues: + """Tests for the _validate_allowed_values method.""" + + def test_allowed_values_no_rule(self, base_validator): + """Tests that no validation happens when no allowed_values rule.""" + try: + base_validator._validate_allowed_values("any_value") + except ValidationError: + pytest.fail("Validation failed with no rule set") + + def test_allowed_values_valid(self, allowed_values_validator): + """Tests allowed value passes validation.""" + try: + allowed_values_validator._validate_allowed_values("v1") + except ValidationError: + pytest.fail("Valid allowed value failed") + + def test_allowed_values_invalid(self, allowed_values_validator): + """Tests invalid allowed value raises error.""" + with pytest.raises(ValidationError) as e: + allowed_values_validator._validate_allowed_values("v4") + + assert e.value.status_code == 400 + assert "not allowed" in e.value.detail.lower() + + def test_allowed_values_all_options(self, allowed_values_validator): + """Tests all allowed values individually.""" + for value in ["v1", "v2", "v3"]: + try: + allowed_values_validator._validate_allowed_values(value) + except ValidationError: + pytest.fail(f"Valid allowed value '{value}' failed") + + def test_allowed_values_case_sensitive(self, allowed_values_validator): + """Tests that allowed values are case-sensitive.""" + with pytest.raises(ValidationError): + allowed_values_validator._validate_allowed_values("V1") + + +class TestHeaderValidatorValidatePattern: + """Tests for the _validate_pattern method.""" + + def test_pattern_no_rule(self, base_validator): + """Tests validation passes with no pattern rule.""" + try: + base_validator._validate_pattern("anything") + except ValidationError: + pytest.fail("Validation failed with no pattern rule") + + def test_pattern_valid_match(self, pattern_validator): + """Tests pattern matches valid value.""" + try: + pattern_validator._validate_pattern("abcdefghijklmnopqrstuvwxyz123456") + except ValidationError: + pytest.fail("Valid pattern match failed") + + def test_pattern_invalid_match(self, pattern_validator): + """Tests pattern fails on invalid value.""" + with pytest.raises(ValidationError) as e: + pattern_validator._validate_pattern("short") + + assert e.value.status_code == 400 + assert "does not match" in e.value.detail.lower() + + def test_pattern_format_uuid4_valid(self): + """Tests uuid4 format validation passes.""" + validator = HeaderValidator(format="uuid4") + valid_uuid = "550e8400-e29b-41d4-a716-446655440000" + try: + validator._validate_pattern(valid_uuid) + except ValidationError: + pytest.fail("Valid UUID4 failed") + + def test_pattern_format_uuid4_invalid(self): + """Tests uuid4 format validation fails.""" + validator = HeaderValidator(format="uuid4") + with pytest.raises(ValidationError) as e: + validator._validate_pattern("not-a-uuid") + + assert "format" in e.value.detail.lower() + + def test_pattern_format_bearer_token_valid(self, format_validator): + """Tests bearer token format validation passes.""" + try: + format_validator._validate_pattern("Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9") + except ValidationError: + pytest.fail("Valid bearer token failed") + + def test_pattern_format_bearer_token_invalid(self, format_validator): + """Tests bearer token format validation fails.""" + with pytest.raises(ValidationError): + format_validator._validate_pattern("InvalidToken") + + def test_pattern_format_email_valid(self): + """Tests email format validation passes.""" + validator = HeaderValidator(format="email") + try: + validator._validate_pattern("user@example.com") + except ValidationError: + pytest.fail("Valid email failed") + + def test_pattern_format_email_invalid(self): + """Tests email format validation fails.""" + validator = HeaderValidator(format="email") + with pytest.raises(ValidationError): + validator._validate_pattern("not-an-email") + + +class TestHeaderValidatorValidateCustom: + """Tests for the _validate_custom method.""" + + def test_custom_no_validator(self, base_validator): + """Tests validation passes with no custom validator.""" + try: + base_validator._validate_custom("any_value") + except ValidationError: + pytest.fail("Validation failed with no custom validator") + + def test_custom_validator_valid(self, custom_validator_obj): + """Tests custom validator passes on valid input.""" + try: + custom_validator_obj._validate_custom("even") # 4 chars + except ValidationError: + pytest.fail("Valid custom validation failed") + + def test_custom_validator_invalid(self, custom_validator_obj): + """Tests custom validator fails on invalid input.""" + with pytest.raises(ValidationError) as e: + custom_validator_obj._validate_custom("odd") # 3 chars + + assert e.value.status_code == 400 + # Accept either failure message depending on your validator code + assert ( + "custom validation failed" in e.value.detail.lower() + or "custom validation error" in e.value.detail.lower() + ) + + + def test_custom_validator_exception(self): + """Tests custom validator exception is caught.""" + def buggy_validator(val: str) -> bool: + raise ValueError("Unexpected error") + + validator = HeaderValidator(validator=buggy_validator) + with pytest.raises(ValidationError) as e: + validator._validate_custom("test") + + assert "custom validation error" in e.value.detail.lower() + + +class TestHeaderValidatorValidate: + """Tests for the main _validate method.""" + + def test_validate_valid_header(self): + """Tests full validation pipeline with valid header.""" + validator = HeaderValidator( + required=True, + allowed_values=["api", "web"], + pattern=r"^[a-z]+$" + ) + try: + result = validator._validate("api") + assert result == "api" + except ValidationError: + pytest.fail("Valid header failed validation") + + def test_validate_fails_required(self): + """Tests validation fails on required check.""" + validator = HeaderValidator(required=True) + with pytest.raises(HTTPException): + validator._validate(None) + + def test_validate_fails_allowed_values(self): + """Tests validation fails on allowed values check.""" + validator = HeaderValidator(allowed_values=["good"]) + with pytest.raises(HTTPException): + validator._validate("bad") + + def test_validate_fails_pattern(self): + """Tests validation fails on pattern check.""" + validator = HeaderValidator(pattern=r"^[0-9]+$") + with pytest.raises(HTTPException): + validator._validate("abc") + + def test_validate_fails_custom(self): + """Tests validation fails on custom validator.""" + def no_spaces(val: str) -> bool: + return " " not in val + + validator = HeaderValidator(validator=no_spaces) + with pytest.raises(HTTPException): + validator._validate("has space") + + def test_validate_empty_optional_header(self): + """Tests optional header with empty string passes.""" + validator = HeaderValidator(required=False) + result = validator._validate("") + assert result == "" + + def test_validate_none_optional_header(self): + """Tests optional header with None passes.""" + validator = HeaderValidator(required=False) + result = validator._validate(None) + assert result is None or result == "" \ No newline at end of file diff --git a/tests/test_path_validator.py b/tests/test_path_validator.py new file mode 100644 index 0000000..e4a31a5 --- /dev/null +++ b/tests/test_path_validator.py @@ -0,0 +1,261 @@ +""" +tests for the PathValidator class. +""" +from fastapi import HTTPException +import pytest +from fastapi_assets.core.base_validator import ValidationError +from fastapi_assets.request_validators.path_validator import PathValidator + +# Fixtures for common PathValidator configurations +@pytest.fixture +def base_validator(): + """Returns a basic PathValidator with no rules.""" + return PathValidator() + +@pytest.fixture +def numeric_validator(): + """Returns a PathValidator configured for numeric validation.""" + return PathValidator(gt=0, lt=1000) + +@pytest.fixture +def string_validator(): + """Returns a PathValidator configured for string validation.""" + return PathValidator( + min_length=3, + max_length=15, + pattern=r"^[a-zA-Z0-9_]+$" + ) + +@pytest.fixture +def allowed_values_validator(): + """Returns a PathValidator with allowed values.""" + return PathValidator( + allowed_values=["active", "inactive", "pending"] + ) + +# Test class for constructor __init__ behavior +class TestPathValidatorInit: + def test_init_defaults(self): + """Tests that all validation rules are None by default.""" + validator = PathValidator() + assert validator._allowed_values is None + assert validator._pattern is None + assert validator._min_length is None + assert validator._max_length is None + assert validator._gt is None + assert validator._lt is None + assert validator._ge is None + assert validator._le is None + assert validator._custom_validator is None + + def test_init_allowed_values(self): + """Tests that allowed_values are stored correctly.""" + values = ["active", "inactive"] + validator = PathValidator(allowed_values=values) + assert validator._allowed_values == values + + def test_init_pattern_compilation(self): + """Tests that regex pattern is compiled.""" + pattern = r"^[a-z0-9]+$" + validator = PathValidator(pattern=pattern) + assert validator._pattern is not None + assert validator._pattern.pattern == pattern + + def test_init_numeric_bounds(self): + """Tests that numeric bounds are stored correctly.""" + validator = PathValidator(gt=0, lt=100, ge=1, le=99) + assert validator._gt == 0 + assert validator._lt == 100 + assert validator._ge == 1 + assert validator._le == 99 + + def test_init_length_bounds(self): + """Tests that length bounds are stored correctly.""" + validator = PathValidator(min_length=5, max_length=20) + assert validator._min_length == 5 + assert validator._max_length == 20 + + def test_init_custom_error_detail(self): + """Tests that custom error messages are stored.""" + custom_error = "Invalid path parameter" + validator = PathValidator(on_error_detail=custom_error) + # _error_detail attribute holds error message + assert validator.error_detail == custom_error or custom_error in str(validator.__dict__) + + def test_init_custom_validator_function(self): + """Tests that custom validator function is stored.""" + def is_even(x): return x % 2 == 0 + validator = PathValidator(validator=is_even) + # Validate custom function works + assert validator._custom_validator(4) is True + assert validator._custom_validator(3) is False + + def test_init_fastapi_path_creation(self): + """Tests that internal FastAPI Path object is created.""" + validator = PathValidator( + title="Item ID", + description="The unique identifier", + gt=0, + lt=1000 + ) + assert validator._path_param is not None + + def test_init_combined_rules(self): + """Tests initialization with multiple combined rules.""" + validator = PathValidator( + min_length=3, + max_length=20, + pattern=r"^[a-zA-Z]+$", + title="Category", + description="Product category slug" + ) + assert validator._min_length == 3 + assert validator._max_length == 20 + assert validator._pattern is not None + +# Validation method tests +class TestPathValidatorValidateAllowedValues: + def test_allowed_values_no_rule(self, base_validator): + """Validation should pass if no rule is set.""" + try: + base_validator._validate_allowed_values("any_value") + except ValidationError: + pytest.fail("Validation failed when no rule was set.") + + def test_allowed_values_valid(self, allowed_values_validator): + """Test valid allowed value.""" + try: + allowed_values_validator._validate_allowed_values("active") + except ValidationError: + pytest.fail("Failed on valid allowed value.") + + def test_allowed_values_invalid(self, allowed_values_validator): + """Test invalid allowed value raises ValidationError.""" + with pytest.raises(ValidationError): + allowed_values_validator._validate_allowed_values("deleted") + +class TestPathValidatorValidatePattern: + def test_pattern_no_rule(self, base_validator): + """Validation passes when no pattern rule.""" + try: + base_validator._validate_pattern("anything@123!@#") + except ValidationError: + pytest.fail("Validation failed when no pattern rule.") + + def test_pattern_valid_match(self, string_validator): + """Valid pattern match.""" + try: + string_validator._validate_pattern("user_123") + except ValidationError: + pytest.fail("Validation failed on valid pattern.") + + def test_pattern_invalid_match(self, string_validator): + """Invalid pattern raises ValidationError.""" + with pytest.raises(ValidationError): + string_validator._validate_pattern("user@123") + + def test_pattern_non_string_ignored(self, string_validator): + """Skip pattern validation for non-strings.""" + try: + string_validator._validate_pattern(123) + except ValidationError: + pytest.fail("Pattern validation should not apply to non-strings.") + + def test_pattern_email_like(self): + """Email pattern with valid and invalid cases.""" + validator = PathValidator(pattern=r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") + try: + validator._validate_pattern("user.name+tag@example.com") + except ValidationError: + pytest.fail("Valid email-like pattern failed") + with pytest.raises(ValidationError): + validator._validate_pattern("user@domain") # missing TLD + +# Length validation tests +class TestPathValidatorValidateLength: + def test_length_no_rule(self, base_validator): + """Validation passes when no length rule.""" + try: + base_validator._validate_length("x") + base_validator._validate_length("longer") + except ValidationError: + pytest.fail("Failed no length rule.") + + def test_length_valid_within_bounds(self, string_validator): + """Valid length within bounds.""" + try: + string_validator._validate_length("hello") + except ValidationError: + pytest.fail("Failed valid length.") + + def test_length_too_short(self, string_validator): + """Fails if shorter than min_length.""" + with pytest.raises(ValidationError): + string_validator._validate_length("ab") + + def test_length_too_long(self, string_validator): + """Fails if longer than max_length.""" + with pytest.raises(ValidationError): + string_validator._validate_length("a"*20) + +# Numeric bounds validation +class TestPathValidatorValidateNumericBounds: + def test_no_rule(self, base_validator): + try: + base_validator._validate_numeric_bounds(999) + base_validator._validate_numeric_bounds(-999) + except ValidationError: + pytest.fail("Failed no numeric rule.") + + def test_gt_lt(self, numeric_validator): + try: + numeric_validator._validate_numeric_bounds(1) + numeric_validator._validate_numeric_bounds(999) + except ValidationError: + pytest.fail("Failed valid bounds.") + with pytest.raises(ValidationError): + numeric_validator._validate_numeric_bounds(0) + + def test_ge_le(self): + validator = PathValidator(ge=0, le=10) + try: + validator._validate_numeric_bounds(0) + validator._validate_numeric_bounds(10) + except ValidationError: + pytest.fail("Failed boundary values.") + with pytest.raises(ValidationError): + validator._validate_numeric_bounds(-1) + +# Custom validation tests +class TestPathValidatorValidateCustom: + def test_no_custom_validator(self, base_validator): + try: + base_validator._validate_custom("test") + except ValidationError: + pytest.fail("Failed with no custom validator.") + def test_valid_custom(self): + def is_even(x): return x % 2 == 0 + v = PathValidator(validator=is_even) + try: + v._validate_custom(4) + except ValidationError: + pytest.fail("Valid custom validation failed.") + def test_invalid_custom(self): + def is_even(x): return x % 2 == 0 + v = PathValidator(validator=is_even) + with pytest.raises(ValidationError): + v._validate_custom(3) + +# Integration of multiple validations +class TestPathValidatorIntegration: + def test_combined_valid(self): + v = PathValidator(allowed_values=["ok"], pattern=r"^ok$", min_length=2, max_length=2) + try: + v._validate("ok") + except ValidationError: + pytest.fail("Valid data failed validation.") + + def test_fail_in_combined(self): + v = PathValidator(allowed_values=["ok"], pattern=r"^ok$", min_length=2, max_length=2) + with pytest.raises(HTTPException): + v._validate("no")