From 972fdb2a14279574e8f39bdd3fdcd117f93de1da Mon Sep 17 00:00:00 2001 From: lwk <3098293798@qq.com> Date: Fri, 13 Mar 2026 21:42:56 +0800 Subject: [PATCH] Refactor schema modules + custom validators (kimi) --- schema/__init__.py | 965 +--------------------------------- schema/_schema_constants.py | 53 ++ schema/_schema_core.py | 454 ++++++++++++++++ schema/_schema_exceptions.py | 69 +++ schema/_schema_json_schema.py | 313 +++++++++++ schema/_schema_types.py | 254 +++++++++ test_custom_validators.py | 385 ++++++++++++++ 7 files changed, 1546 insertions(+), 947 deletions(-) create mode 100644 schema/_schema_constants.py create mode 100644 schema/_schema_core.py create mode 100644 schema/_schema_exceptions.py create mode 100644 schema/_schema_json_schema.py create mode 100644 schema/_schema_types.py create mode 100644 test_custom_validators.py diff --git a/schema/__init__.py b/schema/__init__.py index 3a918d0..2d5a800 100644 --- a/schema/__init__.py +++ b/schema/__init__.py @@ -2,39 +2,26 @@ obtained from config-files, forms, external services or command-line parsing, converted from JSON/YAML (or something else) to Python data-types.""" -import inspect -import re -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Generic, - Iterable, - List, - NoReturn, - Sequence, - Set, - Sized, - Tuple, - Type, - TypeVar, - Union, - cast, -) - -# Use TYPE_CHECKING to determine the correct type hint but avoid runtime import errors -if TYPE_CHECKING: - # Only for type checking purposes, we import the standard ExitStack - from contextlib import ExitStack -else: - try: - from contextlib import ExitStack # Python 3.3 and later - except ImportError: - from contextlib2 import ExitStack # Python 2.x/3.0-3.2 fallback +__version__ = "0.7.8" +# Import all public classes and exceptions +from ._schema_core import Schema, Optional, Hook, Forbidden, Const +from ._schema_exceptions import ( + SchemaError, + SchemaWrongKeyError, + SchemaMissingKeyError, + SchemaForbiddenKeyError, + SchemaUnexpectedTypeError, + SchemaOnlyOneAllowedError, +) +from ._schema_types import ( + And, + Or, + Regex, + Use, + Literal, +) -__version__ = "0.7.8" __all__ = [ "Schema", "And", @@ -52,919 +39,3 @@ "SchemaUnexpectedTypeError", "SchemaOnlyOneAllowedError", ] - - -class SchemaError(Exception): - """Error during Schema validation.""" - - def __init__( - self, - autos: Union[Sequence[Union[str, None]], None], - errors: Union[List, str, None] = None, - ): - self.autos = autos if isinstance(autos, List) else [autos] - self.errors = errors if isinstance(errors, List) else [errors] - Exception.__init__(self, self.code) - - @property - def code(self) -> str: - """Remove duplicates in autos and errors list and combine them into a single message.""" - - def uniq(seq: Iterable[Union[str, None]]) -> List[str]: - """Utility function to remove duplicates while preserving the order.""" - seen: Set[str] = set() - unique_list: List[str] = [] - for x in seq: - if x is not None and x not in seen: - seen.add(x) - unique_list.append(x) - return unique_list - - data_set = uniq(self.autos) - error_list = uniq(self.errors) - - return "\n".join(error_list if error_list else data_set) - - -class SchemaWrongKeyError(SchemaError): - """Error Should be raised when an unexpected key is detected within the - data set being.""" - - pass - - -class SchemaMissingKeyError(SchemaError): - """Error should be raised when a mandatory key is not found within the - data set being validated""" - - pass - - -class SchemaOnlyOneAllowedError(SchemaError): - """Error should be raised when an only_one Or key has multiple matching candidates""" - - pass - - -class SchemaForbiddenKeyError(SchemaError): - """Error should be raised when a forbidden key is found within the - data set being validated, and its value matches the value that was specified""" - - pass - - -class SchemaUnexpectedTypeError(SchemaError): - """Error should be raised when a type mismatch is detected within the - data set being validated.""" - - pass - - -# Type variable to represent a Schema-like type -TSchema = TypeVar("TSchema", bound="Schema") - - -class And(Generic[TSchema]): - """ - Utility function to combine validation directives in AND Boolean fashion. - """ - - def __init__( - self, - *args: Any, - error: Union[str, None] = None, - ignore_extra_keys: bool = False, - schema: Union[Type[TSchema], None] = None, - ) -> None: - self._args: Tuple[Union[TSchema, Callable[..., Any]], ...] = args - self._error: Union[str, None] = error - self._ignore_extra_keys: bool = ignore_extra_keys - self._schema_class: Type[TSchema] = schema if schema is not None else Schema - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({', '.join(repr(a) for a in self._args)})" - - @property - def args(self) -> Tuple[Union[TSchema, Callable[..., Any]], ...]: - """The provided parameters""" - return self._args - - def validate(self, data: Any, **kwargs: Any) -> Any: - """ - Validate data using defined sub schema/expressions ensuring all - values are valid. - :param data: Data to be validated with sub defined schemas. - :return: Returns validated data. - """ - # Annotate sub_schema with the type returned by _build_schema - for sub_schema in self._build_schemas(): # type: TSchema - data = sub_schema.validate(data, **kwargs) - return data - - def _build_schemas(self) -> List[TSchema]: - return [self._build_schema(s) for s in self._args] - - def _build_schema(self, arg: Any) -> TSchema: - # Assume self._schema_class(arg, ...) returns an instance of TSchema - return self._schema_class( - arg, error=self._error, ignore_extra_keys=self._ignore_extra_keys - ) - - -class Or(And[TSchema]): - """Utility function to combine validation directives in a OR Boolean - fashion. - - If one wants to make an xor, one can provide only_one=True optional argument - to the constructor of this object. When a validation was performed for an - xor-ish Or instance and one wants to use it another time, one needs to call - reset() to put the match_count back to 0.""" - - def __init__( - self, - *args: Any, - only_one: bool = False, - **kwargs: Any, - ) -> None: - self.only_one: bool = only_one - self.match_count: int = 0 - super().__init__(*args, **kwargs) - - def reset(self) -> None: - failed: bool = self.match_count > 1 and self.only_one - self.match_count = 0 - if failed: - raise SchemaOnlyOneAllowedError( - ["There are multiple keys present from the %r condition" % self] - ) - - def validate(self, data: Any, **kwargs: Any) -> Any: - """ - Validate data using sub defined schema/expressions ensuring at least - one value is valid. - :param data: data to be validated by provided schema. - :return: return validated data if not validation - """ - autos: List[str] = [] - errors: List[Union[str, None]] = [] - for sub_schema in self._build_schemas(): - try: - validation: Any = sub_schema.validate(data, **kwargs) - self.match_count += 1 - if self.match_count > 1 and self.only_one: - break - return validation - except SchemaError as _x: - autos += _x.autos - errors += _x.errors - raise SchemaError( - ["%r did not validate %r" % (self, data)] + autos, - [self._error.format(data) if self._error else None] + errors, - ) - - -class Regex: - """ - Enables schema.py to validate string using regular expressions. - """ - - # Map all flags bits to a more readable description - NAMES = [ - "re.ASCII", - "re.DEBUG", - "re.VERBOSE", - "re.UNICODE", - "re.DOTALL", - "re.MULTILINE", - "re.LOCALE", - "re.IGNORECASE", - "re.TEMPLATE", - ] - - def __init__( - self, pattern_str: str, flags: int = 0, error: Union[str, None] = None - ) -> None: - self._pattern_str: str = pattern_str - flags_list = [ - Regex.NAMES[i] for i, f in enumerate(f"{flags:09b}") if f != "0" - ] # Name for each bit - - self._flags_names: str = ", flags=" + "|".join(flags_list) if flags_list else "" - self._pattern: re.Pattern = re.compile(pattern_str, flags=flags) - self._error: Union[str, None] = error - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self._pattern_str!r}{self._flags_names})" - - @property - def pattern_str(self) -> str: - """The pattern string for the represented regular expression""" - return self._pattern_str - - def validate(self, data: str, **kwargs: Any) -> str: - """ - Validates data using the defined regex. - :param data: Data to be validated. - :return: Returns validated data. - """ - e = self._error - - try: - if self._pattern.search(data): - return data - else: - error_message = ( - e.format(data) - if e - else f"{data!r} does not match {self._pattern_str!r}" - ) - raise SchemaError(error_message) - except TypeError: - error_message = ( - e.format(data) if e else f"{data!r} is not string nor buffer" - ) - raise SchemaError(error_message) - - -class Use: - """ - For more general use cases, you can use the Use class to transform - the data while it is being validated. - """ - - def __init__( - self, callable_: Callable[[Any], Any], error: Union[str, None] = None - ) -> None: - if not callable(callable_): - raise TypeError(f"Expected a callable, not {callable_!r}") - self._callable: Callable[[Any], Any] = callable_ - self._error: Union[str, None] = error - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self._callable!r})" - - def __call__(self, data: Any) -> Any: - """Make Use instances callable by delegating to the wrapped callable. - - This allows Use to work properly with And, Or, and other combinators - that expect callable arguments, while maintaining the validate() method - for the validator pattern. - """ - return self._callable(data) - - def validate(self, data: Any, **kwargs: Any) -> Any: - try: - return self._callable(data) - except SchemaError as x: - raise SchemaError( - [None] + x.autos, - [self._error.format(data) if self._error else None] + x.errors, - ) - except BaseException as x: - f = _callable_str(self._callable) - raise SchemaError( - "%s(%r) raised %r" % (f, data, x), - self._error.format(data) if self._error else None, - ) - - -COMPARABLE, CALLABLE, VALIDATOR, TYPE, DICT, ITERABLE = range(6) - - -def _priority(s: Any) -> int: - """Return priority for a given object.""" - if type(s) in (list, tuple, set, frozenset): - return ITERABLE - if isinstance(s, dict): - return DICT - if issubclass(type(s), type): - return TYPE - if isinstance(s, Literal): - return COMPARABLE - if hasattr(s, "validate"): - return VALIDATOR - if callable(s): - return CALLABLE - else: - return COMPARABLE - - -def _invoke_with_optional_kwargs(f: Callable[..., Any], **kwargs: Any) -> Any: - s = inspect.signature(f) - if len(s.parameters) == 0: - return f() - return f(**kwargs) - - -class Schema(object): - """ - Entry point of the library, use this class to instantiate validation - schema for the data that will be validated. - """ - - def __init__( - self, - schema: Any, - error: Union[str, None] = None, - ignore_extra_keys: bool = False, - name: Union[str, None] = None, - description: Union[str, None] = None, - as_reference: bool = False, - ) -> None: - self._schema: Any = schema - self._error: Union[str, None] = error - self._ignore_extra_keys: bool = ignore_extra_keys - self._name: Union[str, None] = name - self._description: Union[str, None] = description - self.as_reference: bool = as_reference - - if as_reference and name is None: - raise ValueError("Schema used as reference should have a name") - - def __repr__(self): - return "%s(%r)" % (self.__class__.__name__, self._schema) - - @property - def schema(self) -> Any: - return self._schema - - @property - def description(self) -> Union[str, None]: - return self._description - - @property - def name(self) -> Union[str, None]: - return self._name - - @property - def ignore_extra_keys(self) -> bool: - return self._ignore_extra_keys - - @staticmethod - def _dict_key_priority(s) -> float: - """Return priority for a given key object.""" - if isinstance(s, Hook): - return _priority(s._schema) - 0.5 - if isinstance(s, Optional): - return _priority(s._schema) + 0.5 - return _priority(s) - - @staticmethod - def _is_optional_type(s: Any) -> bool: - """Return True if the given key is optional (does not have to be found)""" - return any(isinstance(s, optional_type) for optional_type in [Optional, Hook]) - - def is_valid(self, data: Any, **kwargs: Dict[str, Any]) -> bool: - """Return whether the given data has passed all the validations - that were specified in the given schema. - """ - try: - self.validate(data, **kwargs) - except SchemaError: - return False - else: - return True - - def _prepend_schema_name(self, message: str) -> str: - """ - If a custom schema name has been defined, prepends it to the error - message that gets raised when a schema error occurs. - """ - if self._name: - message = "{0!r} {1!s}".format(self._name, message) - return message - - def validate(self, data: Any, **kwargs: Dict[str, Any]) -> Any: - Schema = self.__class__ - s: Any = self._schema - e: Union[str, None] = self._error - i: bool = self._ignore_extra_keys - - if isinstance(s, Literal): - s = s.schema - - flavor = _priority(s) - if flavor == ITERABLE: - data = Schema(type(s), error=e).validate(data, **kwargs) - o: Or = Or(*s, error=e, schema=Schema, ignore_extra_keys=i) - return type(data)(o.validate(d, **kwargs) for d in data) - if flavor == DICT: - exitstack = ExitStack() - data = Schema(dict, error=e).validate(data, **kwargs) - new: Dict = type(data)() # new - is a dict of the validated values - coverage: Set = set() # matched schema keys - # for each key and value find a schema entry matching them, if any - sorted_skeys = sorted(s, key=self._dict_key_priority) - for skey in sorted_skeys: - if hasattr(skey, "reset"): - exitstack.callback(skey.reset) - - with exitstack: - # Evaluate dictionaries last - data_items = sorted( - data.items(), key=lambda value: isinstance(value[1], dict) - ) - for key, value in data_items: - for skey in sorted_skeys: - svalue = s[skey] - try: - nkey = Schema(skey, error=e).validate(key, **kwargs) - except SchemaError: - pass - else: - if isinstance(skey, Hook): - # As the content of the value makes little sense for - # keys with a hook, we reverse its meaning: - # we will only call the handler if the value does match - # In the case of the forbidden key hook, - # we will raise the SchemaErrorForbiddenKey exception - # on match, allowing for excluding a key only if its - # value has a certain type, and allowing Forbidden to - # work well in combination with Optional. - try: - nvalue = Schema(svalue, error=e).validate( - value, **kwargs - ) - except SchemaError: - continue - skey.handler(nkey, data, e) - else: - try: - nvalue = Schema( - svalue, error=e, ignore_extra_keys=i - ).validate(value, **kwargs) - except SchemaError as x: - k = "Key '%s' error:" % nkey - message = self._prepend_schema_name(k) - raise SchemaError( - [message] + x.autos, - [e.format(data) if e else None] + x.errors, - ) - else: - new[nkey] = nvalue - coverage.add(skey) - break - required = set(k for k in s if not self._is_optional_type(k)) - if not required.issubset(coverage): - missing_keys = required - coverage - s_missing_keys = ", ".join( - repr(k) for k in sorted(missing_keys, key=repr) - ) - message = "Missing key%s: %s" % ( - _plural_s(missing_keys), - s_missing_keys, - ) - message = self._prepend_schema_name(message) - raise SchemaMissingKeyError(message, e.format(data) if e else None) - if not self._ignore_extra_keys and (len(new) != len(data)): - wrong_keys = set(data.keys()) - set(new.keys()) - s_wrong_keys = ", ".join(repr(k) for k in sorted(wrong_keys, key=repr)) - message = "Wrong key%s %s in %r" % ( - _plural_s(wrong_keys), - s_wrong_keys, - data, - ) - message = self._prepend_schema_name(message) - raise SchemaWrongKeyError(message, e.format(data) if e else None) - - # Apply default-having optionals that haven't been used: - defaults = ( - set(k for k in s if isinstance(k, Optional) and hasattr(k, "default")) - - coverage - ) - for default in defaults: - new[default.key] = ( - _invoke_with_optional_kwargs(default.default, **kwargs) - if callable(default.default) - else default.default - ) - - return new - if flavor == TYPE: - if isinstance(data, s) and not (isinstance(data, bool) and s == int): - return data - else: - message = "%r should be instance of %r" % (data, s.__name__) - message = self._prepend_schema_name(message) - raise SchemaUnexpectedTypeError(message, e.format(data) if e else None) - if flavor == VALIDATOR: - try: - return s.validate(data, **kwargs) - except SchemaError as x: - raise SchemaError( - [None] + x.autos, [e.format(data) if e else None] + x.errors - ) - except BaseException as x: - message = "%r.validate(%r) raised %r" % (s, data, x) - message = self._prepend_schema_name(message) - raise SchemaError(message, e.format(data) if e else None) - if flavor == CALLABLE: - f = _callable_str(s) - try: - if s(data): - return data - except SchemaError as x: - raise SchemaError( - [None] + x.autos, [e.format(data) if e else None] + x.errors - ) - except BaseException as x: - message = "%s(%r) raised %r" % (f, data, x) - message = self._prepend_schema_name(message) - raise SchemaError(message, e.format(data) if e else None) - message = "%s(%r) should evaluate to True" % (f, data) - message = self._prepend_schema_name(message) - raise SchemaError(message, e.format(data) if e else None) - if s == data: - return data - else: - message = "%r does not match %r" % (s, data) - message = self._prepend_schema_name(message) - raise SchemaError(message, e.format(data) if e else None) - - def json_schema( - self, schema_id: str, use_refs: bool = False, **kwargs: Any - ) -> Dict[str, Any]: - """Generate a draft-07 JSON schema dict representing the Schema. - This method must be called with a schema_id. - - :param schema_id: The value of the $id on the main schema - :param use_refs: Enable reusing object references in the resulting JSON schema. - Schemas with references are harder to read by humans, but are a lot smaller when there - is a lot of reuse - """ - - seen: Dict[int, Dict[str, Any]] = {} - definitions_by_name: Dict[str, Dict[str, Any]] = {} - - def _json_schema( - schema: "Schema", - is_main_schema: bool = True, - title: Union[str, None] = None, - description: Union[str, None] = None, - allow_reference: bool = True, - ) -> Dict[str, Any]: - def _create_or_use_ref(return_dict: Dict[str, Any]) -> Dict[str, Any]: - """If not already seen, return the provided part of the schema unchanged. - If already seen, give an id to the already seen dict and return a reference to the previous part - of the schema instead. - """ - if not use_refs or is_main_schema: - return return_schema - - hashed = hash(repr(sorted(return_dict.items()))) - if hashed not in seen: - seen[hashed] = return_dict - return return_dict - else: - id_str = "#" + str(hashed) - seen[hashed]["$id"] = id_str - return {"$ref": id_str} - - def _get_type_name(python_type: Type) -> str: - """Return the JSON schema name for a Python type""" - if python_type == str: - return "string" - elif python_type == int: - return "integer" - elif python_type == float: - return "number" - elif python_type == bool: - return "boolean" - elif python_type == list: - return "array" - elif python_type == dict: - return "object" - return "string" - - def _to_json_type(value: Any) -> Any: - """Attempt to convert a constant value (for "const" and "default") to a JSON serializable value""" - if value is None or type(value) in (str, int, float, bool, list, dict): - return value - - if type(value) in (tuple, set, frozenset): - return list(value) - - if isinstance(value, Literal): - return value.schema - - return str(value) - - def _to_schema(s: Any, ignore_extra_keys: bool) -> Schema: - if not isinstance(s, Schema): - return Schema(s, ignore_extra_keys=ignore_extra_keys) - - return s - - s: Any = schema.schema - i: bool = schema.ignore_extra_keys - flavor = _priority(s) - - return_schema: Dict[str, Any] = {} - - return_description: Union[str, None] = description or schema.description - if return_description: - return_schema["description"] = return_description - if title: - return_schema["title"] = title - - # Check if we have to create a common definition and use as reference - if allow_reference and schema.as_reference: - # Generate sub schema if not already done - if schema.name not in definitions_by_name: - definitions_by_name[ - cast(str, schema.name) - ] = {} # Avoid infinite loop - definitions_by_name[cast(str, schema.name)] = _json_schema( - schema, is_main_schema=False, allow_reference=False - ) - - return_schema["$ref"] = "#/definitions/" + cast(str, schema.name) - else: - if schema.name and not title: - return_schema["title"] = schema.name - - if flavor == TYPE: - # Handle type - return_schema["type"] = _get_type_name(s) - elif flavor == ITERABLE: - # Handle arrays or dict schema - - return_schema["type"] = "array" - if len(s) == 1: - return_schema["items"] = _json_schema( - _to_schema(s[0], i), is_main_schema=False - ) - elif len(s) > 1: - return_schema["items"] = _json_schema( - Schema(Or(*s)), is_main_schema=False - ) - elif isinstance(s, Or): - # Handle Or values - - # Check if we can use an enum - if all( - priority == COMPARABLE - for priority in [_priority(value) for value in s.args] - ): - or_values = [ - str(s) if isinstance(s, Literal) else s for s in s.args - ] - # All values are simple, can use enum or const - if len(or_values) == 1: - or_value = or_values[0] - if or_value is None: - return_schema["type"] = "null" - else: - return_schema["const"] = _to_json_type(or_value) - return return_schema - return_schema["enum"] = or_values - else: - # No enum, let's go with recursive calls - any_of_values = [] - for or_key in s.args: - new_value = _json_schema( - _to_schema(or_key, i), is_main_schema=False - ) - if new_value != {} and new_value not in any_of_values: - any_of_values.append(new_value) - if len(any_of_values) == 1: - # Only one representable condition remains, do not put under anyOf - return_schema.update(any_of_values[0]) - else: - return_schema["anyOf"] = any_of_values - elif isinstance(s, And): - # Handle And values - all_of_values = [] - for and_key in s.args: - new_value = _json_schema( - _to_schema(and_key, i), is_main_schema=False - ) - if new_value != {} and new_value not in all_of_values: - all_of_values.append(new_value) - if len(all_of_values) == 1: - # Only one representable condition remains, do not put under allOf - return_schema.update(all_of_values[0]) - else: - return_schema["allOf"] = all_of_values - elif flavor == COMPARABLE: - if s is None: - return_schema["type"] = "null" - else: - return_schema["const"] = _to_json_type(s) - elif flavor == VALIDATOR and type(s) == Regex: - return_schema["type"] = "string" - # JSON schema uses ECMAScript regex syntax - # Translating one to another is not easy, but this should work for simple cases - return_schema["pattern"] = re.sub( - r"\(\?P<[a-z\d_]+>", "(", s.pattern_str - ).replace("/", r"\/") - else: - if flavor != DICT: - # If not handled, do not check - return return_schema - - # Schema is a dict - - required_keys = [] - expanded_schema = {} - additional_properties = i - for key in s: - if isinstance(key, Hook): - continue - - def _key_allows_additional_properties(key: Any) -> bool: - """Check if a key is broad enough to allow additional properties""" - if isinstance(key, Optional): - return _key_allows_additional_properties(key.schema) - - return key == str or key == object - - def _get_key_title(key: Any) -> Union[str, None]: - """Get the title associated to a key (as specified in a Literal object). Return None if not a Literal""" - if isinstance(key, Optional): - return _get_key_title(key.schema) - - if isinstance(key, Literal): - return key.title - - return None - - def _get_key_description(key: Any) -> Union[str, None]: - """Get the description associated to a key (as specified in a Literal object). Return None if not a Literal""" - if isinstance(key, Optional): - return _get_key_description(key.schema) - - if isinstance(key, Literal): - return key.description - - return None - - def _get_key_name(key: Any) -> Any: - """Get the name of a key (as specified in a Literal object). Return the key unchanged if not a Literal""" - if isinstance(key, Optional): - return _get_key_name(key.schema) - - if isinstance(key, Literal): - return key.schema - - return key - - additional_properties = ( - additional_properties - or _key_allows_additional_properties(key) - ) - sub_schema = _to_schema(s[key], ignore_extra_keys=i) - key_name = _get_key_name(key) - - if isinstance(key_name, str): - if not isinstance(key, Optional): - required_keys.append(key_name) - expanded_schema[key_name] = _json_schema( - sub_schema, - is_main_schema=False, - title=_get_key_title(key), - description=_get_key_description(key), - ) - if isinstance(key, Optional) and hasattr(key, "default"): - expanded_schema[key_name]["default"] = _to_json_type( - _invoke_with_optional_kwargs(key.default, **kwargs) - if callable(key.default) - else key.default - ) - elif isinstance(key_name, Or): - # JSON schema does not support having a key named one name or another, so we just add both options - # This is less strict because we cannot enforce that one or the other is required - - for or_key in key_name.args: - expanded_schema[_get_key_name(or_key)] = _json_schema( - sub_schema, - is_main_schema=False, - description=_get_key_description(or_key), - ) - - return_schema.update( - { - "type": "object", - "properties": expanded_schema, - "required": required_keys, - "additionalProperties": additional_properties, - } - ) - - if is_main_schema: - return_schema.update( - { - "$id": schema_id, - "$schema": "http://json-schema.org/draft-07/schema#", - } - ) - if self._name: - return_schema["title"] = self._name - - if definitions_by_name: - return_schema["definitions"] = {} - for definition_name, definition in definitions_by_name.items(): - return_schema["definitions"][definition_name] = definition - - return _create_or_use_ref(return_schema) - - return _json_schema(self, True) - - -class Optional(Schema): - """Marker for an optional part of the validation Schema.""" - - _MARKER = object() - - def __init__(self, *args: Any, **kwargs: Any) -> None: - default: Any = kwargs.pop("default", self._MARKER) - super(Optional, self).__init__(*args, **kwargs) - if default is not self._MARKER: - if _priority(self._schema) != COMPARABLE: - raise TypeError( - "Optional keys with defaults must have simple, " - "predictable values, like literal strings or ints. " - f'"{self._schema!r}" is too complex.' - ) - self.default = default - self.key = str(self._schema) - - def __hash__(self) -> int: - return hash(self._schema) - - def __eq__(self, other: Any) -> bool: - return ( - self.__class__ is other.__class__ - and getattr(self, "default", self._MARKER) - == getattr(other, "default", self._MARKER) - and self._schema == other._schema - ) - - def reset(self) -> None: - if hasattr(self._schema, "reset"): - self._schema.reset() - - -class Hook(Schema): - def __init__(self, *args: Any, **kwargs: Any) -> None: - self.handler: Callable[..., Any] = kwargs.pop("handler", lambda *args: None) - super(Hook, self).__init__(*args, **kwargs) - self.key = self._schema - - -class Forbidden(Hook): - def __init__(self, *args: Any, **kwargs: Any) -> None: - kwargs["handler"] = self._default_function - super(Forbidden, self).__init__(*args, **kwargs) - - @staticmethod - def _default_function(nkey: Any, data: Any, error: Any) -> NoReturn: - raise SchemaForbiddenKeyError( - f"Forbidden key encountered: {nkey!r} in {data!r}", error - ) - - -class Literal: - def __init__( - self, - value: Any, - description: Union[str, None] = None, - title: Union[str, None] = None, - ) -> None: - self._schema: Any = value - self._description: Union[str, None] = description - self._title: Union[str, None] = title - - def __str__(self) -> str: - return str(self._schema) - - def __repr__(self) -> str: - return f'Literal("{self._schema}", description="{self._description or ""}")' - - @property - def description(self) -> Union[str, None]: - return self._description - - @property - def title(self) -> Union[str, None]: - return self._title - - @property - def schema(self) -> Any: - return self._schema - - -class Const(Schema): - def validate(self, data: Any, **kwargs: Any) -> Any: - super(Const, self).validate(data, **kwargs) - return data - - -def _callable_str(callable_: Callable[..., Any]) -> str: - if hasattr(callable_, "__name__"): - return callable_.__name__ - return str(callable_) - - -def _plural_s(sized: Sized) -> str: - return "s" if len(sized) > 1 else "" diff --git a/schema/_schema_constants.py b/schema/_schema_constants.py new file mode 100644 index 0000000..bbc7d27 --- /dev/null +++ b/schema/_schema_constants.py @@ -0,0 +1,53 @@ +"""Constants and utility functions for schema validation.""" + +import inspect +from typing import Any, Callable, Sized + +# Priority constants for schema types +COMPARABLE, CALLABLE, VALIDATOR, TYPE, DICT, ITERABLE = range(6) + + +def _callable_str(callable_: Callable[..., Any]) -> str: + """Get a string representation of a callable.""" + if hasattr(callable_, "__name__"): + return callable_.__name__ + return str(callable_) + + +def _plural_s(sized: Sized) -> str: + """Return 's' if the sized object has more than one element.""" + return "s" if len(sized) > 1 else "" + + +def _invoke_with_optional_kwargs(f: Callable[..., Any], **kwargs: Any) -> Any: + """Invoke a function with optional kwargs if it accepts them.""" + s = inspect.signature(f) + if len(s.parameters) == 0: + return f() + return f(**kwargs) + + +def _priority(s: Any) -> int: + """Return priority for a given object.""" + if type(s) in (list, tuple, set, frozenset): + return ITERABLE + if isinstance(s, dict): + return DICT + issubclass_ = False + try: + issubclass_ = issubclass(type(s), type) + except TypeError: + pass + if issubclass_: + return TYPE + # Import here to avoid circular imports + from ._schema_types import Literal + + if isinstance(s, Literal): + return COMPARABLE + if hasattr(s, "validate"): + return VALIDATOR + if callable(s): + return CALLABLE + else: + return COMPARABLE diff --git a/schema/_schema_core.py b/schema/_schema_core.py new file mode 100644 index 0000000..cdd9bae --- /dev/null +++ b/schema/_schema_core.py @@ -0,0 +1,454 @@ +"""Core Schema class with validation logic.""" + +from typing import ( + Any, + Callable, + Dict, + List, + NoReturn, + Set, + TypeVar, + Union, +) + +# Use TYPE_CHECKING to determine the correct type hint but avoid runtime import errors +try: + from contextlib import ExitStack # Python 3.3 and later +except ImportError: + from contextlib2 import ExitStack # Python 2.x/3.0-3.2 fallback + +from ._schema_constants import ( + CALLABLE, + COMPARABLE, + DICT, + ITERABLE, + TYPE, + VALIDATOR, + _invoke_with_optional_kwargs, + _plural_s, + _priority, +) +from ._schema_exceptions import ( + SchemaError, + SchemaForbiddenKeyError, + SchemaMissingKeyError, + SchemaUnexpectedTypeError, + SchemaWrongKeyError, +) + +# Type variable to represent a Schema-like type +TSchema = TypeVar("TSchema", bound="Schema") + + +class Schema(object): + """ + Entry point of the library, use this class to instantiate validation + schema for the data that will be validated. + """ + + # Registry for custom validators + _custom_validators: Dict[str, Callable[[Any], Any]] = {} + + def __init__( + self, + schema: Any, + error: Union[str, None] = None, + ignore_extra_keys: bool = False, + name: Union[str, None] = None, + description: Union[str, None] = None, + as_reference: bool = False, + ) -> None: + self._schema: Any = schema + self._error: Union[str, None] = error + self._ignore_extra_keys: bool = ignore_extra_keys + self._name: Union[str, None] = name + self._description: Union[str, None] = description + self.as_reference: bool = as_reference + + if as_reference and name is None: + raise ValueError("Schema used as reference should have a name") + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._schema) + + @property + def schema(self) -> Any: + return self._schema + + @property + def description(self) -> Union[str, None]: + return self._description + + @property + def name(self) -> Union[str, None]: + return self._name + + @property + def ignore_extra_keys(self) -> bool: + return self._ignore_extra_keys + + @classmethod + def _dict_key_priority(cls, s) -> float: + """Return priority for a given key object.""" + if isinstance(s, Hook): + return _priority(s._schema) - 0.5 + if isinstance(s, Optional): + return _priority(s._schema) + 0.5 + return _priority(s) + + @classmethod + def _is_optional_type(cls, s: Any) -> bool: + """Return True if the given key is optional (does not have to be found)""" + return any(isinstance(s, optional_type) for optional_type in [Optional, Hook]) + + @classmethod + def register_validator(cls, name: str, validator: Callable[[Any], Any]) -> None: + """Register a custom validator by name. + + :param name: The name to register the validator under + :param validator: A callable that takes data and returns validated data or raises SchemaError + """ + if not callable(validator): + raise TypeError(f"Validator must be callable, got {type(validator)}") + cls._custom_validators[name] = validator + + @classmethod + def unregister_validator(cls, name: str) -> None: + """Unregister a custom validator by name. + + :param name: The name of the validator to unregister + """ + cls._custom_validators.pop(name, None) + + @classmethod + def get_registered_validators(cls) -> Dict[str, Callable[[Any], Any]]: + """Get a copy of all registered custom validators. + + :return: A dictionary mapping validator names to their callables + """ + return cls._custom_validators.copy() + + def is_valid(self, data: Any, **kwargs: Dict[str, Any]) -> bool: + """Return whether the given data has passed all the validations + that were specified in the given schema. + """ + try: + self.validate(data, **kwargs) + except SchemaError: + return False + else: + return True + + def _prepend_schema_name(self, message: str) -> str: + """ + If a custom schema name has been defined, prepends it to the error + message that gets raised when a schema error occurs. + """ + if self._name: + message = "{0!r} {1!s}".format(self._name, message) + return message + + def _validate_iterable(self, data: Any, s: Any, e: Union[str, None], i: bool, **kwargs: Any) -> Any: + """Validate iterable types (list, tuple, set, frozenset).""" + # Import here to avoid circular imports + from ._schema_types import Or + + SchemaClass = self.__class__ + data = SchemaClass(type(s), error=e).validate(data, **kwargs) + o: Or = Or(*s, error=e, schema=SchemaClass, ignore_extra_keys=i) + return type(data)(o.validate(d, **kwargs) for d in data) + + def _validate_dict(self, data: Any, s: Any, e: Union[str, None], i: bool, **kwargs: Any) -> Any: + """Validate dictionary types.""" + + SchemaClass = self.__class__ + exitstack = ExitStack() + data = SchemaClass(dict, error=e).validate(data, **kwargs) + new: Dict = type(data)() # new - is a dict of the validated values + coverage: Set = set() # matched schema keys + # for each key and value find a schema entry matching them, if any + sorted_skeys = sorted(s, key=self._dict_key_priority) + for skey in sorted_skeys: + if hasattr(skey, "reset"): + exitstack.callback(skey.reset) + + # Create kwargs for key validation with custom validators disabled + key_kwargs = kwargs.copy() + key_kwargs["_disable_custom_validators"] = True + + # Create clean kwargs for value validation (without internal params) + clean_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + + with exitstack: + # Evaluate dictionaries last + data_items = sorted( + data.items(), key=lambda value: isinstance(value[1], dict) + ) + for key, value in data_items: + for skey in sorted_skeys: + svalue = s[skey] + try: + # Use base Schema class for key validation to avoid issues with + # custom Schema subclasses that may not accept internal kwargs + nkey = Schema(skey, error=e).validate(key, **key_kwargs) + except SchemaError: + pass + else: + if isinstance(skey, Hook): + # As the content of the value makes little sense for + # keys with a hook, we reverse its meaning: + # we will only call the handler if the value does match + # In the case of the forbidden key hook, + # we will raise the SchemaErrorForbiddenKey exception + # on match, allowing for excluding a key only if its + # value has a certain type, and allowing Forbidden to + # work well in combination with Optional. + try: + nvalue = SchemaClass(svalue, error=e).validate( + value, **clean_kwargs + ) + except SchemaError: + continue + skey.handler(nkey, data, e) + else: + try: + nvalue = SchemaClass( + svalue, error=e, ignore_extra_keys=i + ).validate(value, **clean_kwargs) + except SchemaError as x: + k = "Key '%s' error:" % nkey + message = self._prepend_schema_name(k) + raise SchemaError( + [message] + x.autos, + [e.format(data) if e else None] + x.errors, + ) + else: + new[nkey] = nvalue + coverage.add(skey) + break + required = set(k for k in s if not self._is_optional_type(k)) + if not required.issubset(coverage): + missing_keys = required - coverage + s_missing_keys = ", ".join( + repr(k) for k in sorted(missing_keys, key=repr) + ) + message = "Missing key%s: %s" % ( + _plural_s(missing_keys), + s_missing_keys, + ) + message = self._prepend_schema_name(message) + raise SchemaMissingKeyError(message, e.format(data) if e else None) + if not self._ignore_extra_keys and (len(new) != len(data)): + wrong_keys = set(data.keys()) - set(new.keys()) + s_wrong_keys = ", ".join(repr(k) for k in sorted(wrong_keys, key=repr)) + message = "Wrong key%s %s in %r" % ( + _plural_s(wrong_keys), + s_wrong_keys, + data, + ) + message = self._prepend_schema_name(message) + raise SchemaWrongKeyError(message, e.format(data) if e else None) + + # Apply default-having optionals that haven't been used: + defaults = ( + set(k for k in s if isinstance(k, Optional) and hasattr(k, "default")) + - coverage + ) + for default in defaults: + new[default.key] = ( + _invoke_with_optional_kwargs(default.default, **kwargs) + if callable(default.default) + else default.default + ) + + return new + + def _validate_type(self, data: Any, s: Any, e: Union[str, None], **kwargs: Any) -> Any: + """Validate type constraints.""" + if isinstance(data, s) and not (isinstance(data, bool) and s == int): + return data + else: + message = "%r should be instance of %r" % (data, s.__name__) + message = self._prepend_schema_name(message) + raise SchemaUnexpectedTypeError(message, e.format(data) if e else None) + + def _validate_validator(self, data: Any, s: Any, e: Union[str, None], **kwargs: Any) -> Any: + """Validate using a validator object (has validate method).""" + # Remove internal kwargs before calling external validator's validate method + clean_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + try: + return s.validate(data, **clean_kwargs) + except SchemaError as x: + raise SchemaError( + [None] + x.autos, [e.format(data) if e else None] + x.errors + ) + except BaseException as x: + message = "%r.validate(%r) raised %r" % (s, data, x) + message = self._prepend_schema_name(message) + raise SchemaError(message, e.format(data) if e else None) + + def _validate_callable(self, data: Any, s: Any, e: Union[str, None], **kwargs: Any) -> Any: + """Validate using a callable function.""" + from ._schema_constants import _callable_str + + f = _callable_str(s) + try: + if s(data): + return data + except SchemaError as x: + raise SchemaError( + [None] + x.autos, [e.format(data) if e else None] + x.errors + ) + except BaseException as x: + message = "%s(%r) raised %r" % (f, data, x) + message = self._prepend_schema_name(message) + raise SchemaError(message, e.format(data) if e else None) + message = "%s(%r) should evaluate to True" % (f, data) + message = self._prepend_schema_name(message) + raise SchemaError(message, e.format(data) if e else None) + + def _validate_comparable(self, data: Any, s: Any, e: Union[str, None], **kwargs: Any) -> Any: + """Validate by direct comparison.""" + if s == data: + return data + else: + message = "%r does not match %r" % (s, data) + message = self._prepend_schema_name(message) + raise SchemaError(message, e.format(data) if e else None) + + def _validate_custom(self, data: Any, validator_name: str, e: Union[str, None], **kwargs: Any) -> Any: + """Validate using a registered custom validator.""" + if validator_name not in self._custom_validators: + raise SchemaError( + f"Unknown custom validator: {validator_name!r}", + e.format(data) if e else None + ) + validator = self._custom_validators[validator_name] + try: + return validator(data) + except SchemaError as x: + raise SchemaError( + [None] + x.autos, [e.format(data) if e else None] + x.errors + ) + except BaseException as x: + from ._schema_constants import _callable_str + + f = _callable_str(validator) + message = "%s(%r) raised %r" % (f, data, x) + message = self._prepend_schema_name(message) + raise SchemaError(message, e.format(data) if e else None) + + def validate(self, data: Any, **kwargs: Any) -> Any: + """Validate data against the schema. + + This method dispatches to specific validation methods based on the + schema type (flavor). + """ + Schema = self.__class__ + s: Any = self._schema + e: Union[str, None] = self._error + i: bool = self._ignore_extra_keys + + # Import here to avoid circular imports + from ._schema_types import Literal + + if isinstance(s, Literal): + s = s.schema + + # Check if s is a string that refers to a custom validator + # Skip if _disable_custom_validators is set (used for dict key validation) + if isinstance(s, str) and s in self._custom_validators and not kwargs.get("_disable_custom_validators"): + return self._validate_custom(data, s, e, **kwargs) + + # Remove internal kwargs before dispatching to avoid passing them to nested schemas + # that might not accept them (e.g., user-defined Schema subclasses) + clean_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + + flavor = _priority(s) + if flavor == ITERABLE: + return self._validate_iterable(data, s, e, i, **clean_kwargs) + if flavor == DICT: + return self._validate_dict(data, s, e, i, **clean_kwargs) + if flavor == TYPE: + return self._validate_type(data, s, e, **clean_kwargs) + if flavor == VALIDATOR: + return self._validate_validator(data, s, e, **clean_kwargs) + if flavor == CALLABLE: + return self._validate_callable(data, s, e, **clean_kwargs) + # COMPARABLE and other cases + return self._validate_comparable(data, s, e, **clean_kwargs) + + def json_schema( + self, schema_id: str, use_refs: bool = False, **kwargs: Any + ) -> Dict[str, Any]: + """Generate a draft-07 JSON schema dict representing the Schema. + This method must be called with a schema_id. + + :param schema_id: The value of the $id on the main schema + :param use_refs: Enable reusing object references in the resulting JSON schema. + Schemas with references are harder to read by humans, but are a lot smaller when there + is a lot of reuse + """ + from ._schema_json_schema import JsonSchemaGenerator + + generator = JsonSchemaGenerator(self, schema_id, use_refs, **kwargs) + return generator.generate() + + +class Optional(Schema): + """Marker for an optional part of the validation Schema.""" + + _MARKER = object() + + def __init__(self, *args: Any, **kwargs: Any) -> None: + default: Any = kwargs.pop("default", self._MARKER) + super(Optional, self).__init__(*args, **kwargs) + if default is not self._MARKER: + if _priority(self._schema) != COMPARABLE: + raise TypeError( + "Optional keys with defaults must have simple, " + "predictable values, like literal strings or ints. " + f'"{self._schema!r}" is too complex.' + ) + self.default = default + self.key = str(self._schema) + + def __hash__(self) -> int: + return hash(self._schema) + + def __eq__(self, other: Any) -> bool: + return ( + self.__class__ is other.__class__ + and getattr(self, "default", self._MARKER) + == getattr(other, "default", self._MARKER) + and self._schema == other._schema + ) + + def reset(self) -> None: + if hasattr(self._schema, "reset"): + self._schema.reset() + + +class Hook(Schema): + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.handler: Callable[..., Any] = kwargs.pop("handler", lambda *args: None) + super(Hook, self).__init__(*args, **kwargs) + self.key = self._schema + + +class Forbidden(Hook): + def __init__(self, *args: Any, **kwargs: Any) -> None: + kwargs["handler"] = self._default_function + super(Forbidden, self).__init__(*args, **kwargs) + + @staticmethod + def _default_function(nkey: Any, data: Any, error: Any) -> NoReturn: + raise SchemaForbiddenKeyError( + f"Forbidden key encountered: {nkey!r} in {data!r}", error + ) + + +class Const(Schema): + def validate(self, data: Any, **kwargs: Any) -> Any: + super(Const, self).validate(data, **kwargs) + return data diff --git a/schema/_schema_exceptions.py b/schema/_schema_exceptions.py new file mode 100644 index 0000000..f10871f --- /dev/null +++ b/schema/_schema_exceptions.py @@ -0,0 +1,69 @@ +"""Exception classes for schema validation.""" + +from typing import Iterable, List, Sequence, Set, Union + + +class SchemaError(Exception): + """Error during Schema validation.""" + + def __init__( + self, + autos: Union[Sequence[Union[str, None]], None], + errors: Union[List, str, None] = None, + ): + self.autos = autos if isinstance(autos, List) else [autos] + self.errors = errors if isinstance(errors, List) else [errors] + Exception.__init__(self, self.code) + + @property + def code(self) -> str: + """Remove duplicates in autos and errors list and combine them into a single message.""" + + def uniq(seq: Iterable[Union[str, None]]) -> List[str]: + """Utility function to remove duplicates while preserving the order.""" + seen: Set[str] = set() + unique_list: List[str] = [] + for x in seq: + if x is not None and x not in seen: + seen.add(x) + unique_list.append(x) + return unique_list + + data_set = uniq(self.autos) + error_list = uniq(self.errors) + + return "\n".join(error_list if error_list else data_set) + + +class SchemaWrongKeyError(SchemaError): + """Error Should be raised when an unexpected key is detected within the + data set being.""" + + pass + + +class SchemaMissingKeyError(SchemaError): + """Error should be raised when a mandatory key is not found within the + data set being validated""" + + pass + + +class SchemaOnlyOneAllowedError(SchemaError): + """Error should be raised when an only_one Or key has multiple matching candidates""" + + pass + + +class SchemaForbiddenKeyError(SchemaError): + """Error should be raised when a forbidden key is found within the + data set being validated, and its value matches the value that was specified""" + + pass + + +class SchemaUnexpectedTypeError(SchemaError): + """Error should be raised when a type mismatch is detected within the + data set being validated.""" + + pass diff --git a/schema/_schema_json_schema.py b/schema/_schema_json_schema.py new file mode 100644 index 0000000..2364f69 --- /dev/null +++ b/schema/_schema_json_schema.py @@ -0,0 +1,313 @@ +"""JSON Schema generation logic.""" + +import re +from typing import Any, Dict, Union, cast + +from ._schema_constants import ( + COMPARABLE, + DICT, + ITERABLE, + TYPE, + VALIDATOR, + _invoke_with_optional_kwargs, + _priority, +) +from ._schema_core import Optional, Hook +from ._schema_types import And, Literal, Or, Regex + + +class JsonSchemaGenerator: + """Generator for JSON Schema output.""" + + def __init__( + self, + schema: "Schema", + schema_id: str, + use_refs: bool = False, + **kwargs: Any, + ) -> None: + self.schema = schema + self.schema_id = schema_id + self.use_refs = use_refs + self.kwargs = kwargs + self.seen: Dict[int, Dict[str, Any]] = {} + self.definitions_by_name: Dict[str, Dict[str, Any]] = {} + + def generate(self) -> Dict[str, Any]: + """Generate the JSON Schema.""" + return self._json_schema(self.schema, True) + + def _json_schema( + self, + schema: "Schema", + is_main_schema: bool = True, + title: Union[str, None] = None, + description: Union[str, None] = None, + allow_reference: bool = True, + ) -> Dict[str, Any]: + def _create_or_use_ref(return_dict: Dict[str, Any]) -> Dict[str, Any]: + """If not already seen, return the provided part of the schema unchanged. + If already seen, give an id to the already seen dict and return a reference to the previous part + of the schema instead. + """ + if not self.use_refs or is_main_schema: + return return_schema + + hashed = hash(repr(sorted(return_dict.items()))) + if hashed not in self.seen: + self.seen[hashed] = return_dict + return return_dict + else: + id_str = "#" + str(hashed) + self.seen[hashed]["$id"] = id_str + return {"$ref": id_str} + + def _get_type_name(python_type: type) -> str: + """Return the JSON schema name for a Python type""" + if python_type == str: + return "string" + elif python_type == int: + return "integer" + elif python_type == float: + return "number" + elif python_type == bool: + return "boolean" + elif python_type == list: + return "array" + elif python_type == dict: + return "object" + return "string" + + def _to_json_type(value: Any) -> Any: + """Attempt to convert a constant value (for "const" and "default") to a JSON serializable value""" + if value is None or type(value) in (str, int, float, bool, list, dict): + return value + + if type(value) in (tuple, set, frozenset): + return list(value) + + if isinstance(value, Literal): + return value.schema + + return str(value) + + def _to_schema(s: Any, ignore_extra_keys: bool) -> "Schema": + from ._schema_core import Schema + + if not isinstance(s, Schema): + return Schema(s, ignore_extra_keys=ignore_extra_keys) + + return s + + s: Any = schema.schema + i: bool = schema.ignore_extra_keys + flavor = _priority(s) + + return_schema: Dict[str, Any] = {} + + return_description: Union[str, None] = description or schema.description + if return_description: + return_schema["description"] = return_description + if title: + return_schema["title"] = title + + # Check if we have to create a common definition and use as reference + if allow_reference and schema.as_reference: + # Generate sub schema if not already done + if schema.name not in self.definitions_by_name: + self.definitions_by_name[ + cast(str, schema.name) + ] = {} # Avoid infinite loop + self.definitions_by_name[cast(str, schema.name)] = self._json_schema( + schema, is_main_schema=False, allow_reference=False + ) + + return_schema["$ref"] = "#/definitions/" + cast(str, schema.name) + else: + if schema.name and not title: + return_schema["title"] = schema.name + + if flavor == TYPE: + # Handle type + return_schema["type"] = _get_type_name(s) + elif flavor == ITERABLE: + # Handle arrays or dict schema + + return_schema["type"] = "array" + if len(s) == 1: + return_schema["items"] = self._json_schema( + _to_schema(s[0], i), is_main_schema=False + ) + elif len(s) > 1: + return_schema["items"] = self._json_schema( + _to_schema(Or(*s), i), is_main_schema=False + ) + elif isinstance(s, Or): + # Handle Or values + + # Check if we can use an enum + if all( + priority == COMPARABLE + for priority in [_priority(value) for value in s.args] + ): + or_values = [ + str(s) if isinstance(s, Literal) else s for s in s.args + ] + # All values are simple, can use enum or const + if len(or_values) == 1: + or_value = or_values[0] + if or_value is None: + return_schema["type"] = "null" + else: + return_schema["const"] = _to_json_type(or_value) + return return_schema + return_schema["enum"] = or_values + else: + # No enum, let's go with recursive calls + any_of_values = [] + for or_key in s.args: + new_value = self._json_schema( + _to_schema(or_key, i), is_main_schema=False + ) + if new_value != {} and new_value not in any_of_values: + any_of_values.append(new_value) + if len(any_of_values) == 1: + # Only one representable condition remains, do not put under anyOf + return_schema.update(any_of_values[0]) + else: + return_schema["anyOf"] = any_of_values + elif isinstance(s, And): + # Handle And values + all_of_values = [] + for and_key in s.args: + new_value = self._json_schema( + _to_schema(and_key, i), is_main_schema=False + ) + if new_value != {} and new_value not in all_of_values: + all_of_values.append(new_value) + if len(all_of_values) == 1: + # Only one representable condition remains, do not put under allOf + return_schema.update(all_of_values[0]) + else: + return_schema["allOf"] = all_of_values + elif flavor == COMPARABLE: + if s is None: + return_schema["type"] = "null" + else: + return_schema["const"] = _to_json_type(s) + elif flavor == VALIDATOR and type(s) == Regex: + return_schema["type"] = "string" + # JSON schema uses ECMAScript regex syntax + # Translating one to another is not easy, but this should work for simple cases + return_schema["pattern"] = re.sub( + r"\(\?P<[a-z\d_]+>", "(", s.pattern_str + ).replace("/", r"\/") + else: + if flavor != DICT: + # If not handled, do not check + return return_schema + + # Schema is a dict + + required_keys = [] + expanded_schema = {} + additional_properties = i + for key in s: + if isinstance(key, Hook): + continue + + def _key_allows_additional_properties(key: Any) -> bool: + """Check if a key is broad enough to allow additional properties""" + if isinstance(key, Optional): + return _key_allows_additional_properties(key.schema) + + return key == str or key == object + + def _get_key_title(key: Any) -> Union[str, None]: + """Get the title associated to a key (as specified in a Literal object). Return None if not a Literal""" + if isinstance(key, Optional): + return _get_key_title(key.schema) + + if isinstance(key, Literal): + return key.title + + return None + + def _get_key_description(key: Any) -> Union[str, None]: + """Get the description associated to a key (as specified in a Literal object). Return None if not a Literal""" + if isinstance(key, Optional): + return _get_key_description(key.schema) + + if isinstance(key, Literal): + return key.description + + return None + + def _get_key_name(key: Any) -> Any: + """Get the name of a key (as specified in a Literal object). Return the key unchanged if not a Literal""" + if isinstance(key, Optional): + return _get_key_name(key.schema) + + if isinstance(key, Literal): + return key.schema + + return key + + additional_properties = ( + additional_properties + or _key_allows_additional_properties(key) + ) + sub_schema = _to_schema(s[key], ignore_extra_keys=i) + key_name = _get_key_name(key) + + if isinstance(key_name, str): + if not isinstance(key, Optional): + required_keys.append(key_name) + expanded_schema[key_name] = self._json_schema( + sub_schema, + is_main_schema=False, + title=_get_key_title(key), + description=_get_key_description(key), + ) + if isinstance(key, Optional) and hasattr(key, "default"): + expanded_schema[key_name]["default"] = _to_json_type( + _invoke_with_optional_kwargs(key.default, **self.kwargs) + if callable(key.default) + else key.default + ) + elif isinstance(key_name, Or): + # JSON schema does not support having a key named one name or another, so we just add both options + # This is less strict because we cannot enforce that one or the other is required + + for or_key in key_name.args: + expanded_schema[_get_key_name(or_key)] = self._json_schema( + sub_schema, + is_main_schema=False, + description=_get_key_description(or_key), + ) + + return_schema.update( + { + "type": "object", + "properties": expanded_schema, + "required": required_keys, + "additionalProperties": additional_properties, + } + ) + + if is_main_schema: + return_schema.update( + { + "$id": self.schema_id, + "$schema": "http://json-schema.org/draft-07/schema#", + } + ) + if self.schema.name: + return_schema["title"] = self.schema.name + + if self.definitions_by_name: + return_schema["definitions"] = {} + for definition_name, definition in self.definitions_by_name.items(): + return_schema["definitions"][definition_name] = definition + + return _create_or_use_ref(return_schema) diff --git a/schema/_schema_types.py b/schema/_schema_types.py new file mode 100644 index 0000000..bd8a74b --- /dev/null +++ b/schema/_schema_types.py @@ -0,0 +1,254 @@ +"""Type definitions and marker classes for schema validation.""" + +from typing import Any, Callable, Generic, List, Tuple, Type, TypeVar, Union + +from ._schema_exceptions import SchemaError + +# Type variable to represent a Schema-like type +TSchema = TypeVar("TSchema", bound="Schema") + + +class And(Generic[TSchema]): + """ + Utility function to combine validation directives in AND Boolean fashion. + """ + + def __init__( + self, + *args: Any, + error: Union[str, None] = None, + ignore_extra_keys: bool = False, + schema: Union[Type[TSchema], None] = None, + ) -> None: + self._args: Tuple[Union[TSchema, Callable[..., Any]], ...] = args + self._error: Union[str, None] = error + self._ignore_extra_keys: bool = ignore_extra_keys + self._schema_class: Type[TSchema] = schema if schema is not None else Schema + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({', '.join(repr(a) for a in self._args)})" + + @property + def args(self) -> Tuple[Union[TSchema, Callable[..., Any]], ...]: + """The provided parameters""" + return self._args + + def validate(self, data: Any, **kwargs: Any) -> Any: + """ + Validate data using defined sub schema/expressions ensuring all + values are valid. + :param data: Data to be validated with sub defined schemas. + :return: Returns validated data. + """ + # Annotate sub_schema with the type returned by _build_schema + for sub_schema in self._build_schemas(): # type: TSchema + data = sub_schema.validate(data, **kwargs) + return data + + def _build_schemas(self) -> List[TSchema]: + return [self._build_schema(s) for s in self._args] + + def _build_schema(self, arg: Any) -> TSchema: + # Assume self._schema_class(arg, ...) returns an instance of TSchema + return self._schema_class( + arg, error=self._error, ignore_extra_keys=self._ignore_extra_keys + ) + + +class Or(And[TSchema]): + """Utility function to combine validation directives in a OR Boolean + fashion. + + If one wants to make an xor, one can provide only_one=True optional argument + to the constructor of this object. When a validation was performed for an + xor-ish Or instance and one wants to use it another time, one needs to call + reset() to put the match_count back to 0.""" + + def __init__( + self, + *args: Any, + only_one: bool = False, + **kwargs: Any, + ) -> None: + self.only_one: bool = only_one + self.match_count: int = 0 + super().__init__(*args, **kwargs) + + def reset(self) -> None: + failed: bool = self.match_count > 1 and self.only_one + self.match_count = 0 + if failed: + from ._schema_exceptions import SchemaOnlyOneAllowedError + + raise SchemaOnlyOneAllowedError( + ["There are multiple keys present from the %r condition" % self] + ) + + def validate(self, data: Any, **kwargs: Any) -> Any: + """ + Validate data using sub defined schema/expressions ensuring at least + one value is valid. + :param data: data to be validated by provided schema. + :return: return validated data if not validation + """ + autos: List[str] = [] + errors: List[Union[str, None]] = [] + for sub_schema in self._build_schemas(): + try: + validation: Any = sub_schema.validate(data, **kwargs) + self.match_count += 1 + if self.match_count > 1 and self.only_one: + break + return validation + except SchemaError as _x: + autos += _x.autos + errors += _x.errors + raise SchemaError( + ["%r did not validate %r" % (self, data)] + autos, + [self._error.format(data) if self._error else None] + errors, + ) + + +class Regex: + """ + Enables schema.py to validate string using regular expressions. + """ + + # Map all flags bits to a more readable description + NAMES = [ + "re.ASCII", + "re.DEBUG", + "re.VERBOSE", + "re.UNICODE", + "re.DOTALL", + "re.MULTILINE", + "re.LOCALE", + "re.IGNORECASE", + "re.TEMPLATE", + ] + + def __init__( + self, pattern_str: str, flags: int = 0, error: Union[str, None] = None + ) -> None: + import re + + self._pattern_str: str = pattern_str + flags_list = [ + Regex.NAMES[i] for i, f in enumerate(f"{flags:09b}") if f != "0" + ] # Name for each bit + + self._flags_names: str = ", flags=" + "|".join(flags_list) if flags_list else "" + self._pattern: re.Pattern = re.compile(pattern_str, flags=flags) + self._error: Union[str, None] = error + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._pattern_str!r}{self._flags_names})" + + @property + def pattern_str(self) -> str: + """The pattern string for the represented regular expression""" + return self._pattern_str + + def validate(self, data: str, **kwargs: Any) -> str: + """ + Validates data using the defined regex. + :param data: Data to be validated. + :return: Returns validated data. + """ + e = self._error + + try: + if self._pattern.search(data): + return data + else: + error_message = ( + e.format(data) + if e + else f"{data!r} does not match {self._pattern_str!r}" + ) + raise SchemaError(error_message) + except TypeError: + error_message = ( + e.format(data) if e else f"{data!r} is not string nor buffer" + ) + raise SchemaError(error_message) + + +class Use: + """ + For more general use cases, you can use the Use class to transform + the data while it is being validated. + """ + + def __init__( + self, callable_: Callable[[Any], Any], error: Union[str, None] = None + ) -> None: + if not callable(callable_): + raise TypeError(f"Expected a callable, not {callable_!r}") + self._callable: Callable[[Any], Any] = callable_ + self._error: Union[str, None] = error + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._callable!r})" + + def __call__(self, data: Any) -> Any: + """Make Use instances callable by delegating to the wrapped callable. + + This allows Use to work properly with And, Or, and other combinators + that expect callable arguments, while maintaining the validate() method + for the validator pattern. + """ + return self._callable(data) + + def validate(self, data: Any, **kwargs: Any) -> Any: + try: + return self._callable(data) + except SchemaError as x: + from ._schema_constants import _callable_str + + raise SchemaError( + [None] + x.autos, + [self._error.format(data) if self._error else None] + x.errors, + ) + except BaseException as x: + from ._schema_constants import _callable_str + + f = _callable_str(self._callable) + raise SchemaError( + "%s(%r) raised %r" % (f, data, x), + self._error.format(data) if self._error else None, + ) + + +class Literal: + def __init__( + self, + value: Any, + description: Union[str, None] = None, + title: Union[str, None] = None, + ) -> None: + self._schema: Any = value + self._description: Union[str, None] = description + self._title: Union[str, None] = title + + def __str__(self) -> str: + return str(self._schema) + + def __repr__(self) -> str: + return f'Literal("{self._schema}", description="{self._description or ""}")' + + @property + def description(self) -> Union[str, None]: + return self._description + + @property + def title(self) -> Union[str, None]: + return self._title + + @property + def schema(self) -> Any: + return self._schema + + +# Forward import to avoid circular imports +from ._schema_core import Schema diff --git a/test_custom_validators.py b/test_custom_validators.py new file mode 100644 index 0000000..9ea6139 --- /dev/null +++ b/test_custom_validators.py @@ -0,0 +1,385 @@ +"""Tests for custom validator registration feature.""" + +import re +from typing import Any + +import pytest + +from schema import Schema, SchemaError + + +class TestCustomValidatorRegistration: + """Test cases for Schema.register_validator functionality.""" + + def setup_method(self): + """Clear registered validators before each test.""" + # Store original validators + self._original_validators = Schema.get_registered_validators() + # Clear all validators + for name in list(self._original_validators.keys()): + Schema.unregister_validator(name) + + def teardown_method(self): + """Restore original validators after each test.""" + # Clear current validators + for name in list(Schema.get_registered_validators().keys()): + Schema.unregister_validator(name) + # Restore original validators + for name, validator in self._original_validators.items(): + Schema.register_validator(name, validator) + + def test_register_simple_validator(self): + """Test registering a simple validator function.""" + def email_validator(value: Any) -> str: + if not isinstance(value, str): + raise SchemaError(f"Expected string, got {type(value).__name__}") + if "@" not in value: + raise SchemaError(f"Invalid email: {value}") + return value.lower() + + Schema.register_validator("email_validator", email_validator) + + # Test valid email + schema = Schema({"email": "email_validator"}) + result = schema.validate({"email": "Test@Example.com"}) + assert result == {"email": "test@example.com"} + + # Test invalid email (missing @) + with pytest.raises(SchemaError) as exc_info: + schema.validate({"email": "invalid-email"}) + assert "Invalid email" in str(exc_info.value) + + def test_register_validator_with_type_check(self): + """Test registering a validator that checks types.""" + def positive_int_validator(value: Any) -> int: + if not isinstance(value, int): + raise SchemaError(f"Expected int, got {type(value).__name__}") + if value <= 0: + raise SchemaError(f"Expected positive integer, got {value}") + return value + + Schema.register_validator("positive_int", positive_int_validator) + + schema = Schema({"count": "positive_int"}) + + # Valid positive integer + assert schema.validate({"count": 5}) == {"count": 5} + + # Zero should fail + with pytest.raises(SchemaError): + schema.validate({"count": 0}) + + # Negative should fail + with pytest.raises(SchemaError): + schema.validate({"count": -1}) + + # Non-integer should fail + with pytest.raises(SchemaError): + schema.validate({"count": "5"}) + + def test_register_validator_with_transformation(self): + """Test registering a validator that transforms data.""" + def trim_string_validator(value: Any) -> str: + if not isinstance(value, str): + raise SchemaError(f"Expected string, got {type(value).__name__}") + return value.strip() + + Schema.register_validator("trimmed", trim_string_validator) + + schema = Schema({"name": "trimmed"}) + result = schema.validate({"name": " John Doe "}) + assert result == {"name": "John Doe"} + + def test_multiple_custom_validators(self): + """Test using multiple custom validators in one schema.""" + def email_validator(value: Any) -> str: + if not isinstance(value, str) or "@" not in value: + raise SchemaError(f"Invalid email: {value}") + return value.lower() + + def phone_validator(value: Any) -> str: + if not isinstance(value, str): + raise SchemaError(f"Expected string, got {type(value).__name__}") + # Simple phone validation: must contain at least 10 digits + digits = re.sub(r"\D", "", value) + if len(digits) < 10: + raise SchemaError(f"Invalid phone number: {value}") + return digits + + Schema.register_validator("email_validator", email_validator) + Schema.register_validator("phone_validator", phone_validator) + + schema = Schema({ + "email": "email_validator", + "phone": "phone_validator", + }) + + result = schema.validate({ + "email": "User@Example.com", + "phone": "(555) 123-4567", + }) + assert result == { + "email": "user@example.com", + "phone": "5551234567", + } + + def test_unregister_validator(self): + """Test unregistering a validator.""" + def dummy_validator(value: Any) -> Any: + return value + + Schema.register_validator("dummy", dummy_validator) + assert "dummy" in Schema.get_registered_validators() + + # Test that registered validator works + schema = Schema({"field": "dummy"}) + result = schema.validate({"field": "value"}) + assert result == {"field": "value"} + + Schema.unregister_validator("dummy") + assert "dummy" not in Schema.get_registered_validators() + + # After unregistering, the validator name is treated as a literal string + schema2 = Schema({"field": "dummy"}) + with pytest.raises(SchemaError) as exc_info: + schema2.validate({"field": "value"}) + # Now "dummy" is treated as a literal string to compare against + assert "does not match" in str(exc_info.value) + + def test_get_registered_validators(self): + """Test getting registered validators.""" + def validator1(value: Any) -> Any: + return value + + def validator2(value: Any) -> Any: + return value + + Schema.register_validator("validator1", validator1) + Schema.register_validator("validator2", validator2) + + validators = Schema.get_registered_validators() + assert "validator1" in validators + assert "validator2" in validators + assert validators["validator1"] is validator1 + assert validators["validator2"] is validator2 + + # Returned dict should be a copy + validators.clear() + assert "validator1" in Schema.get_registered_validators() + + def test_register_non_callable_raises_error(self): + """Test that registering a non-callable raises TypeError.""" + with pytest.raises(TypeError) as exc_info: + Schema.register_validator("not_callable", "string_value") + assert "callable" in str(exc_info.value).lower() + + def test_custom_validator_with_nested_schema(self): + """Test custom validators in nested schemas.""" + def url_validator(value: Any) -> str: + if not isinstance(value, str): + raise SchemaError(f"Expected string, got {type(value).__name__}") + if not value.startswith(("http://", "https://")): + raise SchemaError(f"Invalid URL: {value}") + return value + + Schema.register_validator("url", url_validator) + + schema = Schema({ + "website": "url", + "social": { + "twitter": "url", + }, + }) + + result = schema.validate({ + "website": "https://example.com", + "social": { + "twitter": "https://twitter.com/user", + }, + }) + assert result["website"] == "https://example.com" + assert result["social"]["twitter"] == "https://twitter.com/user" + + def test_custom_validator_in_list(self): + """Test custom validators with list schemas.""" + def non_empty_string_validator(value: Any) -> str: + if not isinstance(value, str): + raise SchemaError(f"Expected string, got {type(value).__name__}") + if not value.strip(): + raise SchemaError("String cannot be empty") + return value + + Schema.register_validator("non_empty", non_empty_string_validator) + + schema = Schema(["non_empty"]) + + # Valid list + result = schema.validate(["hello", "world"]) + assert result == ["hello", "world"] + + # Empty string should fail + with pytest.raises(SchemaError): + schema.validate(["hello", "", "world"]) + + def test_custom_validator_with_error_message(self): + """Test custom validator with custom error message in schema.""" + def strict_positive_validator(value: Any) -> int: + if not isinstance(value, int) or value <= 0: + raise SchemaError(f"Value must be positive integer") + return value + + Schema.register_validator("strict_positive", strict_positive_validator) + + schema = Schema( + {"count": "strict_positive"}, + error="Validation failed for count" + ) + + with pytest.raises(SchemaError) as exc_info: + schema.validate({"count": -5}) + # The custom error message should be included + assert "Validation failed" in str(exc_info.value) + + def test_validator_returning_transformed_value(self): + """Test that validator can return transformed values.""" + def uppercase_validator(value: Any) -> str: + if not isinstance(value, str): + raise SchemaError(f"Expected string") + return value.upper() + + Schema.register_validator("uppercase", uppercase_validator) + + schema = Schema({"code": "uppercase"}) + result = schema.validate({"code": "abc123"}) + assert result == {"code": "ABC123"} + + def test_unknown_validator_raises_error(self): + """Test that using an unknown validator name as root schema falls back to literal comparison.""" + # When an unregistered string is used as root schema, it's treated as a literal value + schema = Schema("nonexistent_validator") + + # Matching the literal value should work + result = schema.validate("nonexistent_validator") + assert result == "nonexistent_validator" + + # Non-matching value should raise SchemaError + with pytest.raises(SchemaError) as exc_info: + schema.validate("some_value") + assert "does not match" in str(exc_info.value) + + def test_validator_name_same_as_key_name(self): + """Test that validator name same as field name doesn't affect key matching. + + This tests the fix for the issue where registering a validator with the same + name as a field would incorrectly trigger validation during key matching. + """ + def email_validator(value: Any) -> str: + if not isinstance(value, str): + raise SchemaError(f"Expected string, got {type(value).__name__}") + if "@" not in value: + raise SchemaError(f"Invalid email: {value}") + return value.lower() + + # Register a validator with the same name as a common field name + Schema.register_validator("email", email_validator) + + # Use "email" as the key name, and "email" (the validator) as the value schema + # The key "email" should be matched as a literal string, not as a validator + schema = Schema({"email": "email"}) + + # This should work: key "email" matches field name "email", + # and value is validated using the "email" validator + result = schema.validate({"email": "Test@Example.com"}) + assert result == {"email": "test@example.com"} + + # Invalid email should still fail + with pytest.raises(SchemaError) as exc_info: + schema.validate({"email": "invalid-email"}) + assert "Invalid email" in str(exc_info.value) + + # Test with different key name but same validator + schema2 = Schema({"user_email": "email"}) + result2 = schema2.validate({"user_email": "User@Example.com"}) + assert result2 == {"user_email": "user@example.com"} + + +class TestCustomValidatorEdgeCases: + """Edge case tests for custom validators.""" + + def setup_method(self): + """Clear registered validators before each test.""" + self._original_validators = Schema.get_registered_validators() + for name in list(self._original_validators.keys()): + Schema.unregister_validator(name) + + def teardown_method(self): + """Restore original validators after each test.""" + for name in list(Schema.get_registered_validators().keys()): + Schema.unregister_validator(name) + for name, validator in self._original_validators.items(): + Schema.register_validator(name, validator) + + def test_validator_raising_schemaerror(self): + """Test validator that raises SchemaError directly.""" + def failing_validator(value: Any) -> Any: + raise SchemaError("Custom error message") + + Schema.register_validator("failing", failing_validator) + + schema = Schema({"field": "failing"}) + with pytest.raises(SchemaError) as exc_info: + schema.validate({"field": "value"}) + assert "Custom error message" in str(exc_info.value) + + def test_validator_with_none_value(self): + """Test validator that handles None values.""" + def nullable_string_validator(value: Any) -> Any: + if value is None: + return None + if not isinstance(value, str): + raise SchemaError(f"Expected string or None") + return value + + Schema.register_validator("nullable_string", nullable_string_validator) + + schema = Schema({"name": "nullable_string"}) + + assert schema.validate({"name": None}) == {"name": None} + assert schema.validate({"name": "test"}) == {"name": "test"} + + with pytest.raises(SchemaError): + schema.validate({"name": 123}) + + def test_validator_name_collision(self): + """Test that registering with same name overwrites previous validator.""" + def validator1(value: Any) -> str: + return "v1" + + def validator2(value: Any) -> str: + return "v2" + + Schema.register_validator("test_validator", validator1) + Schema.register_validator("test_validator", validator2) + + schema = Schema({"field": "test_validator"}) + result = schema.validate({"field": "any_value"}) + assert result == {"field": "v2"} + + def test_validator_with_complex_return_type(self): + """Test validator returning complex types.""" + def parse_list_validator(value: Any) -> list: + if isinstance(value, list): + return value + if isinstance(value, str): + return [item.strip() for item in value.split(",")] + raise SchemaError(f"Cannot convert to list: {value}") + + Schema.register_validator("comma_list", parse_list_validator) + + schema = Schema({"tags": "comma_list"}) + + result = schema.validate({"tags": "python, schema, validation"}) + assert result == {"tags": ["python", "schema", "validation"]} + + result = schema.validate({"tags": ["a", "b", "c"]}) + assert result == {"tags": ["a", "b", "c"]}