diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 3532e81a8e..14fba775cf 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -24,6 +24,7 @@ cast, overload, ) +from unittest import mock from pydantic import BaseModel, EmailStr from pydantic.fields import FieldInfo as PydanticFieldInfo @@ -52,7 +53,14 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid -from typing_extensions import Literal, TypeAlias, deprecated, get_origin +from typing_extensions import ( + Annotated, + Literal, + TypeAlias, + deprecated, + get_args, + get_origin, +) from ._compat import ( # type: ignore[attr-defined] IS_PYDANTIC_V2, @@ -111,71 +119,60 @@ def __dataclass_transform__( class FieldInfo(PydanticFieldInfo): def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: - primary_key = kwargs.pop("primary_key", False) - nullable = kwargs.pop("nullable", Undefined) - foreign_key = kwargs.pop("foreign_key", Undefined) - ondelete = kwargs.pop("ondelete", Undefined) - unique = kwargs.pop("unique", False) - index = kwargs.pop("index", Undefined) - sa_type = kwargs.pop("sa_type", Undefined) - sa_column = kwargs.pop("sa_column", Undefined) - sa_column_args = kwargs.pop("sa_column_args", Undefined) - sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined) - if sa_column is not Undefined: - if sa_column_args is not Undefined: - raise RuntimeError( - "Passing sa_column_args is not supported when " - "also passing a sa_column" - ) - if sa_column_kwargs is not Undefined: - raise RuntimeError( - "Passing sa_column_kwargs is not supported when " - "also passing a sa_column" - ) - if primary_key is not Undefined: - raise RuntimeError( - "Passing primary_key is not supported when " - "also passing a sa_column" - ) - if nullable is not Undefined: - raise RuntimeError( - "Passing nullable is not supported when also passing a sa_column" - ) - if foreign_key is not Undefined: - raise RuntimeError( - "Passing foreign_key is not supported when " - "also passing a sa_column" - ) - if ondelete is not Undefined: - raise RuntimeError( - "Passing ondelete is not supported when also passing a sa_column" - ) - if unique is not Undefined: - raise RuntimeError( - "Passing unique is not supported when also passing a sa_column" - ) - if index is not Undefined: - raise RuntimeError( - "Passing index is not supported when also passing a sa_column" - ) - if sa_type is not Undefined: - raise RuntimeError( - "Passing sa_type is not supported when also passing a sa_column" - ) - if ondelete is not Undefined: - if foreign_key is Undefined: + sqlmodel_attributes = ( + "primary_key", + "nullable", + "foreign_key", + "ondelete", + "unique", + "index", + "sa_type", + "sa_column", + "sa_column_args", + "sa_column_kwargs", + ) + sqlmodel_attributes_set = {} + for attr in sqlmodel_attributes: + value = kwargs.pop(attr, Undefined) + if value is not Undefined: + sqlmodel_attributes_set[attr] = value + + if "sa_column" in sqlmodel_attributes_set: + unsupported_with_sa_column = ( + "sa_column_args", + "sa_column_kwargs", + "primary_key", + "nullable", + "foreign_key", + "ondelete", + "unique", + "index", + "sa_type", + ) + for attr in unsupported_with_sa_column: + if attr in sqlmodel_attributes_set: + raise RuntimeError( + f"Passing {attr} is not supported when also passing a sa_column" + ) + if "ondelete" in sqlmodel_attributes_set: + if "foreign_key" not in sqlmodel_attributes_set: raise RuntimeError("ondelete can only be used with foreign_key") + super().__init__(default=default, **kwargs) - self.primary_key = primary_key - self.nullable = nullable - self.foreign_key = foreign_key - self.ondelete = ondelete - self.unique = unique - self.index = index - self.sa_type = sa_type - self.sa_column = sa_column - self.sa_column_args = sa_column_args - self.sa_column_kwargs = sa_column_kwargs + self._attributes_set.update(sqlmodel_attributes_set) + + self.primary_key = sqlmodel_attributes_set.get("primary_key", False) + self.nullable = sqlmodel_attributes_set.get("nullable", Undefined) + self.foreign_key = sqlmodel_attributes_set.get("foreign_key", Undefined) + self.ondelete = sqlmodel_attributes_set.get("ondelete", Undefined) + self.unique = sqlmodel_attributes_set.get("unique", False) + self.index = sqlmodel_attributes_set.get("index", Undefined) + self.sa_type = sqlmodel_attributes_set.get("sa_type", Undefined) + self.sa_column = sqlmodel_attributes_set.get("sa_column", Undefined) + self.sa_column_args = sqlmodel_attributes_set.get("sa_column_args", Undefined) + self.sa_column_kwargs = sqlmodel_attributes_set.get( + "sa_column_kwargs", Undefined + ) class RelationshipInfo(Representation): @@ -215,33 +212,33 @@ def __init__( def Field( default: Any = Undefined, *, - default_factory: Optional[NoArgAnyCallable] = None, - alias: Optional[str] = None, - title: Optional[str] = None, - description: Optional[str] = None, + default_factory: Optional[NoArgAnyCallable] = Undefined, + alias: Optional[str] = Undefined, + title: Optional[str] = Undefined, + description: Optional[str] = Undefined, exclude: Union[ AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any - ] = None, + ] = Undefined, include: Union[ AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any - ] = None, - const: Optional[bool] = None, - gt: Optional[float] = None, - ge: Optional[float] = None, - lt: Optional[float] = None, - le: Optional[float] = None, - multiple_of: Optional[float] = None, - max_digits: Optional[int] = None, - decimal_places: Optional[int] = None, - min_items: Optional[int] = None, - max_items: Optional[int] = None, - unique_items: Optional[bool] = None, - min_length: Optional[int] = None, - max_length: Optional[int] = None, - allow_mutation: bool = True, - regex: Optional[str] = None, - discriminator: Optional[str] = None, - repr: bool = True, + ] = Undefined, + const: Optional[bool] = Undefined, + gt: Optional[float] = Undefined, + ge: Optional[float] = Undefined, + lt: Optional[float] = Undefined, + le: Optional[float] = Undefined, + multiple_of: Optional[float] = Undefined, + max_digits: Optional[int] = Undefined, + decimal_places: Optional[int] = Undefined, + min_items: Optional[int] = Undefined, + max_items: Optional[int] = Undefined, + unique_items: Optional[bool] = Undefined, + min_length: Optional[int] = Undefined, + max_length: Optional[int] = Undefined, + allow_mutation: bool = Undefined, + regex: Optional[str] = Undefined, + discriminator: Optional[str] = Undefined, + repr: bool = Undefined, primary_key: Union[bool, UndefinedType] = Undefined, foreign_key: Any = Undefined, unique: Union[bool, UndefinedType] = Undefined, @@ -260,33 +257,33 @@ def Field( def Field( default: Any = Undefined, *, - default_factory: Optional[NoArgAnyCallable] = None, - alias: Optional[str] = None, - title: Optional[str] = None, - description: Optional[str] = None, + default_factory: Optional[NoArgAnyCallable] = Undefined, + alias: Optional[str] = Undefined, + title: Optional[str] = Undefined, + description: Optional[str] = Undefined, exclude: Union[ AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any - ] = None, + ] = Undefined, include: Union[ AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any - ] = None, - const: Optional[bool] = None, - gt: Optional[float] = None, - ge: Optional[float] = None, - lt: Optional[float] = None, - le: Optional[float] = None, - multiple_of: Optional[float] = None, - max_digits: Optional[int] = None, - decimal_places: Optional[int] = None, - min_items: Optional[int] = None, - max_items: Optional[int] = None, - unique_items: Optional[bool] = None, - min_length: Optional[int] = None, - max_length: Optional[int] = None, - allow_mutation: bool = True, - regex: Optional[str] = None, - discriminator: Optional[str] = None, - repr: bool = True, + ] = Undefined, + const: Optional[bool] = Undefined, + gt: Optional[float] = Undefined, + ge: Optional[float] = Undefined, + lt: Optional[float] = Undefined, + le: Optional[float] = Undefined, + multiple_of: Optional[float] = Undefined, + max_digits: Optional[int] = Undefined, + decimal_places: Optional[int] = Undefined, + min_items: Optional[int] = Undefined, + max_items: Optional[int] = Undefined, + unique_items: Optional[bool] = Undefined, + min_length: Optional[int] = Undefined, + max_length: Optional[int] = Undefined, + allow_mutation: bool = Undefined, + regex: Optional[str] = Undefined, + discriminator: Optional[str] = Undefined, + repr: bool = Undefined, primary_key: Union[bool, UndefinedType] = Undefined, foreign_key: str, ondelete: Union[OnDeleteType, UndefinedType] = Undefined, @@ -314,68 +311,68 @@ def Field( def Field( default: Any = Undefined, *, - default_factory: Optional[NoArgAnyCallable] = None, - alias: Optional[str] = None, - title: Optional[str] = None, - description: Optional[str] = None, + default_factory: Optional[NoArgAnyCallable] = Undefined, + alias: Optional[str] = Undefined, + title: Optional[str] = Undefined, + description: Optional[str] = Undefined, exclude: Union[ AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any - ] = None, + ] = Undefined, include: Union[ AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any - ] = None, - const: Optional[bool] = None, - gt: Optional[float] = None, - ge: Optional[float] = None, - lt: Optional[float] = None, - le: Optional[float] = None, - multiple_of: Optional[float] = None, - max_digits: Optional[int] = None, - decimal_places: Optional[int] = None, - min_items: Optional[int] = None, - max_items: Optional[int] = None, - unique_items: Optional[bool] = None, - min_length: Optional[int] = None, - max_length: Optional[int] = None, - allow_mutation: bool = True, - regex: Optional[str] = None, - discriminator: Optional[str] = None, - repr: bool = True, + ] = Undefined, + const: Optional[bool] = Undefined, + gt: Optional[float] = Undefined, + ge: Optional[float] = Undefined, + lt: Optional[float] = Undefined, + le: Optional[float] = Undefined, + multiple_of: Optional[float] = Undefined, + max_digits: Optional[int] = Undefined, + decimal_places: Optional[int] = Undefined, + min_items: Optional[int] = Undefined, + max_items: Optional[int] = Undefined, + unique_items: Optional[bool] = Undefined, + min_length: Optional[int] = Undefined, + max_length: Optional[int] = Undefined, + allow_mutation: bool = Undefined, + regex: Optional[str] = Undefined, + discriminator: Optional[str] = Undefined, + repr: bool = Undefined, sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore - schema_extra: Optional[Dict[str, Any]] = None, + schema_extra: Optional[Dict[str, Any]] = Undefined, ) -> Any: ... def Field( default: Any = Undefined, *, - default_factory: Optional[NoArgAnyCallable] = None, - alias: Optional[str] = None, - title: Optional[str] = None, - description: Optional[str] = None, + default_factory: Optional[NoArgAnyCallable] = Undefined, + alias: Optional[str] = Undefined, + title: Optional[str] = Undefined, + description: Optional[str] = Undefined, exclude: Union[ AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any - ] = None, + ] = Undefined, include: Union[ AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any - ] = None, - const: Optional[bool] = None, - gt: Optional[float] = None, - ge: Optional[float] = None, - lt: Optional[float] = None, - le: Optional[float] = None, - multiple_of: Optional[float] = None, - max_digits: Optional[int] = None, - decimal_places: Optional[int] = None, - min_items: Optional[int] = None, - max_items: Optional[int] = None, - unique_items: Optional[bool] = None, - min_length: Optional[int] = None, - max_length: Optional[int] = None, - allow_mutation: bool = True, - regex: Optional[str] = None, - discriminator: Optional[str] = None, - repr: bool = True, + ] = Undefined, + const: Optional[bool] = Undefined, + gt: Optional[float] = Undefined, + ge: Optional[float] = Undefined, + lt: Optional[float] = Undefined, + le: Optional[float] = Undefined, + multiple_of: Optional[float] = Undefined, + max_digits: Optional[int] = Undefined, + decimal_places: Optional[int] = Undefined, + min_items: Optional[int] = Undefined, + max_items: Optional[int] = Undefined, + unique_items: Optional[bool] = Undefined, + min_length: Optional[int] = Undefined, + max_length: Optional[int] = Undefined, + allow_mutation: bool = Undefined, + regex: Optional[str] = Undefined, + discriminator: Optional[str] = Undefined, + repr: bool = Undefined, primary_key: Union[bool, UndefinedType] = Undefined, foreign_key: Any = Undefined, ondelete: Union[OnDeleteType, UndefinedType] = Undefined, @@ -475,6 +472,67 @@ def Relationship( return relationship_info +def _merge_field_infos_and_update_class_dict(annotations, class_dict): + # Merge the FieldInfo into a single FieldInfo to prevent pydantic from dropping sqlmodel's attributes. + # Also update the class_dict if a class's default values include FieldInfo instances. + # Pydantic will automatically merge FieldInfo instances, but the merged object will be a `pydantic.FieldInfo`. All + # the sqlmodel-specific attributes are lost. + new_annotations = {} + updated = [] + for attr, annotation in annotations.items(): + # EARLY CONTINUE + if get_origin(annotation) is not Annotated: + new_annotations[attr] = annotation + continue + type_hint, *annotation_args = get_args(annotation) + # TODO: Check for other subclasses of `pydantic.FieldInfo`? + field_infos = [arg for arg in annotation_args if isinstance(arg, FieldInfo)] + # EARLY CONTINUE + if not field_infos: + # No FieldInfo in the annotation. Leave it as is. + new_annotations[attr] = annotation + continue + non_field_info_annotation_args = [ + arg for arg in annotation_args if not isinstance(arg, FieldInfo) + ] + # If the default value is a FieldInfo, include that too. + default_field_info = None + if isinstance(class_dict.get(attr), FieldInfo): + default_field_info = class_dict.get(attr) + field_infos.append(default_field_info) + + # HACK: Quick and dirty way of merging into a new sqlmodel.FieldInfo object. + with mock.patch("pydantic.fields.FieldInfo", FieldInfo): + field_info = FieldInfo.merge_field_infos(*field_infos) + + if not default_field_info: + # some_attr: Annotated[SomeType, FieldInfo(...), FieldInfo(...)] + # becomes + # some_attr: Annotated[SomeType, FieldInfo(...)] + new_annotations[attr] = Annotated[ + type_hint, field_info, *non_field_info_annotation_args + ] + elif not non_field_info_annotation_args: + # some_attr: Annotated[SomeType, FieldInfo(...)] = FieldInfo(...) + # becomes + # some_attr: SomeType = FieldInfo(...) + new_annotations[attr] = type_hint + class_dict[attr] = field_info + else: + # some_attr: Annotated[SomeType, FieldInfo(...), other_annotation] = FieldInfo(...) + # becomes + # some_attr: Annotated[SomeType, other_annotation] = FieldInfo(...) + new_annotations[attr] = Annotated[ + type_hint, *non_field_info_annotation_args + ] + class_dict[attr] = field_info + + for attr in updated: + if attr not in class_dict: + continue + return new_annotations + + @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): __sqlmodel_relationships__: Dict[str, RelationshipInfo] @@ -519,6 +577,9 @@ def __new__( relationship_annotations[k] = v else: pydantic_annotations[k] = v + pydantic_annotations = _merge_field_infos_and_update_class_dict( + pydantic_annotations, dict_for_pydantic + ) dict_used = { **dict_for_pydantic, "__weakref__": None, @@ -538,6 +599,12 @@ def __new__( config_kwargs = { key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs } + # Also include pydantic's internal kwargs + config_kwargs.update( + (key, value) + for key, value in kwargs.items() + if key.startswith("__pydantic_") + ) new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) new_cls.__annotations__ = { **relationship_annotations, diff --git a/tests/test_pydantic/test_field.py b/tests/test_pydantic/test_field.py index 9d7bc77625..96903a1692 100644 --- a/tests/test_pydantic/test_field.py +++ b/tests/test_pydantic/test_field.py @@ -1,9 +1,12 @@ +import datetime from decimal import Decimal -from typing import Optional, Union +from typing import Annotated, Optional, Union import pytest +import sqlalchemy as sa from pydantic import ValidationError from sqlmodel import Field, SQLModel +from sqlmodel.main import FieldInfo from typing_extensions import Literal @@ -55,3 +58,55 @@ class Model(SQLModel): instance = Model(id=123, foo="bar") assert "foo=" not in repr(instance) + + +def test_field_merging(): + sa_type = sa.DATETIME + + MyDateTime = Annotated[ + datetime.datetime, + Field(sa_type=sa_type), + ] + + class Model(SQLModel): + value: Annotated[ + MyDateTime, + Field(default_factory=datetime.datetime.now), + Field(description="some-description", title="some-title"), + Field(index=True), + ] = Field(nullable=False) + + assert Model.model_json_schema() == { + "properties": { + "value": { + "description": "some-description", + "format": "date-time", + "title": "some-title", + "type": "string", + } + }, + "title": "Model", + "type": "object", + } + expected_field = Field( + sa_type=sa.DATETIME, + default_factory=datetime.datetime.now, + description="some-description", + title="some-title", + index=True, + nullable=False, + ) + actual_field = Model.model_fields["value"] + assert isinstance(actual_field, FieldInfo) + + comp_attrs = [ + "sa_type", + "default_factory", + "description", + "title", + "index", + "nullable", + ] + for attr in comp_attrs: + assert getattr(actual_field, attr) == getattr(expected_field, attr) + assert getattr(actual_field, attr) is not None