1- from abc import ABCMeta
1+ from types import FunctionType
2+ from typing import Type , Any , Dict , Set , Annotated
23import warnings
34
45import numpy as np
56
6- from typing import Type , Any , Dict , Set , no_type_check , Annotated
7-
87from larray .core .axis import AxisCollection
98from larray .core .array import Array , full
109from larray .core .session import Session
@@ -14,6 +13,8 @@ class NotLoaded:
1413 pass
1514
1615
16+ NOT_LOADED = NotLoaded ()
17+
1718try :
1819 import pydantic
1920except ImportError :
@@ -37,9 +38,29 @@ def __init__(self, *args, **kwargs):
3738 raise NotImplementedError ("CheckedParameters class cannot be instantiated "
3839 "because pydantic is not installed" )
3940else :
40- from pydantic import ConfigDict , BeforeValidator , ValidationInfo , TypeAdapter , ValidationError
41+ from pydantic import (
42+ ConfigDict , BeforeValidator , ValidationInfo , TypeAdapter ,
43+ ValidationError , BaseModel
44+ )
4145 from pydantic_core import PydanticUndefined
4246
47+ from pydantic .fields import ComputedFieldInfo
48+
49+ # should more or less match pydantic's default ignored types found
50+ # in pydantic at:
51+ # from pydantic._internal._model_construction import default_ignored_types
52+ # PYDANTIC_IGNORED_TYPES = default_ignored_types()
53+ PYDANTIC_IGNORED_TYPES = (
54+ FunctionType ,
55+ property ,
56+ classmethod ,
57+ staticmethod ,
58+ # PydanticDescriptorProxy,
59+ ComputedFieldInfo ,
60+ # TypeAliasType, # from `typing_extensions`
61+ )
62+
63+
4364 def CheckedArray (axes : AxisCollection , dtype : np .dtype = float ) -> Type [Array ]:
4465 """
4566 Represents a constrained array. It is intended to only be used along with :py:class:`CheckedSession`.
@@ -99,115 +120,37 @@ def validate_array(value: Any, info: ValidationInfo) -> Array:
99120 return Annotated [Array , BeforeValidator (validate_array )]
100121
101122
102- class AbstractCheckedSession :
103- pass
123+ # this is a trick to avoid using pydantic internal API. It is mostly
124+ # equivalent to:
125+ # from pydantic._internal._model_construction import ModelMetaclass
126+ ModelMetaclass = type (BaseModel )
104127
128+ # metaclass to dynamically add type annotations for
129+ # variables defined without type hints in CheckedSession subclasses.
130+ # This allows defining constant class variables (e.g. axes), without having
131+ # to explicitly add type hints, which would feel redundant.
132+ class LArrayModelMetaclass (ModelMetaclass ):
133+ def __new__ (mcs , cls_name : str , bases : tuple [type [Any ], ...],
134+ namespace : dict [str , Any ], ** kwargs ):
105135
106- # Simplified version of the ModelMetaclass class from pydantic:
107- # https://github.com/pydantic/pydantic/blob/v2.12.0/pydantic/_internal/_model_construction.py
108- class ModelMetaclass (ABCMeta ):
109- @no_type_check # noqa C901
110- def __new__ (mcs , cls_name : str , bases : tuple [type [Any ], ...], namespace : dict [str , Any ], ** kwargs : Any ):
111- from pydantic ._internal ._config import ConfigWrapper
112- from pydantic ._internal ._decorators import DecoratorInfos
113- from pydantic ._internal ._namespace_utils import NsResolver
114- from pydantic ._internal ._fields import is_valid_field_name
115- from pydantic ._internal ._model_construction import (inspect_namespace , set_model_fields ,
116- complete_model_class , set_default_hash_func )
117-
136+ # any type hints defined in the class body will land in
137+ # __annotations__ (this is not pydantic-specific) but
138+ # __annotations__ is only defined if there are type hints
118139 raw_annotations = namespace .get ('__annotations__' , {})
119-
120- # tries to infer types for variables without type hints
121- keys_to_infer_type = [key for key in namespace .keys ()
122- if key not in raw_annotations ]
123- keys_to_infer_type = [key for key in keys_to_infer_type
124- if is_valid_field_name (key )]
125- keys_to_infer_type = [key for key in keys_to_infer_type
126- if key not in {'model_config' , 'dict' , 'build' }]
127- keys_to_infer_type = [key for key in keys_to_infer_type
128- if not callable (namespace [key ])]
129- for key in keys_to_infer_type :
130- value = namespace [key ]
131- raw_annotations [key ] = type (value )
132-
133- base_field_names , class_vars , base_private_attributes = mcs ._collect_bases_data (bases )
134-
135- config_wrapper = ConfigWrapper .for_model (bases , namespace , raw_annotations , kwargs )
136- namespace ['model_config' ] = config_wrapper .config_dict
137- private_attributes = inspect_namespace (namespace , raw_annotations , config_wrapper .ignored_types ,
138- class_vars , base_field_names )
139-
140- namespace ['__class_vars__' ] = class_vars
141- namespace ['__private_attributes__' ] = {** base_private_attributes , ** private_attributes }
142-
143- cls = super ().__new__ (mcs , cls_name , bases , namespace , ** kwargs )
144-
145- cls .__pydantic_decorators__ = DecoratorInfos .build (cls )
146- cls .__pydantic_decorators__ .update_from_config (config_wrapper )
147-
148- cls .__pydantic_generic_metadata__ = {'origin' : None , 'args' : (), 'parameters' : None }
149- cls .__pydantic_root_model__ = False
150- cls .__pydantic_complete__ = False
151-
152- # create a copy of raw_annotations since cls.__annotations__ points to it and
153- # cls.__annotations__ must not be polluted before calling set_model_fields() later
154- cls .__fields_annotations__ = {k : v for k , v in raw_annotations .items ()}
155- for base in reversed (bases ):
156- if issubclass (base , AbstractCheckedSession ) and base != AbstractCheckedSession :
157- base_fields_annotations = getattr (base , '__fields_annotations__' , {})
158- for k , v in base_fields_annotations .items ():
159- if k not in cls .__fields_annotations__ :
160- cls .__fields_annotations__ [k ] = v
161-
162- # preserve `__set_name__` protocol defined in https://peps.python.org/pep-0487
163- # for attributes not in `namespace` (e.g. private attributes)
164- for name , obj in private_attributes .items ():
165- obj .__set_name__ (cls , name )
166-
167- ns_resolver = NsResolver ()
168- set_model_fields (cls , config_wrapper = config_wrapper , ns_resolver = ns_resolver )
169- complete_model_class (cls , config_wrapper , ns_resolver , raise_errors = False , call_on_complete_hook = False )
170-
171- if config_wrapper .frozen and '__hash__' not in namespace :
172- set_default_hash_func (cls , bases )
173-
174- return cls
175-
176- @staticmethod
177- def _collect_bases_data (bases : tuple [type [Any ], ...]) -> tuple [set [str ], set [str ], dict [str , Any ]]:
178- from pydantic .fields import ModelPrivateAttr
179-
180- field_names : set [str ] = set ()
181- class_vars : set [str ] = set ()
182- private_attributes : dict [str , ModelPrivateAttr ] = {}
183- for base in bases :
184- if issubclass (base , AbstractCheckedSession ) and base is not AbstractCheckedSession :
185- # model_fields might not be defined yet in the case of generics, so we use getattr here:
186- field_names .update (getattr (base , '__pydantic_fields__' , {}).keys ())
187- class_vars .update (base .__class_vars__ )
188- private_attributes .update (base .__private_attributes__ )
189- return field_names , class_vars , private_attributes
190-
191- @property
192- def __pydantic_fields_complete__ (self ) -> bool :
193- """Whether the fields where successfully collected (i.e. type hints were successfully resolves).
194-
195- This is a private attribute, not meant to be used outside Pydantic.
196- """
197- if '__pydantic_fields__' not in self .__dict__ :
198- return False
199-
200- field_infos = self .__pydantic_fields__
201- return all (field_info ._complete for field_info in field_infos .values ())
202-
203- def __dir__ (self ) -> list [str ]:
204- attributes = list (super ().__dir__ ())
205- if '__fields__' in attributes :
206- attributes .remove ('__fields__' )
207- return attributes
208-
209-
210- class CheckedSession (Session , AbstractCheckedSession , metaclass = ModelMetaclass ):
140+ type_annotations = {
141+ key : type (value )
142+ for key , value in namespace .items ()
143+ if not (key in raw_annotations or
144+ key .startswith ('_' ) or
145+ isinstance (value , PYDANTIC_IGNORED_TYPES ))
146+ }
147+ if type_annotations :
148+ namespace = namespace .copy ()
149+ namespace ['__annotations__' ] = raw_annotations | type_annotations
150+ return super ().__new__ (mcs , cls_name , bases , namespace )
151+
152+
153+ class CheckedSession (Session , BaseModel , metaclass = LArrayModelMetaclass ):
211154 """
212155 Class intended to be inherited by user defined classes in which the variables of a model are declared.
213156 Each declared variable is constrained by a type defined explicitly or deduced from the given default value
@@ -374,10 +317,16 @@ class CheckedSession(Session, AbstractCheckedSession, metaclass=ModelMetaclass):
374317 dumping population ... done
375318 dumping undeclared_var ... done
376319 """
377- model_config = ConfigDict (arbitrary_types_allowed = True , validate_default = True , extra = 'allow' ,
378- validate_assignment = True , frozen = False )
320+ model_config = ConfigDict (
321+ arbitrary_types_allowed = True ,
322+ validate_default = True ,
323+ extra = 'allow' ,
324+ validate_assignment = True ,
325+ frozen = False
326+ )
379327
380328 def __init__ (self , * args , meta = None , ** kwargs ):
329+ # initialize an empty Session
381330 Session .__init__ (self , meta = meta )
382331
383332 # create an intermediate Session object to not call the __setattr__
@@ -386,14 +335,15 @@ def __init__(self, *args, meta=None, **kwargs):
386335 # TODO: refactor Session.load() to use a private function which returns the handler directly
387336 # so that we can get the items out of it and avoid this
388337 input_data = dict (Session (* args , ** kwargs ))
389-
390338 # --- declared variables
391- for name , field in self .__pydantic_fields__ .items ():
392- value = input_data .pop (name , NotLoaded () )
339+ for name , field in self .__class__ . model_fields .items ():
340+ value = input_data .pop (name , NOT_LOADED )
393341
394- if isinstance ( value , NotLoaded ) :
342+ if value is NOT_LOADED :
395343 if field .default is PydanticUndefined :
396- warnings .warn (f"No value passed for the declared variable '{ name } '" , stacklevel = 2 )
344+ warnings .warn (f"No value passed for the declared variable '{ name } '" ,
345+ stacklevel = 2 )
346+ # we actually use NOT_LOADED as the value
397347 self .__setattr__ (name , value , skip_frozen = True , skip_validation = True )
398348 else :
399349 self .__setattr__ (name , field .default , skip_frozen = True )
@@ -405,61 +355,94 @@ def __init__(self, *args, meta=None, **kwargs):
405355 self .__setattr__ (name , value , skip_frozen = True , stacklevel = 2 )
406356
407357 # code of the method below has been partly borrowed from pydantic.BaseModel.__setattr__()
408- def _check_key_value (self , name : str , value : Any , skip_frozen : bool , skip_validation : bool ,
358+ def _check_key_value (self , name : str ,
359+ value : Any ,
360+ skip_frozen : bool ,
361+ skip_validation : bool ,
409362 stacklevel : int ) -> Any :
410- config = self .model_config
411- if not config ['extra' ] and name not in self .__pydantic_fields__ :
412- raise ValueError (f"Variable '{ name } ' is not declared in '{ self .__class__ .__name__ } '. "
413- f"Adding undeclared variables is forbidden. "
414- f"List of declared variables is: { list (self .__pydantic_fields__ .keys ())} ." )
415- if not skip_frozen and config ['frozen' ]:
416- raise TypeError (f"Cannot change the value of the variable '{ name } ' since '{ self .__class__ .__name__ } ' "
363+ if skip_validation :
364+ return value
365+
366+ cls = self .__class__
367+ cls_name = cls .__name__
368+ model_config = cls .model_config
369+ if model_config ['frozen' ] and not skip_frozen :
370+ raise TypeError (f"Cannot change the value of the variable '{ name } ' since '{ cls_name } ' "
417371 f"is immutable and does not support item assignment" )
418- if name in self .__pydantic_fields__ :
419- if not skip_validation :
420- try :
421- field_type = self .__fields_annotations__ .get (name , None )
422- if field_type is None :
423- return value
424- # see https://docs.pydantic.dev/2.12/concepts/types/#custom-types
425- # for more details about TypeAdapter
426- adapter = TypeAdapter (field_type , config = self .model_config )
427- value = adapter .validate_python (value , context = {'name' : name })
428- except ValidationError as e :
429- error = e .errors ()[0 ]
430- msg = f"Error while assigning value to variable '{ name } ':\n "
431- if error ['type' ] == 'is_instance_of' :
432- msg += error ['msg' ]
433- msg += f". Got input value of type '{ type (value ).__name__ } '."
434- raise TypeError (msg )
435- if error ['type' ] == 'value_error' :
436- msg += error ['ctx' ]['error' ].args [0 ]
437- else :
438- msg += error ['msg' ]
439- raise ValueError (msg )
440372
441- else :
442- warnings .warn (f"'{ name } ' is not declared in '{ self .__class__ .__name__ } '" ,
443- stacklevel = stacklevel + 1 )
373+ model_fields = cls .model_fields
374+ if name not in model_fields :
375+ if model_config ['extra' ]:
376+ warnings .warn (f"'{ name } ' is not declared in '{ cls_name } '" ,
377+ stacklevel = stacklevel + 1 )
378+ return value
379+ else :
380+ raise ValueError (f"Variable '{ name } ' is not declared in '{ cls_name } '. "
381+ f"Adding undeclared variables is forbidden. "
382+ f"List of declared variables is: { list (model_fields .keys ())} ." )
383+
384+ field_info = model_fields [name ]
385+ field_type = field_info .annotation
386+ if field_type is None :
387+ return value
388+
389+ # Annotated[T, x] => field_info.metadata == (x,)
390+ if field_info .metadata :
391+ # recreate the Annotated type that CheckedArray
392+ # initially created, because TypeAdapter needs the
393+ # metadata (the validator function) to actually
394+ # validate more than just the value type. I wonder
395+ # if the type isn't available as-is somewhere in
396+ # the field_info structure...
397+ field_type = Annotated [field_type , * field_info .metadata ]
398+
399+ # see https://docs.pydantic.dev/2.12/concepts/types/#custom-types
400+ # for more details about TypeAdapter
401+ adapter = TypeAdapter (field_type , config = self .model_config )
402+ try :
403+ value = adapter .validate_python (value , context = {'name' : name })
404+ except ValidationError as e :
405+ error = e .errors ()[0 ]
406+ msg = f"Error while assigning value to variable '{ name } ':\n "
407+ if error ['type' ] == 'is_instance_of' :
408+ msg += error ['msg' ]
409+ msg += f". Got input value of type '{ type (value ).__name__ } '."
410+ raise TypeError (msg )
411+ if error ['type' ] == 'value_error' :
412+ msg += error ['ctx' ]['error' ].args [0 ]
413+ else :
414+ msg += error ['msg' ]
415+ raise ValueError (msg )
444416 return value
445417
446418 def _update_from_iterable (self , it ):
447419 for k , v in it :
448420 self .__setitem__ (k , v , stacklevel = 3 )
449421
450422 def __setitem__ (self , key , value , skip_frozen = False , skip_validation = False , stacklevel = 1 ):
451- if key != 'meta' :
452- value = self ._check_key_value (key , value , skip_frozen , skip_validation , stacklevel = stacklevel + 1 )
453- # we need to keep the attribute in sync
454- object .__setattr__ (self , key , value )
455- self ._objects [key ] = value
423+ if key == 'meta' :
424+ raise ValueError (
425+ "Sessions cannot contain any object named 'meta'. "
426+ "To modify the session metadata, use "
427+ "'session.meta = value' instead." )
428+ value = self ._check_key_value (key , value , skip_frozen , skip_validation , stacklevel = stacklevel + 1 )
429+ # we need to keep the attribute in sync
430+ # TODO: I don't think this is specific to CheckedSession, so either
431+ # we should do it in Session too or not do it here.
432+ object .__setattr__ (self , key , value )
433+ self ._objects [key ] = value
456434
457435 def __setattr__ (self , key , value , skip_frozen = False , skip_validation = False , stacklevel = 1 ):
458- if key != 'meta' :
459- value = self ._check_key_value (key , value , skip_frozen , skip_validation , stacklevel = stacklevel + 1 )
460- # we need to keep the attribute in sync
436+ if key == 'meta' :
461437 object .__setattr__ (self , key , value )
462- Session .__setattr__ (self , key , value )
438+ return
439+
440+ value = self ._check_key_value (key , value , skip_frozen , skip_validation , stacklevel = stacklevel + 1 )
441+ # we need to keep the attribute in sync
442+ # TODO: I don't think this is specific to CheckedSession, so either
443+ # we should do it in Session too or not do it here.
444+ object .__setattr__ (self , key , value )
445+ self ._objects [key ] = value
463446
464447 def __getstate__ (self ) -> Dict [str , Any ]:
465448 return {'__dict__' : self .__dict__ }
0 commit comments