Skip to content

Commit 3843eab

Browse files
committed
FIX/ENH/CLN: simplified CheckedSession code and avoid using pydantic private API (closes #1151)
also fixes properties on subclasses of CheckedSession (closes #1152)
1 parent 7e401b3 commit 3843eab

File tree

6 files changed

+155
-163
lines changed

6 files changed

+155
-163
lines changed

doc/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ pandas >=0.20
44
matplotlib
55
tables # ==pytables
66
openpyxl
7-
pydantic >=2.12
7+
pydantic ==2.12
88

99
# dependencies to actually build the documentation
1010
sphinx ==5.3.0

doc/source/changes/version_0_35.rst.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Backward incompatible changes
2424
is closed by the user. To revert to the previous behavior, use show=False.
2525

2626
* Using :py:obj:`CheckedSession`, :py:obj:`CheckedParameters` or
27-
:py:obj:`CheckedArray` now requires installing pydantic >= 2.12
27+
:py:obj:`CheckedArray` now requires installing pydantic >= 2
2828
(closes :issue:`1075`).
2929

3030

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ dependencies:
99
- openpyxl
1010
- xlsxwriter
1111
- pytest >=6
12-
- pydantic >= 2.12
12+
- pydantic >=2

larray/core/checked.py

Lines changed: 142 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from abc import ABCMeta
1+
from types import FunctionType
2+
from typing import Type, Any, Dict, Set, Annotated
23
import warnings
34

45
import numpy as np
56

6-
from typing import Type, Any, Dict, Set, no_type_check, Annotated
7-
87
from larray.core.axis import AxisCollection
98
from larray.core.array import Array, full
109
from larray.core.session import Session
@@ -14,6 +13,8 @@ class NotLoaded:
1413
pass
1514

1615

16+
NOT_LOADED = NotLoaded()
17+
1718
try:
1819
import pydantic
1920
except ImportError:
@@ -37,9 +38,29 @@ def __init__(self, *args, **kwargs):
3738
raise NotImplementedError("CheckedParameters class cannot be instantiated "
3839
"because pydantic is not installed")
3940
else:
40-
from pydantic import ConfigDict, BeforeValidator, ValidationInfo, TypeAdapter, ValidationError
41+
from pydantic import (
42+
ConfigDict, BeforeValidator, ValidationInfo, TypeAdapter,
43+
ValidationError, BaseModel
44+
)
4145
from pydantic_core import PydanticUndefined
4246

47+
from pydantic.fields import ComputedFieldInfo
48+
49+
# should more or less match pydantic's default ignored types found
50+
# in pydantic at:
51+
# from pydantic._internal._model_construction import default_ignored_types
52+
# PYDANTIC_IGNORED_TYPES = default_ignored_types()
53+
PYDANTIC_IGNORED_TYPES = (
54+
FunctionType,
55+
property,
56+
classmethod,
57+
staticmethod,
58+
# PydanticDescriptorProxy,
59+
ComputedFieldInfo,
60+
# TypeAliasType, # from `typing_extensions`
61+
)
62+
63+
4364
def CheckedArray(axes: AxisCollection, dtype: np.dtype = float) -> Type[Array]:
4465
"""
4566
Represents a constrained array. It is intended to only be used along with :py:class:`CheckedSession`.
@@ -99,115 +120,37 @@ def validate_array(value: Any, info: ValidationInfo) -> Array:
99120
return Annotated[Array, BeforeValidator(validate_array)]
100121

101122

102-
class AbstractCheckedSession:
103-
pass
123+
# this is a trick to avoid using pydantic internal API. It is mostly
124+
# equivalent to:
125+
# from pydantic._internal._model_construction import ModelMetaclass
126+
ModelMetaclass = type(BaseModel)
104127

128+
# metaclass to dynamically add type annotations for
129+
# variables defined without type hints in CheckedSession subclasses.
130+
# This allows defining constant class variables (e.g. axes), without having
131+
# to explicitly add type hints, which would feel redundant.
132+
class LArrayModelMetaclass(ModelMetaclass):
133+
def __new__(mcs, cls_name: str, bases: tuple[type[Any], ...],
134+
namespace: dict[str, Any], **kwargs):
105135

106-
# Simplified version of the ModelMetaclass class from pydantic:
107-
# https://github.com/pydantic/pydantic/blob/v2.12.0/pydantic/_internal/_model_construction.py
108-
class ModelMetaclass(ABCMeta):
109-
@no_type_check # noqa C901
110-
def __new__(mcs, cls_name: str, bases: tuple[type[Any], ...], namespace: dict[str, Any], **kwargs: Any):
111-
from pydantic._internal._config import ConfigWrapper
112-
from pydantic._internal._decorators import DecoratorInfos
113-
from pydantic._internal._namespace_utils import NsResolver
114-
from pydantic._internal._fields import is_valid_field_name
115-
from pydantic._internal._model_construction import (inspect_namespace, set_model_fields,
116-
complete_model_class, set_default_hash_func)
117-
136+
# any type hints defined in the class body will land in
137+
# __annotations__ (this is not pydantic-specific) but
138+
# __annotations__ is only defined if there are type hints
118139
raw_annotations = namespace.get('__annotations__', {})
119-
120-
# tries to infer types for variables without type hints
121-
keys_to_infer_type = [key for key in namespace.keys()
122-
if key not in raw_annotations]
123-
keys_to_infer_type = [key for key in keys_to_infer_type
124-
if is_valid_field_name(key)]
125-
keys_to_infer_type = [key for key in keys_to_infer_type
126-
if key not in {'model_config', 'dict', 'build'}]
127-
keys_to_infer_type = [key for key in keys_to_infer_type
128-
if not callable(namespace[key])]
129-
for key in keys_to_infer_type:
130-
value = namespace[key]
131-
raw_annotations[key] = type(value)
132-
133-
base_field_names, class_vars, base_private_attributes = mcs._collect_bases_data(bases)
134-
135-
config_wrapper = ConfigWrapper.for_model(bases, namespace, raw_annotations, kwargs)
136-
namespace['model_config'] = config_wrapper.config_dict
137-
private_attributes = inspect_namespace(namespace, raw_annotations, config_wrapper.ignored_types,
138-
class_vars, base_field_names)
139-
140-
namespace['__class_vars__'] = class_vars
141-
namespace['__private_attributes__'] = {**base_private_attributes, **private_attributes}
142-
143-
cls = super().__new__(mcs, cls_name, bases, namespace, **kwargs)
144-
145-
cls.__pydantic_decorators__ = DecoratorInfos.build(cls)
146-
cls.__pydantic_decorators__.update_from_config(config_wrapper)
147-
148-
cls.__pydantic_generic_metadata__ = {'origin': None, 'args': (), 'parameters': None}
149-
cls.__pydantic_root_model__ = False
150-
cls.__pydantic_complete__ = False
151-
152-
# create a copy of raw_annotations since cls.__annotations__ points to it and
153-
# cls.__annotations__ must not be polluted before calling set_model_fields() later
154-
cls.__fields_annotations__ = {k: v for k, v in raw_annotations.items()}
155-
for base in reversed(bases):
156-
if issubclass(base, AbstractCheckedSession) and base != AbstractCheckedSession:
157-
base_fields_annotations = getattr(base, '__fields_annotations__', {})
158-
for k, v in base_fields_annotations.items():
159-
if k not in cls.__fields_annotations__:
160-
cls.__fields_annotations__[k] = v
161-
162-
# preserve `__set_name__` protocol defined in https://peps.python.org/pep-0487
163-
# for attributes not in `namespace` (e.g. private attributes)
164-
for name, obj in private_attributes.items():
165-
obj.__set_name__(cls, name)
166-
167-
ns_resolver = NsResolver()
168-
set_model_fields(cls, config_wrapper=config_wrapper, ns_resolver=ns_resolver)
169-
complete_model_class(cls, config_wrapper, ns_resolver, raise_errors=False, call_on_complete_hook=False)
170-
171-
if config_wrapper.frozen and '__hash__' not in namespace:
172-
set_default_hash_func(cls, bases)
173-
174-
return cls
175-
176-
@staticmethod
177-
def _collect_bases_data(bases: tuple[type[Any], ...]) -> tuple[set[str], set[str], dict[str, Any]]:
178-
from pydantic.fields import ModelPrivateAttr
179-
180-
field_names: set[str] = set()
181-
class_vars: set[str] = set()
182-
private_attributes: dict[str, ModelPrivateAttr] = {}
183-
for base in bases:
184-
if issubclass(base, AbstractCheckedSession) and base is not AbstractCheckedSession:
185-
# model_fields might not be defined yet in the case of generics, so we use getattr here:
186-
field_names.update(getattr(base, '__pydantic_fields__', {}).keys())
187-
class_vars.update(base.__class_vars__)
188-
private_attributes.update(base.__private_attributes__)
189-
return field_names, class_vars, private_attributes
190-
191-
@property
192-
def __pydantic_fields_complete__(self) -> bool:
193-
"""Whether the fields where successfully collected (i.e. type hints were successfully resolves).
194-
195-
This is a private attribute, not meant to be used outside Pydantic.
196-
"""
197-
if '__pydantic_fields__' not in self.__dict__:
198-
return False
199-
200-
field_infos = self.__pydantic_fields__
201-
return all(field_info._complete for field_info in field_infos.values())
202-
203-
def __dir__(self) -> list[str]:
204-
attributes = list(super().__dir__())
205-
if '__fields__' in attributes:
206-
attributes.remove('__fields__')
207-
return attributes
208-
209-
210-
class CheckedSession(Session, AbstractCheckedSession, metaclass=ModelMetaclass):
140+
type_annotations = {
141+
key: type(value)
142+
for key, value in namespace.items()
143+
if not (key in raw_annotations or
144+
key.startswith('_') or
145+
isinstance(value, PYDANTIC_IGNORED_TYPES))
146+
}
147+
if type_annotations:
148+
namespace = namespace.copy()
149+
namespace['__annotations__'] = raw_annotations | type_annotations
150+
return super().__new__(mcs, cls_name, bases, namespace)
151+
152+
153+
class CheckedSession(Session, BaseModel, metaclass=LArrayModelMetaclass):
211154
"""
212155
Class intended to be inherited by user defined classes in which the variables of a model are declared.
213156
Each declared variable is constrained by a type defined explicitly or deduced from the given default value
@@ -374,10 +317,16 @@ class CheckedSession(Session, AbstractCheckedSession, metaclass=ModelMetaclass):
374317
dumping population ... done
375318
dumping undeclared_var ... done
376319
"""
377-
model_config = ConfigDict(arbitrary_types_allowed=True, validate_default=True, extra='allow',
378-
validate_assignment=True, frozen=False)
320+
model_config = ConfigDict(
321+
arbitrary_types_allowed=True,
322+
validate_default=True,
323+
extra='allow',
324+
validate_assignment=True,
325+
frozen=False
326+
)
379327

380328
def __init__(self, *args, meta=None, **kwargs):
329+
# initialize an empty Session
381330
Session.__init__(self, meta=meta)
382331

383332
# create an intermediate Session object to not call the __setattr__
@@ -386,14 +335,15 @@ def __init__(self, *args, meta=None, **kwargs):
386335
# TODO: refactor Session.load() to use a private function which returns the handler directly
387336
# so that we can get the items out of it and avoid this
388337
input_data = dict(Session(*args, **kwargs))
389-
390338
# --- declared variables
391-
for name, field in self.__pydantic_fields__.items():
392-
value = input_data.pop(name, NotLoaded())
339+
for name, field in self.__class__.model_fields.items():
340+
value = input_data.pop(name, NOT_LOADED)
393341

394-
if isinstance(value, NotLoaded):
342+
if value is NOT_LOADED:
395343
if field.default is PydanticUndefined:
396-
warnings.warn(f"No value passed for the declared variable '{name}'", stacklevel=2)
344+
warnings.warn(f"No value passed for the declared variable '{name}'",
345+
stacklevel=2)
346+
# we actually use NOT_LOADED as the value
397347
self.__setattr__(name, value, skip_frozen=True, skip_validation=True)
398348
else:
399349
self.__setattr__(name, field.default, skip_frozen=True)
@@ -405,61 +355,94 @@ def __init__(self, *args, meta=None, **kwargs):
405355
self.__setattr__(name, value, skip_frozen=True, stacklevel=2)
406356

407357
# code of the method below has been partly borrowed from pydantic.BaseModel.__setattr__()
408-
def _check_key_value(self, name: str, value: Any, skip_frozen: bool, skip_validation: bool,
358+
def _check_key_value(self, name: str,
359+
value: Any,
360+
skip_frozen: bool,
361+
skip_validation: bool,
409362
stacklevel: int) -> Any:
410-
config = self.model_config
411-
if not config['extra'] and name not in self.__pydantic_fields__:
412-
raise ValueError(f"Variable '{name}' is not declared in '{self.__class__.__name__}'. "
413-
f"Adding undeclared variables is forbidden. "
414-
f"List of declared variables is: {list(self.__pydantic_fields__.keys())}.")
415-
if not skip_frozen and config['frozen']:
416-
raise TypeError(f"Cannot change the value of the variable '{name}' since '{self.__class__.__name__}' "
363+
if skip_validation:
364+
return value
365+
366+
cls = self.__class__
367+
cls_name = cls.__name__
368+
model_config = cls.model_config
369+
if model_config['frozen'] and not skip_frozen:
370+
raise TypeError(f"Cannot change the value of the variable '{name}' since '{cls_name}' "
417371
f"is immutable and does not support item assignment")
418-
if name in self.__pydantic_fields__:
419-
if not skip_validation:
420-
try:
421-
field_type = self.__fields_annotations__.get(name, None)
422-
if field_type is None:
423-
return value
424-
# see https://docs.pydantic.dev/2.12/concepts/types/#custom-types
425-
# for more details about TypeAdapter
426-
adapter = TypeAdapter(field_type, config=self.model_config)
427-
value = adapter.validate_python(value, context={'name': name})
428-
except ValidationError as e:
429-
error = e.errors()[0]
430-
msg = f"Error while assigning value to variable '{name}':\n"
431-
if error['type'] == 'is_instance_of':
432-
msg += error['msg']
433-
msg += f". Got input value of type '{type(value).__name__}'."
434-
raise TypeError(msg)
435-
if error['type'] == 'value_error':
436-
msg += error['ctx']['error'].args[0]
437-
else:
438-
msg += error['msg']
439-
raise ValueError(msg)
440372

441-
else:
442-
warnings.warn(f"'{name}' is not declared in '{self.__class__.__name__}'",
443-
stacklevel=stacklevel + 1)
373+
model_fields = cls.model_fields
374+
if name not in model_fields:
375+
if model_config['extra']:
376+
warnings.warn(f"'{name}' is not declared in '{cls_name}'",
377+
stacklevel=stacklevel + 1)
378+
return value
379+
else:
380+
raise ValueError(f"Variable '{name}' is not declared in '{cls_name}'. "
381+
f"Adding undeclared variables is forbidden. "
382+
f"List of declared variables is: {list(model_fields.keys())}.")
383+
384+
field_info = model_fields[name]
385+
field_type = field_info.annotation
386+
if field_type is None:
387+
return value
388+
389+
# Annotated[T, x] => field_info.metadata == (x,)
390+
if field_info.metadata:
391+
# recreate the Annotated type that CheckedArray
392+
# initially created, because TypeAdapter needs the
393+
# metadata (the validator function) to actually
394+
# validate more than just the value type. I wonder
395+
# if the type isn't available as-is somewhere in
396+
# the field_info structure...
397+
field_type = Annotated[field_type, *field_info.metadata]
398+
399+
# see https://docs.pydantic.dev/2.12/concepts/types/#custom-types
400+
# for more details about TypeAdapter
401+
adapter = TypeAdapter(field_type, config=self.model_config)
402+
try:
403+
value = adapter.validate_python(value, context={'name': name})
404+
except ValidationError as e:
405+
error = e.errors()[0]
406+
msg = f"Error while assigning value to variable '{name}':\n"
407+
if error['type'] == 'is_instance_of':
408+
msg += error['msg']
409+
msg += f". Got input value of type '{type(value).__name__}'."
410+
raise TypeError(msg)
411+
if error['type'] == 'value_error':
412+
msg += error['ctx']['error'].args[0]
413+
else:
414+
msg += error['msg']
415+
raise ValueError(msg)
444416
return value
445417

446418
def _update_from_iterable(self, it):
447419
for k, v in it:
448420
self.__setitem__(k, v, stacklevel=3)
449421

450422
def __setitem__(self, key, value, skip_frozen=False, skip_validation=False, stacklevel=1):
451-
if key != 'meta':
452-
value = self._check_key_value(key, value, skip_frozen, skip_validation, stacklevel=stacklevel + 1)
453-
# we need to keep the attribute in sync
454-
object.__setattr__(self, key, value)
455-
self._objects[key] = value
423+
if key == 'meta':
424+
raise ValueError(
425+
"Sessions cannot contain any object named 'meta'. "
426+
"To modify the session metadata, use "
427+
"'session.meta = value' instead.")
428+
value = self._check_key_value(key, value, skip_frozen, skip_validation, stacklevel=stacklevel + 1)
429+
# we need to keep the attribute in sync
430+
# TODO: I don't think this is specific to CheckedSession, so either
431+
# we should do it in Session too or not do it here.
432+
object.__setattr__(self, key, value)
433+
self._objects[key] = value
456434

457435
def __setattr__(self, key, value, skip_frozen=False, skip_validation=False, stacklevel=1):
458-
if key != 'meta':
459-
value = self._check_key_value(key, value, skip_frozen, skip_validation, stacklevel=stacklevel + 1)
460-
# we need to keep the attribute in sync
436+
if key == 'meta':
461437
object.__setattr__(self, key, value)
462-
Session.__setattr__(self, key, value)
438+
return
439+
440+
value = self._check_key_value(key, value, skip_frozen, skip_validation, stacklevel=stacklevel + 1)
441+
# we need to keep the attribute in sync
442+
# TODO: I don't think this is specific to CheckedSession, so either
443+
# we should do it in Session too or not do it here.
444+
object.__setattr__(self, key, value)
445+
self._objects[key] = value
463446

464447
def __getstate__(self) -> Dict[str, Any]:
465448
return {'__dict__': self.__dict__}

0 commit comments

Comments
 (0)