Skip to content

Commit 02b8633

Browse files
committed
FIX: fixed calling methods on CheckedSession subclasses
(thanks to Alix for finding the solution to this)
1 parent c807928 commit 02b8633

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

larray/core/checked.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,14 @@ def __new__(mcs, cls_name: str, bases: tuple[type[Any], ...], namespace: dict[st
116116
raw_annotations = namespace.get('__annotations__', {})
117117

118118
# tries to infer types for variables without type hints
119-
keys_to_infer_type = [key for key in namespace.keys() if key not in raw_annotations]
120-
keys_to_infer_type = [key for key in keys_to_infer_type if is_valid_field_name(key)]
121-
keys_to_infer_type = [key for key in keys_to_infer_type if key not in {'model_config', 'dict'}]
119+
keys_to_infer_type = [key for key in namespace.keys()
120+
if key not in raw_annotations]
121+
keys_to_infer_type = [key for key in keys_to_infer_type
122+
if is_valid_field_name(key)]
123+
keys_to_infer_type = [key for key in keys_to_infer_type
124+
if key not in {'model_config', 'dict', 'build'}]
125+
keys_to_infer_type = [key for key in keys_to_infer_type
126+
if not callable(namespace[key])]
122127
for key in keys_to_infer_type:
123128
value = namespace[key]
124129
raw_annotations[key] = type(value)

larray/tests/test_checked_session.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,5 +715,24 @@ def test_neg_cs(checkedsession):
715715
assert_array_nan_equal(neg_cs.h, -h)
716716

717717

718+
def test_checked_class_with_methods():
719+
a = Axis('a=a0,a1')
720+
721+
class CheckedSessionWithMethods(CheckedSession):
722+
arr: CheckedArray(a)
723+
724+
# Define a method which already exists in Session/CheckedSession
725+
def save(self, path=None, **kwargs):
726+
super().save(path, **kwargs)
727+
728+
def new_method(self):
729+
return True
730+
731+
array = ndtest(a)
732+
733+
cs = CheckedSessionWithMethods(arr=array)
734+
assert cs.new_method()
735+
736+
718737
if __name__ == "__main__":
719738
pytest.main()

0 commit comments

Comments
 (0)