From 5cc49e544e074c590db485f8da894730a2abbb7f Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 25 Jan 2023 17:04:56 -0500 Subject: [PATCH 001/205] fix setup.py confict --- python/lib/core/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/core/setup.py b/python/lib/core/setup.py index 3f485cccf..e69c55a56 100644 --- a/python/lib/core/setup.py +++ b/python/lib/core/setup.py @@ -20,6 +20,6 @@ author_email='', url='', license='', - install_requires=[], + install_requires=["pydantic"], packages=find_namespace_packages(exclude=['dmod.test', 'schemas', 'ssl', 'src']) ) From 7ad970053072fd0e54358429da51d95eee9f2088 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 13:32:51 -0500 Subject: [PATCH 002/205] add PydanticEnum. validated by enum member name. Subtypes of this enum variant that are embedded in a pydantic model will be: - coerced into an enum instance using member name (case insensitive) - and expose member names (upper case) in model json schema. --- python/lib/core/dmod/core/enum.py | 85 +++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 python/lib/core/dmod/core/enum.py diff --git a/python/lib/core/dmod/core/enum.py b/python/lib/core/dmod/core/enum.py new file mode 100644 index 000000000..221664d21 --- /dev/null +++ b/python/lib/core/dmod/core/enum.py @@ -0,0 +1,85 @@ +from enum import Enum +from pydantic.fields import ModelField +from pprint import pformat + +from typing import Any, Dict, Union + +# inspiration from https://github.com/pydantic/pydantic/issues/598 +class PydanticEnum(Enum): + """ + Subtypes of this enum variant that are embedded in a pydantic model will be: + - coerced into an enum instance using member name (case insensitive) + - and expose member names (upper case) in model json schema. + + + Example: + ```python + class PowerState(PydanticEnum): + OFF = 0 + ON = 1 + + class Appliance(pydantic.BaseModel): + power_state: PowerState + ... + + Appliance(power_state=PowerState.ON) + Appliance(power_state="ON") + Appliance(power_state="on") + + Appliance(power_state=1) # invalid + ``` + + Note, `PydanticEnum` subtypes with member names that case-intensively match will yield + undesirable behavior. + """ + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any], field: ModelField) -> None: + """Method used by pydantic to populate json schema fields and their associated types.""" + # display enum field names as field options + if "enum" in field_schema: + field_schema["enum"] = [f.name.upper() for f in field.type_] + field_schema["type"] = "string" + + @classmethod + def __get_validators__(cls): + """Method used by pydantic to retrieve a class's validators.""" + yield cls.validate + + @classmethod + def validate(cls, v: Union[Enum, str]): + """ + Method used by pydantic to validate and potentially coerce a `v` into a `cls` enum type. + + Coercion from a `str` into a `cls` enum instance is performed _case-insensitively_ based on + the `cls` enum's `name` fields. For example, enum Foo with member `bar = 1` is coercible by + providing `"bar"`, _not_ `1`. + + Example: + ```python + class Foo(PydanticEnum): + bar = 1 + + class Model(pydantic.BaseModel): + foo: Foo + + Model(foo=Foo.bar) # valid + Model(foo="bar") # valid + Model(foo="BAR") # valid + + Model(foo=1) # invalid + ``` + """ + if isinstance(v, cls): + return v + + v = str(v).upper() + + for name, value in cls.__members__.items(): + if name.upper() == v: + return value + + error_message = pformat( + f"Invalid Enum field. Field {v!r} is not a member of {set(cls.__members__)}" + ) + raise ValueError(error_message) From f931db62be05e9b10e4d200858b701ac8075e21a Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 13:38:45 -0500 Subject: [PATCH 003/205] add PydanticEnum unittests --- python/lib/core/dmod/test/test_enum.py | 44 ++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 python/lib/core/dmod/test/test_enum.py diff --git a/python/lib/core/dmod/test/test_enum.py b/python/lib/core/dmod/test/test_enum.py new file mode 100644 index 000000000..8d6e6834c --- /dev/null +++ b/python/lib/core/dmod/test/test_enum.py @@ -0,0 +1,44 @@ +import unittest +import enum +from pydantic import BaseModel + +from ..core.enum import PydanticEnum + + +class SomeEnum(PydanticEnum): + foo = 1 + bar = 2 + baz = 3 + + +class SomeModel(BaseModel): + some_enum: SomeEnum + + +class TestEnumValidateByNameMixIn(unittest.TestCase): + def test_instantiate_model_with_enum_field_name(self): + model = SomeModel(some_enum="foo") + self.assertEqual(model.some_enum, SomeEnum.foo) + + def test_instantiate_model_with_enum_instance(self): + model = SomeModel(some_enum=SomeEnum.foo) + self.assertEqual(model.some_enum, SomeEnum.foo) + + def test_raises_ValueError_instantiate_model_with_bad_enum_field_name(self): + with self.assertRaises(ValueError): + SomeModel(some_enum="missing_field") + + def test_raises_ValueError_instantiate_model_with_bad_enum_instance(self): + class BadEnum(enum.Enum): + bad = 1 + + with self.assertRaises(ValueError): + SomeModel(some_enum=BadEnum.bad) + + def test_enum_names_in_json_schema(self): + schema = SomeModel.schema() + some_enum_schema = schema["definitions"]["SomeEnum"] + self.assertEqual(some_enum_schema["type"], "string") + + enum_field_names = [member.name.upper() for member in SomeEnum] + self.assertListEqual(enum_field_names, some_enum_schema["enum"]) From 899b8a0350428953efe7dda40b330ef78ae0fb4c Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 14:28:46 -0500 Subject: [PATCH 004/205] Serializable subclasses pydantic.BaseModel and update docs --- python/lib/core/dmod/core/serializable.py | 82 +++++++++++++++++++++-- 1 file changed, 75 insertions(+), 7 deletions(-) diff --git a/python/lib/core/dmod/core/serializable.py b/python/lib/core/dmod/core/serializable.py index cde509b4d..acd1b0ad2 100644 --- a/python/lib/core/dmod/core/serializable.py +++ b/python/lib/core/dmod/core/serializable.py @@ -1,14 +1,28 @@ from abc import ABC, abstractmethod from numbers import Number -from typing import Callable, Dict, Type, Union +from enum import Enum +from typing import Any, Callable, ClassVar, Dict, Type, TYPE_CHECKING, Union, Optional +from pydantic import BaseModel import json +if TYPE_CHECKING: + from pydantic.typing import ( + AbstractSetIntStr, + MappingIntStrAny, + ) -class Serializable(ABC): + +class Serializable(BaseModel, ABC): """ An interface class for an object that can be serialized to a dictionary-like format (i.e., potentially a JSON object) and JSON string format based directly from dumping the aforementioned dictionary-like representation. + Subtypes of `Serializable` should specify their fields following + [`pydantic.BaseModel`](https://docs.pydantic.dev/usage/models/) semantics (see example below). + Notably, `to_dict` and `to_json` will exclude `None` fields and serialize fields using any + provided aliases (i.e. `pydantic.Field(alias="some_alias")`). Also, enum subtypes are + serialized using their member `name` property. + Objects of this type will also used the JSON string format as their default string representation. While not strictly enforced (because this probably isn't possible), it is HIGHLY recommended that instance @@ -24,9 +38,30 @@ class Serializable(ABC): its ::attribute:`_SERIAL_DATETIME_STR_FORMAT` class attribute. Note that the actual parsing/serialization logic is left entirely to the subtypes, as many will not need it (and thus should not have to worry about implement another method or have their superclass bloated by importing the ``datetime`` package). + + Example: + ``` + # specify field as class variable, specify final type using type hint. + # pydantic will try to coerce a field into the specified type, if it can't, a + # `pydantic.ValidationError` is raised. + + class User(Serializable): + id: int + username: str + email: str # more appropriately, `pydantic.EmailStr` + + >>> user = User(id=1, username="uncle_sam", email="uncle_sam@fake.gov") + >>> user.to_dict() # {"id": 1, "username": "uncle_sam", "email": "uncle_sam@fake.gov"} + >>> user.to_json() # '{"id": 1, "username": "uncle_sam", "email": "uncle_sam@fake.gov"}' + ``` """ - _SERIAL_DATETIME_STR_FORMAT = '%Y-%m-%d %H:%M:%S' + _SERIAL_DATETIME_STR_FORMAT: ClassVar[str] = '%Y-%m-%d %H:%M:%S' + + # global pydantic options + class Config: + # fields can be populated using their given name or provided alias + allow_population_by_field_name = True @classmethod def _get_invalid_type_message(cls): @@ -158,11 +193,11 @@ def parse_simple_serialized(cls, json_obj: dict, key: str, expected_type: Type, # If we get this far, then return the converted value return converted_value - @abstractmethod def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: """ Get the representation of this instance as a serialized dictionary or dictionary-like object (e.g., a JSON - object). + object). Field's are serialized using an alias, if provided. Field's that are `None` are + excluded from serialization. Since the returned value must be serializable and JSON-like, key and value types are restricted. In particular, the returned value type, which this docstring will call ``D``, must adhere to the criteria defined below: @@ -180,7 +215,7 @@ def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: The representation of this instance as a serialized dictionary or dictionary-like object, with valid types of keys and values. """ - pass + return self.dict(exclude_none=True, by_alias=True) def __str__(self): return str(self.to_json()) @@ -196,6 +231,39 @@ def to_json(self) -> str: """ return json.dumps(self.to_dict(), sort_keys=True) + @classmethod + def _get_value( + cls, + v: Any, + to_dict: bool, + by_alias: bool, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]], + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]], + exclude_unset: bool, + exclude_defaults: bool, + exclude_none: bool, + ) -> Any: + """ + Method used by pydantic to serialize field values. + + Override how `enum.Enum` subclasses are serialized by pydantic. Enums are serialized using + their member name, not their value. + """ + # serialize enum's using their name property + if isinstance(v, Enum) and not getattr(cls.Config, "use_enum_values", False): + return v.name + + return super()._get_value( + v, + to_dict=to_dict, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + include=include, + exclude=exclude, + exclude_none=exclude_none, + ) + class SerializedDict(Serializable): """ @@ -263,4 +331,4 @@ def factory_init_from_deserialized_json(cls, json_obj: dict): return None def __init__(self, *args, **kwargs): - super(BasicResultIndicator, self).__init__(*args, **kwargs) \ No newline at end of file + super(BasicResultIndicator, self).__init__(*args, **kwargs) From 65668333ae28e33534816218f31f84a680d88e39 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 14:34:54 -0500 Subject: [PATCH 005/205] add deprecated decorator function --- .../lib/core/dmod/core/decorators/__init__.py | 1 + .../dmod/core/decorators/decorator_functions.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/python/lib/core/dmod/core/decorators/__init__.py b/python/lib/core/dmod/core/decorators/__init__.py index da8bc7c13..82b1d09bb 100644 --- a/python/lib/core/dmod/core/decorators/__init__.py +++ b/python/lib/core/dmod/core/decorators/__init__.py @@ -8,6 +8,7 @@ from .decorator_functions import initializer from .decorator_functions import additional_parameter +from .decorator_functions import deprecated from .message_handlers import socket_handler from .message_handlers import client_message_handler diff --git a/python/lib/core/dmod/core/decorators/decorator_functions.py b/python/lib/core/dmod/core/decorators/decorator_functions.py index 4af7abb25..13f3e25d3 100644 --- a/python/lib/core/dmod/core/decorators/decorator_functions.py +++ b/python/lib/core/dmod/core/decorators/decorator_functions.py @@ -2,6 +2,8 @@ Defines common decorators """ import typing +from warnings import warn +from functools import wraps from .decorator_constants import * @@ -77,3 +79,18 @@ def additional_parameter(function): if not hasattr(function, ADDITIONAL_PARAMETER_ATTRIBUTE): setattr(function, ADDITIONAL_PARAMETER_ATTRIBUTE, True) return function + + +def deprecated(deprecation_message: str): + def function_to_deprecate(fn): + + @wraps(fn) + def wrapper(*args, **kwargs): + warn(deprecation_message, DeprecationWarning) + return fn(*args, **kwargs) + + return wrapper + + return function_to_deprecate + + From e53390d63cb2260424af702bf836e9b059a2f1ac Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 14:35:19 -0500 Subject: [PATCH 006/205] add unit test to verify deprecated warning is raised by decorator --- python/lib/core/dmod/test/test_decorator.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 python/lib/core/dmod/test/test_decorator.py diff --git a/python/lib/core/dmod/test/test_decorator.py b/python/lib/core/dmod/test/test_decorator.py new file mode 100644 index 000000000..02b1fd4d2 --- /dev/null +++ b/python/lib/core/dmod/test/test_decorator.py @@ -0,0 +1,13 @@ +import unittest +from ..core.decorators import deprecated + +DEPRECATION_MESSAGE = "test is deprecated" + +@deprecated(DEPRECATION_MESSAGE) +def deprecated_function(): + ... + +class TestDeprecatedDecorator(unittest.TestCase): + def test_raises_deprecated_warning(self): + with self.assertWarns(DeprecationWarning): + deprecated_function() From 5a877f1dc2067688fdb1c910982e0e249f536133 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 14:37:43 -0500 Subject: [PATCH 007/205] add deprecation warning to parse_simple_serialized --- python/lib/core/dmod/core/serializable.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/lib/core/dmod/core/serializable.py b/python/lib/core/dmod/core/serializable.py index acd1b0ad2..6bec05433 100644 --- a/python/lib/core/dmod/core/serializable.py +++ b/python/lib/core/dmod/core/serializable.py @@ -5,6 +5,8 @@ from pydantic import BaseModel import json +from .decorators import deprecated + if TYPE_CHECKING: from pydantic.typing import ( AbstractSetIntStr, @@ -100,6 +102,7 @@ def get_datetime_str_format(cls): return cls._SERIAL_DATETIME_STR_FORMAT @classmethod + @deprecated("In the future this will be removed. Use pydantic type hints, validators, or root validators instead.") def parse_simple_serialized(cls, json_obj: dict, key: str, expected_type: Type, required_present: bool = True, converter: Callable = None): """ From c399880dd4d8db5de2b96a6a458f37caeb915713 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 14:46:11 -0500 Subject: [PATCH 008/205] fix docstring typo --- python/lib/core/dmod/core/serializable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/core/dmod/core/serializable.py b/python/lib/core/dmod/core/serializable.py index 6bec05433..2c194d6fc 100644 --- a/python/lib/core/dmod/core/serializable.py +++ b/python/lib/core/dmod/core/serializable.py @@ -36,7 +36,7 @@ class Serializable(BaseModel, ABC): An exception to the aforementioned recommendation is the ::class:`datetime.datetime` type. Subtype attributes of ::class:`datetime.datetime` type should be parsed and serialized using the pattern returned by the ::method:`get_datetime_str_format` class method. A reasonable default is provided in the base interface class, but - the pattern can be adjusted eitehr by overriding the class method directly or by having a subtypes set/override + the pattern can be adjusted either by overriding the class method directly or by having a subtypes set/override its ::attribute:`_SERIAL_DATETIME_STR_FORMAT` class attribute. Note that the actual parsing/serialization logic is left entirely to the subtypes, as many will not need it (and thus should not have to worry about implement another method or have their superclass bloated by importing the ``datetime`` package). From d92bfb3ac9c523acd94edf9e0485e9568e8d09e9 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 14:49:33 -0500 Subject: [PATCH 009/205] refactor SerializedDict, ResultIndicator, and BasicResultIndicator --- python/lib/core/dmod/core/serializable.py | 33 ++++++----------------- 1 file changed, 8 insertions(+), 25 deletions(-) diff --git a/python/lib/core/dmod/core/serializable.py b/python/lib/core/dmod/core/serializable.py index 2c194d6fc..fe78a23ba 100644 --- a/python/lib/core/dmod/core/serializable.py +++ b/python/lib/core/dmod/core/serializable.py @@ -2,7 +2,7 @@ from numbers import Number from enum import Enum from typing import Any, Callable, ClassVar, Dict, Type, TYPE_CHECKING, Union, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field import json from .decorators import deprecated @@ -272,16 +272,11 @@ class SerializedDict(Serializable): """ A basic encapsulation of a dictionary as a ::class:`Serializable`. """ + base_dict: dict @classmethod def factory_init_from_deserialized_json(cls, json_obj: dict): - return cls(json_obj) - - def __init__(self, base_dict: dict): - self.base_dict = base_dict - - def to_dict(self) -> dict: - return self.base_dict + return cls(**json_obj) class ResultIndicator(Serializable, ABC): @@ -307,18 +302,9 @@ class ResultIndicator(Serializable, ABC): An optional, more detailed explanation of the result, which by default is an empty string. """ - - def __init__(self, success: bool, reason: str, message: str = '', *args, **kwargs): - super(ResultIndicator, self).__init__(*args, **kwargs) - self.success: bool = success - """ Whether this indicates a successful result. """ - self.reason: str = reason - """ A very short, high-level summary of the result. """ - self.message: str = message - """ An optional, more detailed explanation of the result, which by default is an empty string. """ - - def to_dict(self) -> dict: - return {'success': self.success, 'reason': self.reason, 'message': self.message} + success: bool = Field(description="Whether this indicates a successful result.") + reason: str = Field(description="A very short, high-level summary of the result.") + message: str = Field("", description="An optional, more detailed explanation of the result, which by default is an empty string.") class BasicResultIndicator(ResultIndicator): @@ -329,9 +315,6 @@ class BasicResultIndicator(ResultIndicator): @classmethod def factory_init_from_deserialized_json(cls, json_obj: dict): try: - return cls(success=json_obj['success'], reason=json_obj['reason'], message=json_obj['message']) - except Exception as e: + return cls(**json_obj) + except Exception: return None - - def __init__(self, *args, **kwargs): - super(BasicResultIndicator, self).__init__(*args, **kwargs) From 7f6b917df2c6a920ca7baed1b087261166671e4e Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 14:50:55 -0500 Subject: [PATCH 010/205] refactor AllocationParadigm --- python/lib/core/dmod/core/execution.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/lib/core/dmod/core/execution.py b/python/lib/core/dmod/core/execution.py index f8f99f50b..9286b037b 100644 --- a/python/lib/core/dmod/core/execution.py +++ b/python/lib/core/dmod/core/execution.py @@ -1,8 +1,9 @@ -from enum import Enum from typing import Optional +from .enum import PydanticEnum -class AllocationParadigm(Enum): + +class AllocationParadigm(PydanticEnum): """ Representation of the ways compute assets may be combined to fulfill a total required asset amount for a task. From 744d62bc6b448521b55a22cb59c1f22c56d8b811 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 14:52:33 -0500 Subject: [PATCH 011/205] refactor meta_data enums --- python/lib/core/dmod/core/meta_data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index 9c74c9fe0..1bfb04195 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -1,13 +1,13 @@ -from enum import Enum from datetime import datetime +from .enum import PydanticEnum from .serializable import Serializable from numbers import Number from typing import Any, Dict, List, Optional, Set, Type, Union from collections.abc import Iterable -class StandardDatasetIndex(Enum): +class StandardDatasetIndex(PydanticEnum): UNKNOWN = (-1, Any) TIME = (0, datetime) @@ -35,7 +35,7 @@ def get_for_name(cls, name_str: str) -> 'StandardDatasetIndex': return StandardDatasetIndex.UNKNOWN -class DataFormat(Enum): +class DataFormat(PydanticEnum): """ Supported data format types for data needed or produced by workflow execution tasks. @@ -710,7 +710,7 @@ def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: return serial -class DataCategory(Enum): +class DataCategory(PydanticEnum): """ The general category values for different data. """ From 9b415b2b22f1827e865b52286499de5b81084fa2 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 15:03:29 -0500 Subject: [PATCH 012/205] refactor ContinuousRestriction to use pydantic --- python/lib/core/dmod/core/meta_data.py | 100 ++++++++++++------------- 1 file changed, 46 insertions(+), 54 deletions(-) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index 1bfb04195..37504e815 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -5,6 +5,7 @@ from numbers import Number from typing import Any, Dict, List, Optional, Set, Type, Union from collections.abc import Iterable +from pydantic import root_validator, validator, PyObject, Field class StandardDatasetIndex(PydanticEnum): @@ -34,6 +35,11 @@ def get_for_name(cls, name_str: str) -> 'StandardDatasetIndex': return value return StandardDatasetIndex.UNKNOWN +def _validate_variable_is_known(cls, variable: StandardDatasetIndex) -> StandardDatasetIndex: + if variable == StandardDatasetIndex.UNKNOWN: + raise ValueError("Invalid value for {} variable: {}".format(cls.__name__, variable)) + return variable + class DataFormat(PydanticEnum): """ @@ -215,6 +221,37 @@ class ContinuousRestriction(Serializable): """ A filtering component, typically applied as a restriction on a domain, by a continuous range of values of a variable. """ + variable: StandardDatasetIndex + begin: datetime + end: datetime + datetime_pattern: Optional[str] + subclass: Optional[PyObject] = Field(exclude=True) + + @root_validator(pre=True) + def coerce_times_if_datetime_pattern(cls, values): + datetime_ptr = values.get("datetime_pattern") + + if datetime_ptr is not None: + # If there is a datetime pattern, then expect begin and end to parse properly to datetime objects + begin = values["begin"] + end = values["end"] + + if not isinstance(begin, datetime): + values["begin"] = datetime.strptime(begin, datetime_ptr) + + if not isinstance(end, datetime): + values["end"] = datetime.strptime(end, datetime_ptr) + return values + + @root_validator() + def validate_start_before_end(cls, values): + if values["begin"] > values["end"]: + raise RuntimeError("Cannot have {} with begin value larger than end.".format(cls.__name__)) + + return values + + # validate variable is not UNKNOWN variant + _validate_variable = validator("variable", allow_reuse=True)(_validate_variable_is_known) @classmethod def convert_truncated_serial_form(cls, truncated_json_obj: dict, datetime_format: Optional[str] = None) -> dict: @@ -251,61 +288,21 @@ def convert_truncated_serial_form(cls, truncated_json_obj: dict, datetime_format @classmethod def factory_init_from_deserialized_json(cls, json_obj: dict): - datetime_ptr = json_obj["datetime_pattern"] if "datetime_pattern" in json_obj else None - try: - variable = StandardDatasetIndex.get_for_name(json_obj['variable']) - if variable == StandardDatasetIndex.UNKNOWN: - raise RuntimeError( - "Unrecognized continuous restriction serialize variable: {}".format(json_obj['variable'])) - # Handle simple case, which currently means non-datetime item (i.e., no pattern included) - if datetime_ptr is None: - return cls(variable=variable, begin=json_obj["begin"], end=json_obj["end"]) - - # If there is a datetime pattern, then expect begin and end to parse properly to datetime objects - begin = datetime.strptime(json_obj["begin"], datetime_ptr) - end = datetime.strptime(json_obj["end"], datetime_ptr) - - # Use this type if that's what the JSON specifies is the Serializable subtype - if cls.__name__ == json_obj["subclass"]: - return cls(variable=variable, begin=begin, end=end, datetime_pattern=datetime_ptr) - - # Try to initialize the right subclass type, or fall back if appropriate to the base type - # TODO: consider adding something for recursive search for subclass, not just immediate children types - # Use nested try, because we want to fall back to cls type if no subclass attempt or subclass attempt fails + if "subclass" in json_obj: try: for subclass in cls.__subclasses__(): if subclass.__name__ == json_obj["subclass"]: - return subclass(variable=variable, begin=begin, end=end, datetime_pattern=datetime_ptr) + return subclass(**json_obj) except: pass - # Fall back if needed - return cls(variable=variable, begin=begin, end=end, datetime_pattern=datetime_ptr) + try: + return cls(**json_obj) except: return None - def __init__(self, variable: Union[str, StandardDatasetIndex], begin, end, datetime_pattern: Optional[str] = None): - self.variable = StandardDatasetIndex.get_for_name(variable) if isinstance(variable, str) else variable - if self.variable == StandardDatasetIndex.UNKNOWN: - raise ValueError("Invalid value for {} variable: {}".format(self.__class__.__name__, variable)) - if begin > end: - raise RuntimeError("Cannot have {} with begin value larger than end.".format(self.__class__.__name__)) - self.begin = begin - self.end = end - self._datetime_pattern = datetime_pattern - - def __eq__(self, other): - if self.__class__ == other.__class__ or isinstance(other, self.__class__): - return self.variable == other.variable and self.begin == other.begin and self.end == other.end \ - and self._datetime_pattern == other._datetime_pattern - elif isinstance(self, other.__class__): - return other.__eq__(self) - else: - return False - def __hash__(self): - str_func = lambda x: str(x) if self._datetime_pattern is None else datetime.strptime(x, self._datetime_pattern) - hash('{}-{}-{}'.format(self.variable.name, str_func(self.begin), str_func(self.end))) + return hash('{}-{}-{}'.format(self.variable.name, self.begin, self.end)) def contains(self, other: 'ContinuousRestriction') -> bool: """ @@ -330,16 +327,11 @@ def contains(self, other: 'ContinuousRestriction') -> bool: return self.begin <= other.begin and self.end >= other.end def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serial = dict() - serial["variable"] = self.variable.name + serial = self.dict(exclude_none=True) serial["subclass"] = self.__class__.__name__ - if self._datetime_pattern is not None: - serial["datetime_pattern"] = self._datetime_pattern - serial["begin"] = self.begin.strftime(self._datetime_pattern) - serial["end"] = self.end.strftime(self._datetime_pattern) - else: - serial["begin"] = self.begin - serial["end"] = self.end + if self.datetime_pattern is not None: + serial["begin"] = self.begin.strftime(self.datetime_pattern) + serial["end"] = self.end.strftime(self.datetime_pattern) return serial From b633f08a5a60b3f2e190392150a603065f606928 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 15:10:19 -0500 Subject: [PATCH 013/205] add default implimentation for factory_init_from_deserialized_json --- python/lib/core/dmod/core/serializable.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/lib/core/dmod/core/serializable.py b/python/lib/core/dmod/core/serializable.py index fe78a23ba..7e708651f 100644 --- a/python/lib/core/dmod/core/serializable.py +++ b/python/lib/core/dmod/core/serializable.py @@ -1,7 +1,7 @@ -from abc import ABC, abstractmethod +from abc import ABC from numbers import Number from enum import Enum -from typing import Any, Callable, ClassVar, Dict, Type, TYPE_CHECKING, Union, Optional +from typing import Any, Callable, ClassVar, Dict, Type, TypeVar, TYPE_CHECKING, Union, Optional from pydantic import BaseModel, Field import json @@ -13,6 +13,8 @@ MappingIntStrAny, ) +Self = TypeVar("Self", bound="Serializable") + class Serializable(BaseModel, ABC): """ @@ -71,8 +73,7 @@ def _get_invalid_type_message(cls): return invalid_type_msg @classmethod - @abstractmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): + def factory_init_from_deserialized_json(cls: Self, json_obj: dict) -> Optional[Self]: """ Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. @@ -84,7 +85,10 @@ def factory_init_from_deserialized_json(cls, json_obj: dict): ------- A new object of this type instantiated from the deserialize JSON object dictionary """ - pass + try: + return cls(**json_obj) + except: + return None @classmethod def get_datetime_str_format(cls): From e7673fffa5889a3a0930802c5fdc7f4461f21ad4 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 15:15:53 -0500 Subject: [PATCH 014/205] add comment to SerializedDict --- python/lib/core/dmod/core/serializable.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/lib/core/dmod/core/serializable.py b/python/lib/core/dmod/core/serializable.py index 7e708651f..cc3fa8363 100644 --- a/python/lib/core/dmod/core/serializable.py +++ b/python/lib/core/dmod/core/serializable.py @@ -279,7 +279,8 @@ class SerializedDict(Serializable): base_dict: dict @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): + def factory_init_from_deserialized_json(cls: Self, json_obj: dict) -> Self: + # NOTE: could raise. return type has fewer constraints return cls(**json_obj) From 775e642d2dfdc68b17870705934404c489b284f1 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 15:16:10 -0500 Subject: [PATCH 015/205] refactor BasicResultIndicator --- python/lib/core/dmod/core/serializable.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/lib/core/dmod/core/serializable.py b/python/lib/core/dmod/core/serializable.py index cc3fa8363..9d464fb06 100644 --- a/python/lib/core/dmod/core/serializable.py +++ b/python/lib/core/dmod/core/serializable.py @@ -316,10 +316,3 @@ class BasicResultIndicator(ResultIndicator): """ Bare-bones, concrete implementation of ::class:`ResultIndicator`. """ - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - try: - return cls(**json_obj) - except Exception: - return None From 56dc33412b28c1228eb325a5e738f46396de4322 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 15:21:42 -0500 Subject: [PATCH 016/205] refactor DiscreteRestriction --- python/lib/core/dmod/core/meta_data.py | 38 ++++++++++---------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index 37504e815..567f87d4f 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -5,7 +5,8 @@ from numbers import Number from typing import Any, Dict, List, Optional, Set, Type, Union from collections.abc import Iterable -from pydantic import root_validator, validator, PyObject, Field +from collections import OrderedDict +from pydantic import root_validator, validator, PyObject, Field, StrictStr, StrictFloat, StrictInt class StandardDatasetIndex(PydanticEnum): @@ -342,35 +343,29 @@ class DiscreteRestriction(Serializable): Note that an empty list for the ::attribute:`values` property implies a restriction of all possible values being required. This is reflected by the :method:`is_all_possible_values` property. """ + variable: StandardDatasetIndex + values: Union[List[StrictStr], List[StrictFloat], List[StrictInt]] + + # validate variable is not UNKNOWN variant + _validate_variable = validator("variable", allow_reuse=True)(_validate_variable_is_known) + @classmethod def factory_init_from_deserialized_json(cls, json_obj: dict): try: - variable = StandardDatasetIndex.get_for_name(json_obj["variable"]) - if variable == StandardDatasetIndex.UNKNOWN: - return None - return cls(variable=variable, values=json_obj["values"]) + cls(**json_obj) except: return None def __init__(self, variable: Union[str, StandardDatasetIndex], values: Union[List[str], List[Number]], allow_reorder: bool = True, - remove_duplicates: bool = True): - self.variable = StandardDatasetIndex.get_for_name(variable) if isinstance(variable, str) else variable - if self.variable == StandardDatasetIndex.UNKNOWN: - raise ValueError("Invalid value for {} variable: {}".format(self.__class__.__name__, variable)) - self.values: Union[List[str], List[Number]] = list(set(values)) if remove_duplicates else values + remove_duplicates: bool = True, **kwargs): + super().__init__(variable=variable, values=values, **kwargs) + if remove_duplicates: + self.values = list(OrderedDict.fromkeys(self.values)) if allow_reorder: self.values.sort() - def __eq__(self, other): - if self.__class__ == other.__class__ or isinstance(other, self.__class__): - return self.variable == other.variable and self.values == other.values - elif isinstance(self, other.__class__): - return other.__eq__(self) - else: - return False - - def __hash__(self): - hash('{}-{}'.format(self.variable.name, ','.join([str(v) for v in self.values]))) + def __hash__(self) -> int: + return hash('{}-{}'.format(self.variable.name, ','.join([str(v) for v in self.values]))) def contains(self, other: 'DiscreteRestriction') -> bool: """ @@ -425,9 +420,6 @@ def is_all_possible_values(self) -> bool: """ return self.values is not None and len(self.values) == 0 - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - return {"variable": self.variable.name, "values": self.values} - class DataDomain(Serializable): """ From d5d0f09d849fd051077bb80a61f04fe434bb6c0a Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 15:38:24 -0500 Subject: [PATCH 017/205] refactor DataDomain --- python/lib/core/dmod/core/meta_data.py | 174 +++++++++---------------- 1 file changed, 64 insertions(+), 110 deletions(-) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index 567f87d4f..9379075bb 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -425,30 +425,54 @@ class DataDomain(Serializable): """ A domain for a dataset, with domain-defining values contained by one or more discrete and/or continuous components. """ + data_format: DataFormat = Field( + description="The format for the data in this domain, which contains details like the indices and other data fields." + ) + continuous_restrictions: Optional[List[ContinuousRestriction]] = Field( + description="Map of the continuous restrictions defining this domain, keyed by variable name.", + alias="continuous", + default_factory=list + ) + discrete_restrictions: Optional[List[DiscreteRestriction]] = Field( + description="Map of the discrete restrictions defining this domain, keyed by variable name.", + alias="discrete", + default_factory=list + ) + custom_data_fields: Optional[Dict[str, Union[str, int, float, Any]]] = Field( + description=("This will either be directly from the format, if its format specifies any fields, or from a custom fields" + "attribute that may be set during initialization (but is ignored when the format specifies fields)."), + alias="data_fields" + ) + + @validator("custom_data_fields") + def validate_data_fields(cls, values): + def handle_type_map(t): + if t == "str" or t == str: + return str + elif t == "int" or t == int: + return int + elif t == "float" or t == float: + return float + # maintain reference to a passed in python type or subtype + elif isinstance(t, type): + return t + return Any + + return {k: handle_type_map(v) for k, v in values.items()} + + @root_validator() + def validate_sufficient_restrictions(cls, values): + continuous_restrictions = values.get("continuous_restrictions", []) + discrete_restrictions = values.get("discrete_restrictions", []) + if len(continuous_restrictions) + len(discrete_restrictions) == 0: + msg = "Cannot create {} without at least one finite continuous or discrete restriction" + raise RuntimeError(msg.format(cls.__name__)) + return values @classmethod def factory_init_from_deserialized_json(cls, json_obj: dict): try: - data_format = DataFormat.get_for_name(json_obj["data_format"]) - continuous = [ContinuousRestriction.factory_init_from_deserialized_json(c) for c in json_obj["continuous"]] - discrete = [DiscreteRestriction.factory_init_from_deserialized_json(d) for d in json_obj["discrete"]] - if 'data_fields' in json_obj: - data_fields = dict() - for key in json_obj['data_fields']: - val = json_obj['data_fields'][key] - if val == 'str': - data_fields[key] = str - elif val == 'int': - data_fields[key] = int - elif val == 'float': - data_fields[key] = float - else: - data_fields[key] = Any - else: - data_fields = None - - return cls(data_format=data_format, continuous_restrictions=continuous, discrete_restrictions=discrete, - custom_data_fields=data_fields) + return cls(**json_obj) except: return None @@ -521,12 +545,6 @@ def factory_init_from_restriction_collections(cls, data_format: DataFormat, **kw continuous_restrictions=None if len(continuous) == 0 else continuous, discrete_restrictions=None if len(discrete) == 0 else discrete) - def __eq__(self, other): - return self.__class__ == other.__class__ and self.data_format == other.data_format \ - and self.continuous_restrictions == other.continuous_restrictions \ - and self.discrete_restrictions == other.discrete_restrictions \ - and self._custom_data_fields == other._custom_data_fields - def __hash__(self): if self._custom_data_fields is None: cu = '' @@ -538,27 +556,6 @@ def __hash__(self): ','.join([str(hash(self.discrete_restrictions[k])) for k in sorted(self.discrete_restrictions)]), cu)) - def __init__(self, data_format: DataFormat, continuous_restrictions: Optional[List[ContinuousRestriction]] = None, - discrete_restrictions: Optional[List[DiscreteRestriction]] = None, - custom_data_fields: Optional[Dict[str, Type]] = None): - self._data_format = data_format - self._continuous_restrictions = dict() - self._discrete_restrictions = dict() - self._custom_data_fields = custom_data_fields - """ Extra attribute for custom data fields when format does not specify all data fields (ignore when format does specify). """ - - if continuous_restrictions is not None: - for c in continuous_restrictions: - self._continuous_restrictions[c.variable] = c - - if discrete_restrictions is not None: - for d in discrete_restrictions: - self._discrete_restrictions[d.variable] = d - - if len(self._continuous_restrictions) + len(self._discrete_restrictions) == 0: - msg = "Cannot create {} without at least one finite continuous or discrete restriction" - raise RuntimeError(msg.format(self.__class__.__name__)) - def _extends_continuous_restriction(self, continuous_restriction: ContinuousRestriction) -> bool: idx = continuous_restriction.variable return idx in self.continuous_restrictions and self.continuous_restrictions[idx].contains(continuous_restriction) @@ -596,30 +593,6 @@ def contains(self, other: Union[ContinuousRestriction, DiscreteRestriction, 'Dat return False return True - @property - def continuous_restrictions(self) -> Dict[StandardDatasetIndex, ContinuousRestriction]: - """ - Map of the continuous restrictions defining this domain, keyed by variable name. - - Returns - ------- - Dict[str, ContinuousRestriction] - Map of the continuous restrictions defining this domain, keyed by variable name. - """ - return self._continuous_restrictions - - @property - def discrete_restrictions(self) -> Dict[StandardDatasetIndex, DiscreteRestriction]: - """ - Map of the discrete restrictions defining this domain, keyed by variable name. - - Returns - ------- - Dict[str, DiscreteRestriction] - Map of the discrete restrictions defining this domain, keyed by variable name. - """ - return self._discrete_restrictions - @property def data_fields(self) -> Dict[str, Type]: """ @@ -633,23 +606,9 @@ def data_fields(self) -> Dict[str, Type]: """ if self.data_format.data_fields is None: - return self._custom_data_fields + return self.custom_data_fields else: - return self._data_format.data_fields - - @property - def data_format(self) -> DataFormat: - """ - The format for data in this domain. - - The format for the data in this domain, which contains details like the indices and other data fields. - - Returns - ------- - DataFormat - The format for data in this domain. - """ - return self._data_format + return self.data_format.data_fields @property def indices(self) -> List[str]: @@ -664,34 +623,29 @@ def indices(self) -> List[str]: List[str] List of the string forms of the ::class:`StandardDataIndex` indices that define this domain. """ - return self._data_format.indices + return self.data_format.indices - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: + def dict(self, **kwargs) -> dict: """ - Serialize to a dictionary. - - Serialize this instance to a dictionary, with there being two top-level list items. These are made from the - the contained ::class:`ContinuousRestriction` and ::class:`DiscreteRestriction` objects - - Returns - ------- + `data_fields` is excluded from dict if `self.data_format.data_fields` is None. + called by `to_dict` and `to_json`. """ - serial = {"data_format": self._data_format.name, - "continuous": [component.to_dict() for idx, component in self.continuous_restrictions.items()], - "discrete": [component.to_dict() for idx, component in self.discrete_restrictions.items()]} + # TODO: aaraney, handle encoding type (int, float, etc.) as str + by_alias = kwargs.pop("by_alias") if "by_alias" in kwargs else True + if self.data_format.data_fields is None: - serial['data_fields'] = dict() - for key in self._custom_data_fields: - if self._custom_data_fields[key] == str: - serial['data_fields'][key] = 'str' - elif self._custom_data_fields[key] == int: - serial['data_fields'][key] = 'int' - elif self._custom_data_fields[key] == float: - serial['data_fields'][key] = 'float' - else: - serial['data_fields'][key] = 'Any' - return serial + return super().dict(by_alias=by_alias, **kwargs) + + exclude = {"custom_data_fields"} + + # merge exclude fields and excludes from kwargs + if "exclude" in kwargs: + values = kwargs.pop("exclude") + if values is not None: + exclude = {*exclude, *values} + + return super().dict(by_alias=by_alias, exclude=exclude, **kwargs) class DataCategory(PydanticEnum): From 9924c14bf026f0329f9a73412399b36004c2ab4f Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 15:43:43 -0500 Subject: [PATCH 018/205] refactor DataRequirement --- python/lib/core/dmod/core/meta_data.py | 140 +++---------------------- 1 file changed, 12 insertions(+), 128 deletions(-) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index 9379075bb..f9605d8c6 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -671,20 +671,19 @@ class TimeRange(ContinuousRestriction): """ Encapsulated representation of a time range. """ - - def __init__(self, begin: Union[str, datetime], end: Union[str, datetime], datetime_pattern: Optional[str] = None, - **kwargs): - dt_ptrn = self.get_datetime_str_format() if datetime_pattern is None else datetime_pattern - super(TimeRange, self).__init__(variable=StandardDatasetIndex.TIME, - begin=begin if isinstance(begin, datetime) else datetime.strptime(begin, dt_ptrn), - end=end if isinstance(end, datetime) else datetime.strptime(end, dt_ptrn), - datetime_pattern=dt_ptrn) + variable: StandardDatasetIndex = Field(StandardDatasetIndex.TIME, const=True) class DataRequirement(Serializable): """ A definition of a particular data requirement needed for an execution task. """ + category: DataCategory + domain: DataDomain + fulfilled_access_at: Optional[str] = Field(description="The location at which the fulfilling dataset for this requirement is accessible, if the dataset known.") + fulfilled_by: Optional[str] = Field(description="The name of the dataset that will fulfill this, if it is known.") + is_input: bool = Field(..., description="Whether this represents required input data, as opposed to a requirement for storing output data.") + size: Optional[int] _KEY_CATEGORY = 'category' """ Serialization dictionary JSON key for ::attribute:`category` property value. """ @@ -715,129 +714,14 @@ def factory_init_from_deserialized_json(cls, json_obj: dict) -> Optional['DataRe A deserialized ::class:`DataRequirement` instance, or return ``None`` if the JSON is not valid. """ try: - domain = DataDomain.factory_init_from_deserialized_json(json_obj[cls._KEY_DOMAIN]) - category = DataCategory.get_for_name(json_obj[cls._KEY_CATEGORY]) - is_input = json_obj[cls._KEY_IS_INPUT] - - opt_kwargs_w_defaults = dict() - if cls._KEY_FULFILLED_BY in json_obj: - opt_kwargs_w_defaults['fulfilled_by'] = json_obj[cls._KEY_FULFILLED_BY] - if cls._KEY_SIZE in json_obj: - opt_kwargs_w_defaults['size'] = json_obj[cls._KEY_SIZE] - if cls._KEY_FULFILLED_ACCESS_AT in json_obj: - opt_kwargs_w_defaults['fulfilled_access_at'] = json_obj[cls._KEY_FULFILLED_ACCESS_AT] - - return cls(domain=domain, is_input=is_input, category=category, **opt_kwargs_w_defaults) + return cls(**json_obj) except: return None - def __eq__(self, other): - return self.__class__ == other.__class__ and self.domain == other.domain and self.is_input == other.is_input \ - and self.category == other.category - def __hash__(self): return hash('{}-{}-{}'.format(hash(self.domain), self.is_input, self.category)) - def __init__(self, domain: DataDomain, is_input: bool, category: DataCategory, size: Optional[int] = None, - fulfilled_by: Optional[str] = None, fulfilled_access_at: Optional[str] = None): - self._domain = domain - self._is_input = is_input - self._category = category - self._size = size - self._fulfilled_by = fulfilled_by - self._fulfilled_access_at = fulfilled_access_at - - @property - def category(self) -> DataCategory: - """ - The ::class:`DataCategory` of data required. - - Returns - ------- - DataCategory - The category of data required. - """ - return self._category - - @property - def domain(self) -> DataDomain: - """ - The (restricted) domain of the data that is required. - - Returns - ------- - DataDomain - The (restricted) domain of the data that is required. - """ - return self._domain - - @property - def fulfilled_access_at(self) -> Optional[str]: - """ - The location at which the fulfilling dataset for this requirement is accessible, if the dataset known. - - Returns - ------- - Optional[str] - The location at which the fulfilling dataset for this requirement is accessible, if known, or ``None`` - otherwise. - """ - return self._fulfilled_access_at - - @fulfilled_access_at.setter - def fulfilled_access_at(self, location: str): - self._fulfilled_access_at = location - - @property - def fulfilled_by(self) -> Optional[str]: - """ - The name of the dataset that will fulfill this, if it is known. - - Returns - ------- - Optional[str] - The name of the dataset that will fulfill this, if it is known; ``None`` otherwise. - """ - return self._fulfilled_by - - @fulfilled_by.setter - def fulfilled_by(self, name: str): - self._fulfilled_by = name - - @property - def is_input(self) -> bool: - """ - Whether this represents required input data, as opposed to a requirement for storing output data. - - Returns - ------- - bool - Whether this represents required input data. - """ - return self._is_input - - @property - def size(self) -> Optional[int]: - """ - The size of the required data, if it is known. - - This is particularly important (though still not strictly required) for an output data requirement; i.e., a - requirement to store output data somewhere. - - Returns - ------- - Optional[int] - he size of the required data, if it is known, or ``None`` otherwise. - """ - return self._size - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serial = {self._KEY_DOMAIN: self.domain.to_dict(), self._KEY_IS_INPUT: self.is_input, - self._KEY_CATEGORY: self.category.name} - if self.size is not None: - serial[self._KEY_SIZE] = self.size - if self.fulfilled_by is not None: - serial[self._KEY_FULFILLED_BY] = self.fulfilled_by - if self.fulfilled_access_at is not None: - serial[self._KEY_FULFILLED_ACCESS_AT] = self.fulfilled_access_at - return serial + def dict(self, *, **kwargs) -> dict: + exclude_unset = True if kwargs.get("exclude_unset") is None else False + kwargs["exclude_unset"] = exclude_unset + return super().dict(**kwargs) From 8bd316bd869c0c8e5304dec2d9a489b2ef29305d Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 15:45:28 -0500 Subject: [PATCH 019/205] remove unnecessary factory_init_from_deserialized_json's --- python/lib/core/dmod/core/meta_data.py | 34 -------------------------- 1 file changed, 34 deletions(-) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index f9605d8c6..d205db56d 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -349,13 +349,6 @@ class DiscreteRestriction(Serializable): # validate variable is not UNKNOWN variant _validate_variable = validator("variable", allow_reuse=True)(_validate_variable_is_known) - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - try: - cls(**json_obj) - except: - return None - def __init__(self, variable: Union[str, StandardDatasetIndex], values: Union[List[str], List[Number]], allow_reorder: bool = True, remove_duplicates: bool = True, **kwargs): super().__init__(variable=variable, values=values, **kwargs) @@ -469,13 +462,6 @@ def validate_sufficient_restrictions(cls, values): raise RuntimeError(msg.format(cls.__name__)) return values - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - try: - return cls(**json_obj) - except: - return None - @classmethod def factory_init_from_restriction_collections(cls, data_format: DataFormat, **kwargs) -> 'DataDomain': """ @@ -698,26 +684,6 @@ class DataRequirement(Serializable): _KEY_SIZE = 'size' """ Serialization dictionary JSON key for ::attribute:`size` property value. """ - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict) -> Optional['DataRequirement']: - """ - Deserialize the given JSON to a ::class:`DataRequirement` instance, or return ``None`` if it is not valid. - - Parameters - ---------- - json_obj : dict - The JSON to be deserialized. - - Returns - ------- - Optional[DataRequirement] - A deserialized ::class:`DataRequirement` instance, or return ``None`` if the JSON is not valid. - """ - try: - return cls(**json_obj) - except: - return None - def __hash__(self): return hash('{}-{}-{}'.format(hash(self.domain), self.is_input, self.category)) From f61d5dc6efe3c994426491a907f6cb8d12aaa387 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 15:46:48 -0500 Subject: [PATCH 020/205] remove unneeded class variables --- python/lib/core/dmod/core/meta_data.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index d205db56d..021cbcbcb 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -671,19 +671,6 @@ class DataRequirement(Serializable): is_input: bool = Field(..., description="Whether this represents required input data, as opposed to a requirement for storing output data.") size: Optional[int] - _KEY_CATEGORY = 'category' - """ Serialization dictionary JSON key for ::attribute:`category` property value. """ - _KEY_DOMAIN = 'domain' - """ Serialization dictionary JSON key for ::attribute:`domain_params` property value. """ - _KEY_FULFILLED_ACCESS_AT = 'fulfilled_access_at' - """ Serialization dictionary JSON key for ::attribute:`fulfilled_access_at` property value. """ - _KEY_FULFILLED_BY = 'fulfilled_by' - """ Serialization dictionary JSON key for ::attribute:`fulfilled_by` property value. """ - _KEY_IS_INPUT = 'is_input' - """ Serialization dictionary JSON key for ::attribute:`is_input` property value. """ - _KEY_SIZE = 'size' - """ Serialization dictionary JSON key for ::attribute:`size` property value. """ - def __hash__(self): return hash('{}-{}-{}'.format(hash(self.domain), self.is_input, self.category)) From 665c4151f27ab19a90e4c54065a56befdb4a198a Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 15:47:42 -0500 Subject: [PATCH 021/205] remove ellipse --- python/lib/core/dmod/core/meta_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index 021cbcbcb..244e8cade 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -668,7 +668,7 @@ class DataRequirement(Serializable): domain: DataDomain fulfilled_access_at: Optional[str] = Field(description="The location at which the fulfilling dataset for this requirement is accessible, if the dataset known.") fulfilled_by: Optional[str] = Field(description="The name of the dataset that will fulfill this, if it is known.") - is_input: bool = Field(..., description="Whether this represents required input data, as opposed to a requirement for storing output data.") + is_input: bool = Field(description="Whether this represents required input data, as opposed to a requirement for storing output data.") size: Optional[int] def __hash__(self): From adb5c03412d6d618a227db5c0885d5888b624076 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 15:48:50 -0500 Subject: [PATCH 022/205] refactor dataset module enums --- python/lib/core/dmod/core/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/lib/core/dmod/core/dataset.py b/python/lib/core/dmod/core/dataset.py index b1e5ca2aa..56dc28c03 100644 --- a/python/lib/core/dmod/core/dataset.py +++ b/python/lib/core/dmod/core/dataset.py @@ -5,13 +5,13 @@ from datetime import datetime, timedelta from .serializable import Serializable, ResultIndicator -from enum import Enum +from .enum import PydanticEnum from numbers import Number from typing import Any, Callable, Dict, FrozenSet, List, Optional, Set, Tuple, Type, Union from uuid import UUID, uuid4 -class DatasetType(Enum): +class DatasetType(PydanticEnum): UNKNOWN = (-1, False, lambda dataset: None) OBJECT_STORE = (0, True, lambda dataset: dataset.name) FILESYSTEM = (1, True, lambda dataset: dataset.access_location) From 3fe7c46a4933f7545f818214dadd15a899c66047 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 16:06:01 -0500 Subject: [PATCH 023/205] refactor Dataset --- python/lib/core/dmod/core/dataset.py | 396 +++++++-------------------- 1 file changed, 98 insertions(+), 298 deletions(-) diff --git a/python/lib/core/dmod/core/dataset.py b/python/lib/core/dmod/core/dataset.py index 56dc28c03..e6d765d1f 100644 --- a/python/lib/core/dmod/core/dataset.py +++ b/python/lib/core/dmod/core/dataset.py @@ -7,7 +7,8 @@ from .serializable import Serializable, ResultIndicator from .enum import PydanticEnum from numbers import Number -from typing import Any, Callable, Dict, FrozenSet, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, ClassVar, Dict, FrozenSet, List, Optional, Set, Tuple, Type, Union +from pydantic import Field from uuid import UUID, uuid4 @@ -59,22 +60,69 @@ class Dataset(Serializable): Rrepresentation of the descriptive metadata for a grouped collection of data. """ - _SERIAL_DATETIME_STR_FORMAT = '%Y-%m-%d %H:%M:%S' - - _KEY_ACCESS_LOCATION = 'access_location' - _KEY_CREATED_ON = 'create_on' - _KEY_DATA_CATEGORY = 'data_category' - _KEY_DATA_DOMAIN = 'data_domain' - _KEY_DERIVED_FROM = 'derived_from' - _KEY_DERIVATIONS = 'derivations' - _KEY_DESCRIPTION = 'description' - _KEY_EXPIRES = 'expires' - _KEY_IS_READ_ONLY = 'is_read_only' - _KEY_LAST_UPDATE = 'last_updated' - _KEY_MANAGER_UUID = 'manager_uuid' - _KEY_NAME = 'name' - _KEY_TYPE = 'type' - _KEY_UUID = 'uuid' + _SERIAL_DATETIME_STR_FORMAT: ClassVar = '%Y-%m-%d %H:%M:%S' + name: str = Field(description="The name for this dataset, which also should be a unique identifier.") + # QUESTION: should this be optional? see factory_init_from_deserialized_json + category: Optional[DataCategory] = Field(None, alias="data_category", description="The ::class:`DataCategory` type value for this instance.") + # QUESTION: should this be optional? see factory_init_from_deserialized_json + data_domain: Optional[DataDomain] + dataset_type: DatasetType = Field(DatasetType.UNKNOWN, alias="type") + access_location: str = Field(description="String representation of the location at which this dataset is accessible.") + uuid: Optional[UUID] = Field(default_factory=uuid4) + # manager can only be passed as constructed DatasetManager subtype. Manager not included in `dict` or `json` deserialization. + # TODO: don't include `manager` in `Dataset.schema()`. Inclusion is not reflective of the de/serialization behavior. + manager: Optional['DatasetManager'] = Field(exclude=True) + manager_uuid: Optional[UUID] + is_read_only: bool = Field(True, description="Whether this is a dataset that can only be read from.") + description: Optional[str] + expires: Optional[datetime] = Field(description='The time after which a dataset may "expire" and be removed, or ``None`` if the dataset is not temporary.') + derived_from: Optional[str] = Field(description="The name of the dataset from which this dataset was derived, if it is known to have been derived.") + derivations: Optional[List[str]] = Field(default_factory=list, description="""List of names of datasets which were derived from this dataset.\n + Note that it is not guaranteed that any such dataset still exist and/or are still available.""") + created_on: Optional[datetime] = Field(description="When this dataset was created, or ``None`` if that is not known.") + last_updated: Optional[datetime] + + @validator("created_on", "last_updated", "expires", pre=True) + def parse_dates(cls, v): + if v is None: + return None + + if isinstance(v, datetime): + return v + + # NOTE: could raise: + # - TypeError: if `v` or `cls.get_datetime_str_format` is not `str` + # - ValueError: if `v` cannot be coerced into `datetime` object + return datetime.strptime(v, cls.get_datetime_str_format()) + + @validator("created_on", "last_updated", "expires") + def drop_microseconds(cls, v: datetime): + return v.replace(microsecond=0) + + @validator("manager", pre=True) + def drop_manager_if_not_constructed_subtype(cls, value): + # manager can only be passed as constructed DatasetManager subtype + if isinstance(value, DatasetManager): + return value + return None + + @root_validator() + def set_manager_uuid(cls, values) -> dict: + manager: Optional[DatasetManager] = values["manager"] + # give preference to `manager.uuid` otherwise use specified `manager_uuid` + if manager is not None: + # pydantic will not validate this, so we need to check it + if not isinstance(manager.uuid, UUID): + raise ValueError(f"Expected UUID got {type(manager.uuid)}") + values["manager_uuid"] = manager.uuid + + return values + + class Config: + # NOTE: re-validate when any field is re-assigned (i.e. `model.foo = 12`) + # TODO: in future deprecate setting properties unless through a setter method + validate_assignment = True + arbitrary_types_allowed = True # TODO: move this (and something more to better automatically handle Serializable subtypes) to Serializable directly @classmethod @@ -84,60 +132,32 @@ def _date_parse_helper(cls, json_obj: dict, key: str) -> Optional[datetime]: else: return None - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - try: - manager_uuid = UUID(json_obj[cls._KEY_MANAGER_UUID]) if cls._KEY_MANAGER_UUID in json_obj else None - return cls(name=json_obj[cls._KEY_NAME], - category=DataCategory.get_for_name(json_obj[cls._KEY_DATA_CATEGORY]), - data_domain=DataDomain.factory_init_from_deserialized_json(json_obj[cls._KEY_DATA_DOMAIN]), - dataset_type=DatasetType.get_for_name(json_obj[cls._KEY_TYPE]), - access_location=json_obj[cls._KEY_ACCESS_LOCATION], - description=json_obj.get(cls._KEY_DESCRIPTION, None), - uuid=UUID(json_obj[cls._KEY_UUID]), - manager_uuid=manager_uuid, - is_read_only=json_obj[cls._KEY_IS_READ_ONLY], - expires=cls._date_parse_helper(json_obj, cls._KEY_EXPIRES), - derived_from=json_obj[cls._KEY_DERIVED_FROM] if cls._KEY_DERIVED_FROM in json_obj else None, - derivations=json_obj[cls._KEY_DERIVATIONS] if cls._KEY_DERIVATIONS in json_obj else [], - created_on=cls._date_parse_helper(json_obj, cls._KEY_CREATED_ON), - last_updated=cls._date_parse_helper(json_obj, cls._KEY_LAST_UPDATE)) - except Exception as e: - return None - - def __eq__(self, other): - return isinstance(other, Dataset) and self.name == other.name and self.category == other.category \ - and self.dataset_type == other.dataset_type and self.data_domain == other.data_domain \ - and self.access_location == other.access_location and self.is_read_only == other.is_read_only \ - and self.created_on == other.created_on + # TODO: Remove after draft review + # @classmethod + # def factory_init_from_deserialized_json(cls, json_obj: dict): + # try: + # manager_uuid = UUID(json_obj[cls._KEY_MANAGER_UUID]) if cls._KEY_MANAGER_UUID in json_obj else None + # return cls(name=json_obj[cls._KEY_NAME], + # category=DataCategory.get_for_name(json_obj[cls._KEY_DATA_CATEGORY]), + # data_domain=DataDomain.factory_init_from_deserialized_json(json_obj[cls._KEY_DATA_DOMAIN]), + # dataset_type=DatasetType.get_for_name(json_obj[cls._KEY_TYPE]), + # access_location=json_obj[cls._KEY_ACCESS_LOCATION], + # description=json_obj.get(cls._KEY_DESCRIPTION, None), + # uuid=UUID(json_obj[cls._KEY_UUID]), + # manager_uuid=manager_uuid, + # is_read_only=json_obj[cls._KEY_IS_READ_ONLY], + # expires=cls._date_parse_helper(json_obj, cls._KEY_EXPIRES), + # derived_from=json_obj[cls._KEY_DERIVED_FROM] if cls._KEY_DERIVED_FROM in json_obj else None, + # derivations=json_obj[cls._KEY_DERIVATIONS] if cls._KEY_DERIVATIONS in json_obj else [], + # created_on=cls._date_parse_helper(json_obj, cls._KEY_CREATED_ON), + # last_updated=cls._date_parse_helper(json_obj, cls._KEY_LAST_UPDATE)) + # except Exception as e: + # return None def __hash__(self): return hash(','.join([self.__class__.__name__, self.name, self.category.name, str(hash(self.data_domain)), self.access_location, str(self.is_read_only), str(hash(self.created_on))])) - def __init__(self, name: str, category: DataCategory, data_domain: DataDomain, dataset_type: DatasetType, - access_location: str, uuid: Optional[UUID] = None, manager: Optional['DatasetManager'] = None, - manager_uuid: Optional[UUID] = None, is_read_only: bool = True, description: Optional[str] = None, expires: Optional[datetime] = None, - derived_from: Optional[str] = None, derivations: Optional[List[str]] = None, - created_on: Optional[datetime] = None, last_updated: Optional[datetime] = None): - self._name = name - self._category = category - self._data_domain = data_domain - self._dataset_type = dataset_type - self._access_location = access_location - self._uuid = uuid4() if uuid is None else uuid - self._manager = manager - self._manager_uuid = manager.uuid if manager is not None else manager_uuid - self._description = description - self._is_read_only = is_read_only - self._expires = expires if expires is None else expires.replace(microsecond=0) - self._derived_from = derived_from - self._derivations = derivations if derivations is not None else list() - self._created_on = created_on if created_on is None else created_on.replace(microsecond=0) - self._last_updated = last_updated if last_updated is None else last_updated.replace(microsecond=0) - # TODO: have manager handle the logic - #retention_strategy - def _set_expires(self, new_expires: datetime): """ "Private" function to set the ::attribute:`expires` property. @@ -150,60 +170,8 @@ def _set_expires(self, new_expires: datetime): new_expires : datetime The new value for ::attribute:`expires`. """ - self._expires = new_expires - # n = datetime.now() - # n.astimezone().tzinfo.tzname(n.astimezone()) - self._last_updated = datetime.now() - - @property - def access_location(self) -> str: - """ - String representation of the location at which this dataset is accessible. - - Depending on the subtype, this may be the string form of a URL, URI, or basic filesystem path. - - Returns - ------- - str - String representation of the location at which this dataset is accessible. - """ - return self._access_location - - @property - def category(self) -> DataCategory: - """ - The ::class:`DataCategory` type value for this instance. - - Returns - ------- - DataCategory - The ::class:`DataCategory` type value for this instance. - """ - return self._category - - @property - def created_on(self) -> Optional[datetime]: - """ - When this dataset was created, or ``None`` if that is not known. - - Returns - ------- - Optional[datetime] - When this dataset was created, or ``None`` if that is not known. - """ - return self._created_on - - @property - def data_domain(self) -> DataDomain: - """ - The data domain for this instance. - - Returns - ------- - DataDomain - The ::class:`DataDomain` for this instance. - """ - return self._data_domain + self.expires = new_expires + self.last_updated = datetime.now() @property def data_format(self) -> DataFormat: @@ -217,53 +185,6 @@ def data_format(self) -> DataFormat: """ return self.data_domain.data_format - @property - def dataset_type(self) -> DatasetType: - return self._dataset_type - - @property - def derivations(self) -> List[str]: - """ - List of names of datasets which were derived from this dataset. - - Note that it is not guaranteed that any such dataset still exist and/or are still available. - - Returns - ------- - List[str] - List of names of datasets which were derived from this dataset. - """ - return self._derivations - - @property - def derived_from(self) -> Optional[str]: - """ - The name of the dataset from which this dataset was derived, if it is known to have been derived. - - Returns - ------- - Optional[str] - The name of the dataset from which this dataset was derived, or ``None`` if this dataset is not known to - have been derived. - """ - return self._derived_from - - @property - def description(self) -> Optional[str]: - """ - An optional string description of this dataset. - - Returns - ------- - Optional[str] - An optional string description of this dataset. - """ - return self._description - - @description.setter - def description(self, desc: Optional[str]): - self._description = desc - @property def docker_mount(self) -> str: """ @@ -289,22 +210,6 @@ def docker_mount(self) -> str: else: return result - @property - def expires(self) -> Optional[datetime]: - """ - The time after which a dataset may "expire" and be removed, or ``None`` if the dataset is not temporary. - - A dataset may be temporary, meaning its availability and validity cannot be assumed perpetually; e.g., the data - may be removed from storage. This property indicates the time through which availability and validity is - guaranteed. - - Returns - ------- - Optional[datetime] - The time after which a dataset may "expire" and be removed, or ``None`` if the dataset is not temporary. - """ - return self._expires - def extend_life(self, value: Union[datetime, timedelta]) -> bool: """ Extend the expiration of this dataset. @@ -335,7 +240,7 @@ def extend_life(self, value: Union[datetime, timedelta]) -> bool: if not self.is_temporary: return False elif isinstance(value, timedelta): - self._set_expires(self._expires + value) + self._set_expires(self.expires + value) return True elif isinstance(value, datetime) and self.expires < value: self._set_expires(value) @@ -355,18 +260,6 @@ def fields(self) -> Dict[str, Type]: """ return self.data_domain.data_fields - @property - def is_read_only(self) -> bool: - """ - Whether this is a dataset that can only be read from. - - Returns - ------- - bool - Whether this is a dataset that can only be read from. - """ - return self._is_read_only - @property def is_temporary(self) -> bool: """ @@ -382,64 +275,6 @@ def is_temporary(self) -> bool: """ return self.expires is not None - @property - def last_updated(self) -> Optional[datetime]: - """ - When this dataset was last updated, or ``None`` if that is not known. - - Note that this includes adjustments to metadata, including the value for ::attribute:`expires`. - - Returns - ------- - Optional[datetime] - When this dataset was last updated, or ``None`` if that is not known. - """ - return self._last_updated - - @property - def manager(self) -> 'DatasetManager': - """ - The ::class:`DatasetManager` for this instance. - - Returns - ------- - DatasetManager - The ::class:`DatasetManager` for this instance. - """ - return self._manager - - @manager.setter - def manager(self, manager: 'DatasetManager'): - self._manager = manager - self._manager_uuid = manager.uuid - - @property - def manager_uuid(self) -> UUID: - """ - The UUID of the ::class:`DatasetManager` for this instance. - - Returns - ------- - DatasetManager - The UUID of the ::class:`DatasetManager` for this instance. - """ - return self._manager_uuid - - @property - def name(self) -> str: - """ - The name for this dataset, which also should be a unique identifier. - - Every dataset in the domain of all datasets known to this instance's ::attribute:`manager` must have a unique - name value. - - Returns - ------- - str - The dataset's unique name. - """ - return self._name - @property def time_range(self) -> Optional[TimeRange]: """ @@ -456,51 +291,16 @@ def time_range(self) -> Optional[TimeRange]: tr = self.data_domain.continuous_restrictions[StandardDatasetIndex.TIME] return tr if isinstance(tr, TimeRange) else TimeRange(begin=tr.begin, end=tr.end, variable=tr.variable) - @property - def uuid(self) -> UUID: - """ - The UUID for this instance. + def _get_exclude_fields(self) -> Set[str]: + """Set of fields to exclude during deserialization if they are some None variant (e.g. '', 0, None)""" + candidates = ("manager_uuid", "expires", "derived_from", "derivations", "description", "created_on", "last_updated") + return {f for f in candidates if not self.__getattribute__(f)} - Returns - ------- - UUID - The UUID for this instance. - """ - return self._uuid - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - """ - Get the serial form of this instance as a dictionary object. - - Returns - ------- - Dict[str, Union[str, Number, dict, list]] - The serialized form of this instance. - """ - serial = dict() - serial[self._KEY_NAME] = self.name - serial[self._KEY_DATA_CATEGORY] = self.category.name - serial[self._KEY_DATA_DOMAIN] = self.data_domain.to_dict() - serial[self._KEY_TYPE] = self.dataset_type.name - # TODO: unit test this - serial[self._KEY_ACCESS_LOCATION] = self.access_location - serial[self._KEY_UUID] = str(self.uuid) - serial[self._KEY_IS_READ_ONLY] = self.is_read_only - if self.manager_uuid is not None: - serial[self._KEY_MANAGER_UUID] = str(self.manager_uuid) - if self.expires is not None: - serial[self._KEY_EXPIRES] = self.expires.strftime(self.get_datetime_str_format()) - if self.derived_from is not None: - serial[self._KEY_DERIVED_FROM] = self.derived_from - if len(self.derivations) > 0: - serial[self._KEY_DERIVATIONS] = self.derivations - if self.description is not None: - serial[self._KEY_DESCRIPTION] = self.description - if self.created_on is not None: - serial[self._KEY_CREATED_ON] = self.created_on.strftime(self.get_datetime_str_format()) - if self.last_updated is not None: - serial[self._KEY_LAST_UPDATE] = self.last_updated.strftime(self.get_datetime_str_format()) - return serial + def dict(self, **kwargs) -> dict: + # if exclude is set, ignore this _get_exclude_fields() + exclude = self._get_exclude_fields() if kwargs.get("exclude", False) is False else kwargs["exclude"] + kwargs["exclude"] = exclude + return super().dict(**kwargs) class DatasetUser(ABC): From 12010b6c8db75cdd680bf903a80fee9c96458a4f Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 16:09:50 -0500 Subject: [PATCH 024/205] add bug fix todo's --- python/lib/core/dmod/core/dataset.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/python/lib/core/dmod/core/dataset.py b/python/lib/core/dmod/core/dataset.py index e6d765d1f..23f4f904c 100644 --- a/python/lib/core/dmod/core/dataset.py +++ b/python/lib/core/dmod/core/dataset.py @@ -90,9 +90,6 @@ def parse_dates(cls, v): if isinstance(v, datetime): return v - # NOTE: could raise: - # - TypeError: if `v` or `cls.get_datetime_str_format` is not `str` - # - ValueError: if `v` cannot be coerced into `datetime` object return datetime.strptime(v, cls.get_datetime_str_format()) @validator("created_on", "last_updated", "expires") @@ -124,14 +121,6 @@ class Config: validate_assignment = True arbitrary_types_allowed = True - # TODO: move this (and something more to better automatically handle Serializable subtypes) to Serializable directly - @classmethod - def _date_parse_helper(cls, json_obj: dict, key: str) -> Optional[datetime]: - if key in json_obj: - return datetime.strptime(json_obj[key], cls.get_datetime_str_format()) - else: - return None - # TODO: Remove after draft review # @classmethod # def factory_init_from_deserialized_json(cls, json_obj: dict): @@ -240,8 +229,10 @@ def extend_life(self, value: Union[datetime, timedelta]) -> bool: if not self.is_temporary: return False elif isinstance(value, timedelta): + # TODO: Fix bug. expires could be None self._set_expires(self.expires + value) return True + # TODO: Fix bug. expires could be None elif isinstance(value, datetime) and self.expires < value: self._set_expires(value) return True From f71e965476b0e2608631cc1df2b00fc0c2d7b314 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 16:12:51 -0500 Subject: [PATCH 025/205] fix **kwargs bug --- python/lib/core/dmod/core/meta_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index 244e8cade..130717b11 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -674,7 +674,7 @@ class DataRequirement(Serializable): def __hash__(self): return hash('{}-{}-{}'.format(hash(self.domain), self.is_input, self.category)) - def dict(self, *, **kwargs) -> dict: + def dict(self, **kwargs) -> dict: exclude_unset = True if kwargs.get("exclude_unset") is None else False kwargs["exclude_unset"] = exclude_unset return super().dict(**kwargs) From 4f2fbe3c87a60488044f07b3f16d225548fcd9fb Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 16:13:28 -0500 Subject: [PATCH 026/205] clean up imports --- python/lib/core/dmod/core/dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/lib/core/dmod/core/dataset.py b/python/lib/core/dmod/core/dataset.py index 23f4f904c..78dd497f2 100644 --- a/python/lib/core/dmod/core/dataset.py +++ b/python/lib/core/dmod/core/dataset.py @@ -6,9 +6,8 @@ from .serializable import Serializable, ResultIndicator from .enum import PydanticEnum -from numbers import Number from typing import Any, Callable, ClassVar, Dict, FrozenSet, List, Optional, Set, Tuple, Type, Union -from pydantic import Field +from pydantic import Field, validator, root_validator from uuid import UUID, uuid4 From 87401b42ce8b2b9f2160f87c2791fc6618bc6e89 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 16:17:11 -0500 Subject: [PATCH 027/205] update unittests. no breaking changes made --- python/lib/core/dmod/test/test_data_requirement.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/lib/core/dmod/test/test_data_requirement.py b/python/lib/core/dmod/test/test_data_requirement.py index 4cbf4e2c5..3d3ff52a2 100644 --- a/python/lib/core/dmod/test/test_data_requirement.py +++ b/python/lib/core/dmod/test/test_data_requirement.py @@ -30,15 +30,15 @@ def test_to_dict_0_a(self): requirement = self.example_reqs[ex] as_dict = requirement.to_dict() self.assertTrue(isinstance(as_dict, dict)) - self.assertTrue(DataRequirement._KEY_DOMAIN in as_dict) + self.assertTrue("domain" in as_dict) def test_to_dict_0_b(self): ex = 0 requirement = self.example_reqs[ex] as_dict = requirement.to_dict() self.assertTrue(requirement.is_input) - self.assertTrue(isinstance(as_dict[DataRequirement._KEY_IS_INPUT], bool)) - self.assertTrue(as_dict[DataRequirement._KEY_IS_INPUT]) + self.assertTrue(isinstance(as_dict["is_input"], bool)) + self.assertTrue(as_dict["is_input"]) def test_factory_init_from_deserialized_json_0_a(self): """ From 7339714dfd17cfe835368cc3c41e33ac2b37b58d Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 16:24:07 -0500 Subject: [PATCH 028/205] serialize UUID using explict cast to str --- python/lib/core/dmod/core/dataset.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/lib/core/dmod/core/dataset.py b/python/lib/core/dmod/core/dataset.py index 78dd497f2..2740a830a 100644 --- a/python/lib/core/dmod/core/dataset.py +++ b/python/lib/core/dmod/core/dataset.py @@ -290,7 +290,12 @@ def dict(self, **kwargs) -> dict: # if exclude is set, ignore this _get_exclude_fields() exclude = self._get_exclude_fields() if kwargs.get("exclude", False) is False else kwargs["exclude"] kwargs["exclude"] = exclude - return super().dict(**kwargs) + + serial = super().dict(**kwargs) + + # serialize uuid + serial["uuid"] = str(self.uuid) + return serial class DatasetUser(ABC): From 215de3bce19ba723f29d881bd0c72e64d1ac4d51 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 16:24:32 -0500 Subject: [PATCH 029/205] update dataset unittest --- python/lib/core/dmod/test/test_dataset.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/lib/core/dmod/test/test_dataset.py b/python/lib/core/dmod/test/test_dataset.py index 0548b395c..68f13ac4d 100644 --- a/python/lib/core/dmod/test/test_dataset.py +++ b/python/lib/core/dmod/test/test_dataset.py @@ -60,14 +60,14 @@ def setUp(self) -> None: discrete_restrictions=[self.example_catchment_restrictions[i]])) self.example_datasets.append(self._init_dataset_example(i)) date_fmt = Dataset.get_datetime_str_format() - self.example_data.append({Dataset._KEY_NAME: self.gen_dataset_name(i), - Dataset._KEY_DATA_DOMAIN: self.example_domains[i].to_dict(), - Dataset._KEY_DATA_CATEGORY: self.example_categories[i].name, - Dataset._KEY_TYPE: self.example_types[i].name, - Dataset._KEY_UUID: str(self.example_datasets[i].uuid), - Dataset._KEY_ACCESS_LOCATION: 'location_{}'.format(i), - Dataset._KEY_IS_READ_ONLY: False, - Dataset._KEY_CREATED_ON: self._created_on.strftime(date_fmt), + self.example_data.append({"name": self.gen_dataset_name(i), + "data_domain": self.example_domains[i].to_dict(), + "data_category": self.example_categories[i].name, + "type": self.example_types[i].name, + "uuid": str(self.example_datasets[i].uuid), + "access_location": 'location_{}'.format(i), + "is_read_only": False, + "created_on": self._created_on, # NOTE: breaking change }) def test_factory_init_from_deserialized_json_0_a(self): From e8dd16a8b40862dcf958b6d9bb6a0cc8ae43bcc6 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 16:44:55 -0500 Subject: [PATCH 030/205] adjust factory init logic in ContinuousRestriction --- python/lib/core/dmod/core/meta_data.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index 130717b11..d06462b8f 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -254,6 +254,11 @@ def validate_start_before_end(cls, values): # validate variable is not UNKNOWN variant _validate_variable = validator("variable", allow_reuse=True)(_validate_variable_is_known) + def __eq__(self, o: object) -> bool: + if not isinstance(o, ContinuousRestriction): + return False + return self.variable == o.variable and self.begin == o.begin and self.end == o.end + @classmethod def convert_truncated_serial_form(cls, truncated_json_obj: dict, datetime_format: Optional[str] = None) -> dict: """ @@ -291,8 +296,15 @@ def convert_truncated_serial_form(cls, truncated_json_obj: dict, datetime_format def factory_init_from_deserialized_json(cls, json_obj: dict): if "subclass" in json_obj: try: + subclass_str = json_obj["subclass"] + + if subclass_str == cls.__name__: + json_obj["subclass"] = cls + return subclass(**json_obj) + for subclass in cls.__subclasses__(): - if subclass.__name__ == json_obj["subclass"]: + if subclass.__name__ == subclass_str: + json_obj["subclass"] = subclass return subclass(**json_obj) except: pass From 48308b066d182cb836d6cd5a3b5cd18b13a7e322 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 9 Jan 2023 16:45:18 -0500 Subject: [PATCH 031/205] add meta_data module unit tests --- python/lib/core/dmod/test/test_meta_data.py | 212 ++++++++++++++++++++ 1 file changed, 212 insertions(+) create mode 100644 python/lib/core/dmod/test/test_meta_data.py diff --git a/python/lib/core/dmod/test/test_meta_data.py b/python/lib/core/dmod/test/test_meta_data.py new file mode 100644 index 000000000..4843e9ab7 --- /dev/null +++ b/python/lib/core/dmod/test/test_meta_data.py @@ -0,0 +1,212 @@ +import unittest +from datetime import datetime + +from ..core.meta_data import ( + ContinuousRestriction, + DiscreteRestriction, + StandardDatasetIndex, + DataDomain, + DataFormat, + TimeRange, + DataCategory, + DataRequirement, +) + +from typing import Any + + +class TestContinuousRestriction(unittest.TestCase): + def test_custom_datetime_pattern(self): + o = ContinuousRestriction( + begin="2020-01-01", + end="2020-01-02", + variable="TIME", + datetime_pattern="%Y-%m-%d", + ) + self.assertEqual(o.variable, StandardDatasetIndex.TIME) + + def test_custom_datetime_pattern_should_fail(self): + with self.assertRaises(RuntimeError): + ContinuousRestriction( + begin="2020-01-01", + end="2019-12-31", + variable="TIME", + datetime_pattern="%Y-%m-%d", + ) + + def test_create_from_python_objects(self): + begin = datetime(2020, 1, 1) + end = datetime(2020, 1, 2) + o = ContinuousRestriction( + begin=begin, end=end, variable=StandardDatasetIndex.TIME + ) + self.assertEqual(o.begin, begin) + self.assertEqual(o.end, end) + self.assertEqual(o.variable, StandardDatasetIndex.TIME) + + def test_create_fails_with_invalid_variable(self): + begin = datetime(2020, 1, 1) + end = datetime(2020, 1, 2) + with self.assertRaises(ValueError): + ContinuousRestriction( + begin=begin, end=end, variable=StandardDatasetIndex.UNKNOWN + ) + + def test_eq(self): + begin = datetime(2020, 1, 1) + end = datetime(2020, 1, 2) + o1 = ContinuousRestriction( + begin=begin, end=end, variable=StandardDatasetIndex.TIME + ) + o2 = ContinuousRestriction( + begin=begin, end=end, variable=StandardDatasetIndex.TIME + ) + + self.assertEqual(o1, o2) + + def test_hash(self): + begin = datetime(2020, 1, 1) + end = datetime(2020, 1, 2) + var = StandardDatasetIndex.TIME + expected_hash = hash(f"{var.name}-{begin}-{end}") + o_hash = hash(ContinuousRestriction(variable=var, begin=begin, end=end)) + self.assertEqual(expected_hash, o_hash) + + def test_to_dict(self): + begin = "2020-01-01" + end = "2020-01-02" + d = ContinuousRestriction( + begin=begin, + end=end, + variable="TIME", + datetime_pattern="%Y-%m-%d", + ).to_dict() + self.assertEqual(d["begin"], begin) + self.assertEqual(d["end"], end) + self.assertEqual(d["variable"], StandardDatasetIndex.TIME.name) + + def test_factory_init_from_deserialized_json(self): + deserialied = {"begin": 0, "end": 1, "variable": "TIME"} + o1 = ContinuousRestriction.factory_init_from_deserialized_json(deserialied) + + deserialied = { + "begin": 0, + "end": 1, + "variable": "TIME", + "subclass": "ContinuousRestriction", + } + o2 = ContinuousRestriction.factory_init_from_deserialized_json(deserialied) + self.assertEqual(o1, o2) + + def test_to_json(self): + import json + begin = "2020-01-01" + end = "2020-01-02" + d = json.loads( + ContinuousRestriction( + begin=begin, + end=end, + variable="TIME", + datetime_pattern="%Y-%m-%d", + ).to_json() + ) + + self.assertEqual(d["begin"], begin) + self.assertEqual(d["end"], end) + self.assertEqual(d["variable"], StandardDatasetIndex.TIME.name) + + +class TestDiscreteRestriction(unittest.TestCase): + def test_duplicate_values_removed(self): + o = DiscreteRestriction( + variable="TIME", values=[1, 1, 1], remove_duplicates=True + ) + self.assertListEqual(o.values, [1]) + + def test_values_reordered(self): + values = [3, 2, 1] + o = DiscreteRestriction(variable="TIME", values=values, allow_reorder=True) + self.assertListEqual(o.values, values[::-1]) + + def test_values_removed_not_reordered(self): + values = [3, 3, 2, 1] + o = DiscreteRestriction( + variable="TIME", + values=values, + allow_reorder=False, + remove_duplicates=True, + ) + self.assertListEqual(o.values, values[1:]) + + def test_values_reordered_not_removed(self): + values = [3, 3, 2, 1] + o = DiscreteRestriction( + variable="TIME", + values=values, + allow_reorder=True, + remove_duplicates=False, + ) + self.assertListEqual(o.values, values[::-1]) + + +class TestDataDomain(unittest.TestCase): + def test_it_works(self): + disc_rest = DiscreteRestriction( + variable=StandardDatasetIndex.DATA_ID, values=["0"] + ) + o = DataDomain( + data_format=DataFormat.AORC_CSV, + discrete_restrictions=[disc_rest], + data_fields=dict(a="str", b="float", c="int", d="datetime"), + ) + self.assertEqual(o.custom_data_fields["a"], str) + self.assertEqual(o.custom_data_fields["b"], float) + self.assertEqual(o.custom_data_fields["c"], int) + self.assertEqual(o.custom_data_fields["d"], Any) + + def test_init_fails_if_insufficient_restrictions(self): + with self.assertRaises(RuntimeError): + DataDomain( + data_format=DataFormat.AORC_CSV, + continuous_restrictions=[], + discrete_restrictions=[], + ) + + with self.assertRaises(RuntimeError): + DataDomain(data_format=DataFormat.AORC_CSV) + + def test_factory_init_from_deserialized_json(self): + data = { + "data_format": "AORC_CSV", + "continuous_restrictions": [], + "discrete_restrictions": [{"variable": "DATA_ID", "values": ["0"]}], + } + o = DataDomain.factory_init_from_deserialized_json(data) + self.assertEqual(o.data_format.name, "AORC_CSV") + + +class TestTimeRange(unittest.TestCase): + def test_begin_cannot_come_after_end(self): + with self.assertRaises(RuntimeError): + TimeRange(begin=1, end=0) + + def test_cannot_provide_non_time_variable(self): + with self.assertRaises(RuntimeError): + TimeRange(variable=StandardDatasetIndex.DATA_ID, begin=1, end=0) + +class TestDataRequirement(unittest.TestCase): + def test_unset_fields_are_excluded_in_serialized_dict(self): + domain = DataDomain( + data_format=DataFormat.AORC_CSV, + discrete_restrictions=[ + DiscreteRestriction(variable=StandardDatasetIndex.DATA_ID, values=["0"]) + ], + ) + + d = DataRequirement( + domain=domain, is_input=True, category=DataCategory.CONFIG + ).to_dict() + self.assertNotIn("size", d) + self.assertNotIn("fulfilled_by", d) + self.assertNotIn("fulfilled_access_at", d) + From 1f8a24026db49c55375a77ef09d5e3c75ca24f15 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 10 Jan 2023 09:34:19 -0500 Subject: [PATCH 032/205] remove unnecessary return by_alias check. now handled by superclass, Serializable. --- python/lib/core/dmod/core/meta_data.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index d06462b8f..9e8dcd6d2 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -630,11 +630,6 @@ def dict(self, **kwargs) -> dict: called by `to_dict` and `to_json`. """ # TODO: aaraney, handle encoding type (int, float, etc.) as str - by_alias = kwargs.pop("by_alias") if "by_alias" in kwargs else True - - if self.data_format.data_fields is None: - return super().dict(by_alias=by_alias, **kwargs) - exclude = {"custom_data_fields"} # merge exclude fields and excludes from kwargs @@ -643,7 +638,7 @@ def dict(self, **kwargs) -> dict: if values is not None: exclude = {*exclude, *values} - return super().dict(by_alias=by_alias, exclude=exclude, **kwargs) + return super().dict(exclude=exclude, **kwargs) class DataCategory(PydanticEnum): From 8e691f9812ede6527cc4e0cc757de7d0518f1de5 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 10 Jan 2023 09:35:04 -0500 Subject: [PATCH 033/205] handle coercing bool str repr into python type --- python/lib/core/dmod/core/meta_data.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index 9e8dcd6d2..0e6a9ef8d 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -458,6 +458,8 @@ def handle_type_map(t): return int elif t == "float" or t == float: return float + elif t == "bool" or t == bool: + return bool # maintain reference to a passed in python type or subtype elif isinstance(t, type): return t From 82053ea879935c32908aa2d05f171665d5453b9a Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 10 Jan 2023 09:35:57 -0500 Subject: [PATCH 034/205] add method that encodes python built in types as str (i.e. 'int') and all other as 'Any' --- python/lib/core/dmod/core/meta_data.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index 0e6a9ef8d..7e26b2aea 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -625,6 +625,13 @@ def indices(self) -> List[str]: """ return self.data_format.indices + @staticmethod + def _encode_py_type(o: type) -> str: + """Return string representation of a built in type (e.g. 'int') or 'Any'.""" + if o in {str, int, float, bool}: + return o.__name__ + return "Any" + def dict(self, **kwargs) -> dict: """ `data_fields` is excluded from dict if `self.data_format.data_fields` is None. From 50d1b8b7d57d9e644247013c846dc3ab789d44f7 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 10 Jan 2023 11:20:16 -0500 Subject: [PATCH 035/205] correctly deserialize DataDomain's custom_data_fields field --- python/lib/core/dmod/core/meta_data.py | 49 +++++++++++++++++++++----- 1 file changed, 41 insertions(+), 8 deletions(-) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index 7e26b2aea..a55f6c687 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -638,16 +638,49 @@ def dict(self, **kwargs) -> dict: called by `to_dict` and `to_json`. """ - # TODO: aaraney, handle encoding type (int, float, etc.) as str - exclude = {"custom_data_fields"} + DATA_FIELDS_KEY = "custom_data_fields" + DATA_FIELDS_ALIAS_KEY = "data_fields" + exclude = {DATA_FIELDS_KEY} # merge exclude fields and excludes from kwargs - if "exclude" in kwargs: - values = kwargs.pop("exclude") - if values is not None: - exclude = {*exclude, *values} - - return super().dict(exclude=exclude, **kwargs) + kwarg_exclude: Optional[Set[str]] = kwargs.get("exclude") + if kwarg_exclude is not None: + exclude = {*exclude, *kwarg_exclude} + + # cases when "custom_data_fields" is excluded + if ( + self.data_format.data_fields is None + or ( + kwarg_exclude is not None + and DATA_FIELDS_KEY in kwarg_exclude + ) + ): + # overwrite existing exclude with, potentially, merged version. + kwargs["exclude"] = exclude + return super().dict(**kwargs) + + # serialize "custom_data_fields" python types + custom_data_fields = ( + {k: self._encode_py_type(v) for k, v in self.custom_data_fields.items()} + if self.custom_data_fields is not None + # need this to support `exclude_none=False` + else None + ) + + # exclude "custom_data_fields" and potentially other fields + kwargs["exclude"] = exclude + + serial = super().dict(**kwargs) + + # case: `by_alias` is True + if kwargs.get("by_alias", False): + # reincorporate "custom_data_fields" using it's alias + serial[DATA_FIELDS_ALIAS_KEY] = custom_data_fields + return serial + + # reincorporate "custom_data_fields" using it's name + serial[DATA_FIELDS_KEY] = custom_data_fields + return serial class DataCategory(PydanticEnum): From 25ad0891a649a700965a72383e29b3fce44e2404 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 10 Jan 2023 12:00:06 -0500 Subject: [PATCH 036/205] ensure restriction fields have empty list by default --- python/lib/core/dmod/core/meta_data.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index a55f6c687..215441a25 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -449,6 +449,12 @@ class DataDomain(Serializable): alias="data_fields" ) + @validator("continuous_restrictions", "discrete_restrictions", always=True) + def _validate_restriction_default(cls, value): + if value is None: + return [] + return value + @validator("custom_data_fields") def validate_data_fields(cls, values): def handle_type_map(t): From ad5caca7bdb5d5f98001d55ae3d45b2d502f3fd3 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 10 Jan 2023 12:14:14 -0500 Subject: [PATCH 037/205] add data domain unit tests --- python/lib/core/dmod/test/test_meta_data.py | 28 +++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/lib/core/dmod/test/test_meta_data.py b/python/lib/core/dmod/test/test_meta_data.py index 4843e9ab7..17c3914ec 100644 --- a/python/lib/core/dmod/test/test_meta_data.py +++ b/python/lib/core/dmod/test/test_meta_data.py @@ -184,6 +184,34 @@ def test_factory_init_from_deserialized_json(self): o = DataDomain.factory_init_from_deserialized_json(data) self.assertEqual(o.data_format.name, "AORC_CSV") + def test_to_dict(self): + input_data_fields = {"a": "int", "b": "float", "c": "bool", "d": "str", "e": "flux_capacitor"} + expected_serialized_data_fields = {"a": "int", "b": "float", "c": "bool", "d": "str", "e": "Any"} + data = { + "data_format": "AORC_CSV", + "continuous": [], + "discrete": [{"variable": "DATA_ID", "values": ["0"]}], + } + input_data = data.copy() + input_data["data_fields"] = input_data_fields + + expected_data = data.copy() + expected_data["data_fields"] = expected_serialized_data_fields + + # better error detection if this fails + o = DataDomain(**input_data) + serial = o.to_dict() + self.assertDictEqual(serial, expected_data) + + def test_factory_init_from_restriction_collections(self): + catchment_id = ["12"] + o = DataDomain.factory_init_from_restriction_collections(data_format=DataFormat.AORC_CSV, CATCHMENT_ID=catchment_id) + self.assertListEqual(o.discrete_restrictions[0].values, catchment_id) + + def test_factory_init_from_restriction_collections_fail_for_mismatching_index_field(self): + with self.assertRaises(RuntimeError): + DataDomain.factory_init_from_restriction_collections(data_format=DataFormat.AORC_CSV, DATA_ID=["12"]) + class TestTimeRange(unittest.TestCase): def test_begin_cannot_come_after_end(self): From c0df060610a6573294c9aab8aaa98cbb0cfae369 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 10 Jan 2023 16:37:58 -0500 Subject: [PATCH 038/205] add pydantic to root requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index f3142c57c..a7a99a3e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,3 +29,4 @@ channels channels-redis Pint django_rq +pydantic From 7a63bc85404af1320a640fe57903d5cc0c17bc3a Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 10 Jan 2023 16:53:04 -0500 Subject: [PATCH 039/205] DataDomain, handle case when custom_data_fields=None --- python/lib/core/dmod/core/meta_data.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index 215441a25..d2756e39e 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -471,6 +471,9 @@ def handle_type_map(t): return t return Any + if values is None: + return None + return {k: handle_type_map(v) for k, v in values.items()} @root_validator() From a5292aba6f2d787c69c28fddf90f6ed814e714fb Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 10 Jan 2023 17:45:55 -0500 Subject: [PATCH 040/205] remove erroneous TODO comments --- python/lib/core/dmod/core/dataset.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/lib/core/dmod/core/dataset.py b/python/lib/core/dmod/core/dataset.py index 2740a830a..388a2beac 100644 --- a/python/lib/core/dmod/core/dataset.py +++ b/python/lib/core/dmod/core/dataset.py @@ -228,10 +228,8 @@ def extend_life(self, value: Union[datetime, timedelta]) -> bool: if not self.is_temporary: return False elif isinstance(value, timedelta): - # TODO: Fix bug. expires could be None self._set_expires(self.expires + value) return True - # TODO: Fix bug. expires could be None elif isinstance(value, datetime) and self.expires < value: self._set_expires(value) return True From 6332e9609b986725ab6676794960d35de28bdcf5 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 10 Jan 2023 21:40:30 -0500 Subject: [PATCH 041/205] @robertbartel confirmed in #239 that category and data_domain are not optional fields --- python/lib/core/dmod/core/dataset.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/lib/core/dmod/core/dataset.py b/python/lib/core/dmod/core/dataset.py index 388a2beac..8bf5329c4 100644 --- a/python/lib/core/dmod/core/dataset.py +++ b/python/lib/core/dmod/core/dataset.py @@ -61,10 +61,8 @@ class Dataset(Serializable): _SERIAL_DATETIME_STR_FORMAT: ClassVar = '%Y-%m-%d %H:%M:%S' name: str = Field(description="The name for this dataset, which also should be a unique identifier.") - # QUESTION: should this be optional? see factory_init_from_deserialized_json - category: Optional[DataCategory] = Field(None, alias="data_category", description="The ::class:`DataCategory` type value for this instance.") - # QUESTION: should this be optional? see factory_init_from_deserialized_json - data_domain: Optional[DataDomain] + category: DataCategory = Field(None, alias="data_category", description="The ::class:`DataCategory` type value for this instance.") + data_domain: DataDomain dataset_type: DatasetType = Field(DatasetType.UNKNOWN, alias="type") access_location: str = Field(description="String representation of the location at which this dataset is accessible.") uuid: Optional[UUID] = Field(default_factory=uuid4) From 13947837d715fd6ed929793db3d7d14fc6f0272a Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 10 Jan 2023 21:42:10 -0500 Subject: [PATCH 042/205] remove commented out Dataset factory_init_from_deserialized_json --- python/lib/core/dmod/core/dataset.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/python/lib/core/dmod/core/dataset.py b/python/lib/core/dmod/core/dataset.py index 8bf5329c4..c3b5feb25 100644 --- a/python/lib/core/dmod/core/dataset.py +++ b/python/lib/core/dmod/core/dataset.py @@ -118,28 +118,6 @@ class Config: validate_assignment = True arbitrary_types_allowed = True - # TODO: Remove after draft review - # @classmethod - # def factory_init_from_deserialized_json(cls, json_obj: dict): - # try: - # manager_uuid = UUID(json_obj[cls._KEY_MANAGER_UUID]) if cls._KEY_MANAGER_UUID in json_obj else None - # return cls(name=json_obj[cls._KEY_NAME], - # category=DataCategory.get_for_name(json_obj[cls._KEY_DATA_CATEGORY]), - # data_domain=DataDomain.factory_init_from_deserialized_json(json_obj[cls._KEY_DATA_DOMAIN]), - # dataset_type=DatasetType.get_for_name(json_obj[cls._KEY_TYPE]), - # access_location=json_obj[cls._KEY_ACCESS_LOCATION], - # description=json_obj.get(cls._KEY_DESCRIPTION, None), - # uuid=UUID(json_obj[cls._KEY_UUID]), - # manager_uuid=manager_uuid, - # is_read_only=json_obj[cls._KEY_IS_READ_ONLY], - # expires=cls._date_parse_helper(json_obj, cls._KEY_EXPIRES), - # derived_from=json_obj[cls._KEY_DERIVED_FROM] if cls._KEY_DERIVED_FROM in json_obj else None, - # derivations=json_obj[cls._KEY_DERIVATIONS] if cls._KEY_DERIVATIONS in json_obj else [], - # created_on=cls._date_parse_helper(json_obj, cls._KEY_CREATED_ON), - # last_updated=cls._date_parse_helper(json_obj, cls._KEY_LAST_UPDATE)) - # except Exception as e: - # return None - def __hash__(self): return hash(','.join([self.__class__.__name__, self.name, self.category.name, str(hash(self.data_domain)), self.access_location, str(self.is_read_only), str(hash(self.created_on))])) From 22ed48b74f597c6252fa89b56d9399b651e8c423 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 17 Jan 2023 20:55:46 -0500 Subject: [PATCH 043/205] fix DataDomain's dict method. related to #245 --- python/lib/core/dmod/core/meta_data.py | 59 +++++++++++++------------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index d2756e39e..ef683b3d6 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -641,7 +641,17 @@ def _encode_py_type(o: type) -> str: return o.__name__ return "Any" - def dict(self, **kwargs) -> dict: + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> Dict[str, Union[str, int]]: """ `data_fields` is excluded from dict if `self.data_format.data_fields` is None. @@ -649,45 +659,36 @@ def dict(self, **kwargs) -> dict: """ DATA_FIELDS_KEY = "custom_data_fields" DATA_FIELDS_ALIAS_KEY = "data_fields" - exclude = {DATA_FIELDS_KEY} - - # merge exclude fields and excludes from kwargs - kwarg_exclude: Optional[Set[str]] = kwargs.get("exclude") - if kwarg_exclude is not None: - exclude = {*exclude, *kwarg_exclude} - - # cases when "custom_data_fields" is excluded - if ( - self.data_format.data_fields is None - or ( - kwarg_exclude is not None - and DATA_FIELDS_KEY in kwarg_exclude + + exclude = exclude or set() + + exclude_data_fields = DATA_FIELDS_KEY in exclude + exclude.add(DATA_FIELDS_KEY) + + serial = super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, ) - ): - # overwrite existing exclude with, potentially, merged version. - kwargs["exclude"] = exclude - return super().dict(**kwargs) + + if exclude_data_fields or self.data_format.data_fields: + return serial # serialize "custom_data_fields" python types custom_data_fields = ( {k: self._encode_py_type(v) for k, v in self.custom_data_fields.items()} if self.custom_data_fields is not None - # need this to support `exclude_none=False` - else None + else dict() ) - # exclude "custom_data_fields" and potentially other fields - kwargs["exclude"] = exclude - - serial = super().dict(**kwargs) - - # case: `by_alias` is True - if kwargs.get("by_alias", False): - # reincorporate "custom_data_fields" using it's alias + if by_alias: serial[DATA_FIELDS_ALIAS_KEY] = custom_data_fields return serial - # reincorporate "custom_data_fields" using it's name serial[DATA_FIELDS_KEY] = custom_data_fields return serial From 2ecec46466c2dce8ba07ec64d6d738d1f14d4c80 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 18 Jan 2023 22:04:59 -0500 Subject: [PATCH 044/205] revert DataRequirement's __eq__ implimentation --- python/lib/core/dmod/core/meta_data.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index ef683b3d6..20d708e2c 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -730,6 +730,14 @@ class DataRequirement(Serializable): is_input: bool = Field(description="Whether this represents required input data, as opposed to a requirement for storing output data.") size: Optional[int] + def __eq__(self, other: object) -> bool: + return ( + self.__class__ == other.__class__ + and self.domain == other.domain + and self.is_input == other.is_input + and self.category == other.category + ) + def __hash__(self): return hash('{}-{}-{}'.format(hash(self.domain), self.is_input, self.category)) From d58536f4ca82e7e06526518da31d62e35856d709 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 18 Jan 2023 22:31:02 -0500 Subject: [PATCH 045/205] add note to remove DataDomain.custom_data_fields in future. relates note to #245 --- python/lib/core/dmod/core/meta_data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index 20d708e2c..ee900f977 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -443,6 +443,7 @@ class DataDomain(Serializable): alias="discrete", default_factory=list ) + # NOTE: remove this field after #239 is merged. will close #245. custom_data_fields: Optional[Dict[str, Union[str, int, float, Any]]] = Field( description=("This will either be directly from the format, if its format specifies any fields, or from a custom fields" "attribute that may be set during initialization (but is ignored when the format specifies fields)."), From f4758f16aab32c101772f6a8758e2e59bbe34d80 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 18 Jan 2023 22:32:04 -0500 Subject: [PATCH 046/205] exclude custom_data_fields during serialization if empty T variant. This breaks with Serializable's convention to only exclude `None` value fields. --- python/lib/core/dmod/core/meta_data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index ee900f977..8a6c3cde0 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -676,7 +676,9 @@ def dict( exclude_none=exclude_none, ) - if exclude_data_fields or self.data_format.data_fields: + # NOTE: `custom_data_fields` is excluded if it is a empty T variant. This breaks with + # Serializable's convention to only exclude `None` value fields. + if exclude_data_fields or self.data_format.data_fields or not self.custom_data_fields: return serial # serialize "custom_data_fields" python types From ec617ff311e9aa06ebe11a9fb4a0cc9a80e2711f Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 13:57:48 -0500 Subject: [PATCH 047/205] add field_serializers Config option to Serializable The field_serializers Config option is a apping of field name to callable that changes the default serialized form of a field (i.e. to_dict, to_json, dict, json). This is often helpful when a field requires a use case specific representation (i.e. datetime) or is not JSON serializable. For example, if a field is a datetime type, this feature enables changing how that datetime object is serialized (e.g. ISO8601 with only seconds). The main intent of this feature is to discourage subclasses from overriding `dict` to implement use case specific serialization and provide a pathway to achieve this. --- python/lib/core/dmod/core/serializable.py | 154 ++++++++++++++++++++++ 1 file changed, 154 insertions(+) diff --git a/python/lib/core/dmod/core/serializable.py b/python/lib/core/dmod/core/serializable.py index 9d464fb06..6a1852071 100644 --- a/python/lib/core/dmod/core/serializable.py +++ b/python/lib/core/dmod/core/serializable.py @@ -2,7 +2,10 @@ from numbers import Number from enum import Enum from typing import Any, Callable, ClassVar, Dict, Type, TypeVar, TYPE_CHECKING, Union, Optional +from typing_extensions import TypeAlias from pydantic import BaseModel, Field +from functools import lru_cache +import inspect import json from .decorators import deprecated @@ -11,9 +14,17 @@ from pydantic.typing import ( AbstractSetIntStr, MappingIntStrAny, + DictStrAny ) Self = TypeVar("Self", bound="Serializable") +M = TypeVar("M", bound="Serializable") +T = TypeVar("T") +R = Union[str, int, float, bool, None] + +FnSerializer: TypeAlias = Callable[[T], R] +SelfFieldSerializer: TypeAlias = Callable[[M, T], R] +FieldSerializer = Union[SelfFieldSerializer[M, Any], FnSerializer[Any]] class Serializable(BaseModel, ABC): @@ -66,6 +77,37 @@ class User(Serializable): class Config: # fields can be populated using their given name or provided alias allow_population_by_field_name = True + field_serializers: Dict[str, FieldSerializer[M]] = {} + """ + Mapping of field name to callable that changes the default serialized form of a field. + This is often helpful when a field requires a use case specific representation (i.e. + datetime) or is not JSON serializable. + + Callables can be specified as either: + (value: T) -> R or + (self: M, value: T) -> R + where: + T is the field type + M is an instance of the Serializable subtype + R is the, json serializable, return type of the transformation + + Example: + + class Observation(Serializable): + value: float + value_time: datetime.datetime + value_unit: str + + class Config: + field_serializers = { + "value_time": lambda value_time: value_time.isoformat(timespec="seconds") + } + + o = Observation(value=42.0, value_time=datetime(2020, 1, 1), value_unit="m") + expect = {"value": 42.0, "value_time": "2020-01-01T00:00:00", "value_unit": "m"} + + assert o.dict() == expect + """ @classmethod def _get_invalid_type_message(cls): @@ -224,6 +266,30 @@ def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: """ return self.dict(exclude_none=True, by_alias=True) + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> "DictStrAny": + serial = super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + transformers = _collect_field_transformers(type(self)) + return _transform_fields(self, transformers, serial, by_alias=by_alias) + def __str__(self): return str(self.to_json()) @@ -316,3 +382,91 @@ class BasicResultIndicator(ResultIndicator): """ Bare-bones, concrete implementation of ::class:`ResultIndicator`. """ + +# NOTE: function below are intentionally not methods on `Serializable` to avoid subclasses +# overriding their behavior. + +@lru_cache +def _collect_field_transformers(cls: Type[M]) -> Dict[str, FieldSerializer[M]]: + transformers: Dict[str, FieldSerializer[M]] = {} + + # base case + if cls == Serializable: + return transformers + + super_classes = cls.__mro__ + base_class_index = super_classes.index(Serializable) + + # index 0 is the calling cls try and merge `field_serializers` from superclasses up until + # Base class (stopping condition). merge in reverse order of mro so child class + # `field_serializers` override superclasses `field_serializers`. + for s in super_classes[1:base_class_index][::-1]: + if not issubclass(s, Serializable): + continue + + # doesn't have a Config class or Config.field_serializers + if not hasattr(s, "Config") and not hasattr(s.Config, "field_serializers"): + continue + + transformers.update(_collect_field_transformers(s)) + + # has Config class and Config.field_serializers + if hasattr(cls, "Config") and hasattr(cls.Config, "field_serializers"): + transformers.update(cls.Config.field_serializers) + + return transformers + + +def _get_field_alias(cls: Type[M], field_name: str) -> str: + # NOTE: KeyError will raise if field_name does not exist + return cls.__fields__[field_name].alias + + +def _transform_fields( + self: M, + transformers: Dict[str, FieldSerializer[M]], + serial: Dict[str, Any], + by_alias: bool = False, +) -> Dict[str, Any]: + for field, transform in transformers.items(): + if by_alias: + field = _get_field_alias(type(self), field) + + if field not in serial: + # TODO: field could have been excluded. need to consider what to do if invalid + # serial key was provided. + continue + + if not inspect.isfunction(transform): + error_message = ( + f"non-callable field_transformer provided for field {field!r}." + "\n\n" + "field_transformers should be specified as either:" + "\n" + "\t(value: T) -> R\n" + "\t(self: M, value: T) -> R\n" + "where:\n" + "\tT is the field type\n" + "\tM is an instance of the Serializable subtype\n" + "\tR is the, json serializable, return type of the transformation" + ) + raise ValueError(error_message) + + sig = inspect.signature(transform) + + if len(sig.parameters) == 1: + serial[field] = transform(serial[field]) + + elif len(sig.parameters) == 2: + serial[field] = transform(self, serial[field]) + + else: + error_message = ( + f"unsupported parameter length for field_transformer callable, {field!r}." + "\n\n" + "field_transformer's take either 1 or 2 parameters, (value: T) or (self, value: T),\n" + "where T is the type of the field." + ) + raise RuntimeError(error_message) + + return serial From 80a8e229e97486e0d225549099c768d46d72514d Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 14:03:33 -0500 Subject: [PATCH 048/205] fix type hints --- python/lib/core/dmod/core/serializable.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/lib/core/dmod/core/serializable.py b/python/lib/core/dmod/core/serializable.py index 6a1852071..59331a8f6 100644 --- a/python/lib/core/dmod/core/serializable.py +++ b/python/lib/core/dmod/core/serializable.py @@ -2,7 +2,7 @@ from numbers import Number from enum import Enum from typing import Any, Callable, ClassVar, Dict, Type, TypeVar, TYPE_CHECKING, Union, Optional -from typing_extensions import TypeAlias +from typing_extensions import Self, TypeAlias from pydantic import BaseModel, Field from functools import lru_cache import inspect @@ -17,7 +17,6 @@ DictStrAny ) -Self = TypeVar("Self", bound="Serializable") M = TypeVar("M", bound="Serializable") T = TypeVar("T") R = Union[str, int, float, bool, None] @@ -115,7 +114,7 @@ def _get_invalid_type_message(cls): return invalid_type_msg @classmethod - def factory_init_from_deserialized_json(cls: Self, json_obj: dict) -> Optional[Self]: + def factory_init_from_deserialized_json(cls: Type[Self], json_obj: dict) -> Optional[Self]: """ Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. From 070b57a1cc35ff867b5990e0512aeec79b00f0b6 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 14:38:40 -0500 Subject: [PATCH 049/205] test Serializable field_serializer Config option --- .../test_serializable_field_serializers.py | 279 ++++++++++++++++++ 1 file changed, 279 insertions(+) create mode 100644 python/lib/core/dmod/test/test_serializable_field_serializers.py diff --git a/python/lib/core/dmod/test/test_serializable_field_serializers.py b/python/lib/core/dmod/test/test_serializable_field_serializers.py new file mode 100644 index 000000000..648df8bda --- /dev/null +++ b/python/lib/core/dmod/test/test_serializable_field_serializers.py @@ -0,0 +1,279 @@ +import unittest +from typing import List +from pydantic import SecretStr +from datetime import date + +from ..core.serializable import Serializable + + +class Country(Serializable): + name: str + phone_code: int + + class Config: + field_serializers = {"name": lambda s: s.upper()} + + +class Address(Serializable): + post_code: int + country: Country + + +class CardDetails(Serializable): + number: SecretStr + expires: date + + +class Hobby(Serializable): + name: str + info: str + + class Config: + fields = {"name": {"alias": "NAME"}} + + +class User(Serializable): + first_name: str + second_name: str + address: Address + card_details: CardDetails + hobbies: List[Hobby] + + class Config: + field_serializers = {"first_name": lambda f: f.upper()} + + +class A(Serializable): + field: str + + class Config: + field_serializers = {"field": lambda s: s.lower()} + + +class B(Serializable): + field: str + + class Config: + field_serializers = {"field": lambda s: s.upper()} + + +class C(Serializable): + a: A + b: B + + +class D(Serializable): + field: str + + class Config: + field_serializers = {"field": lambda s: s.upper()} + + +class E(D): + class Config: + field_serializers = {"field": lambda s: s} + + +class F(Serializable): + a: str + + class Config: + field_serializers = {"a": lambda s: s.lower()} + + +class G(Serializable): + b: str + + class Config: + field_serializers = {"b": lambda s: s.upper()} + + +class H(F, G): + ... + + +class I(G, F): + ... + + +class J(Serializable): + a: int + + +class K(Serializable): + j: J + + class Config: + field_serializers = {"j": lambda self, _: self.j.a} + + +class L(Serializable): + a: str + + class Config: + field_serializers = {"a": 12} + + +class M(Serializable): + a: str + + class Config: + field_serializers = {"a": lambda a, b, c: (a, b, c)} + + +class N(Serializable): + a: str + + class Config: + field_serializers = {"a": lambda: "should fail"} + + +class RootModel(Serializable): + __root__: int + + class Config: + field_serializers = {"__root__": lambda s: s ** 2} + + +def user_fixture() -> User: + return User( + first_name="John", + second_name="Doe", + address=Address(post_code=123456, country=Country(name="usa", phone_code=1)), + card_details=CardDetails(number=4212934504460000, expires=date(2020, 5, 1)), + hobbies=[ + Hobby(name="Programming", info="Writing code and stuff"), + Hobby(name="Gaming", info="Hell Yeah!!!"), + ], + ) + + +class TestFieldSerializerConfigOption(unittest.TestCase): + def test_exclude_keys_User(self): + user = user_fixture() + + exclude_keys = { + "second_name": True, + "address": {"post_code": True, "country": {"phone_code"}}, + "card_details": True, + # You can exclude fields from specific members of a tuple/list by index: + "hobbies": {-1: {"info"}}, + } + + expect = { + "first_name": "JOHN", + "address": {"country": {"name": "USA"}}, + "hobbies": [ + { + "name": "Programming", + "info": "Writing code and stuff", + }, + {"name": "Gaming"}, + ], + } + + self.assertDictEqual(user.dict(exclude=exclude_keys), expect) + + def test_include_keys_User(self): + user = user_fixture() + + include_keys = { + "first_name": True, + "address": {"country": {"name"}}, + "hobbies": {0: True, -1: {"name"}}, + } + + expect = { + "first_name": "JOHN", + "address": {"country": {"name": "USA"}}, + "hobbies": [ + { + "name": "Programming", + "info": "Writing code and stuff", + }, + {"name": "Gaming"}, + ], + } + + self.assertDictEqual(user.dict(include=include_keys), expect) + + def test_exclude_keys_by_alias_User(self): + user = user_fixture() + + exclude_keys = { + "second_name": True, + "address": {"post_code": True, "country": {"phone_code"}}, + "card_details": True, + # You can exclude fields from specific members of a tuple/list by index: + "hobbies": {-1: {"info"}}, + } + + expect = { + "first_name": "JOHN", + "address": {"country": {"name": "USA"}}, + "hobbies": [ + { + "NAME": "Programming", + "info": "Writing code and stuff", + }, + {"NAME": "Gaming"}, + ], + } + + self.assertDictEqual(user.dict(exclude=exclude_keys, by_alias=True), expect) + + def test_composed_fields_dont_mangle_C(self): + o = C(a=A(field="A"), b=B(field="b")) + + expect = {"a": {"field": "a"}, "b": {"field": "B"}} + self.assertDictEqual(o.dict(), expect) + + def test_override_in_subclass_D_E(self): + o = D(field="a") + self.assertEqual(o.dict()["field"], "A") + + subclass_o = E(field="a") + + self.assertEqual(subclass_o.dict()["field"], "a") + + def test_root_model_RootModel(self): + o = RootModel(__root__=12) + self.assertEqual(o.dict()["__root__"], 144) + + def test_multi_inheritance_H_I(self): + # H(F, G) + h = H(a="a", b="b") + # I(G, H) + i = I(a="a", b="b") + + expect = { + "a": "a", + "b": "B", + } + + self.assertDictEqual(h.dict(), expect) + self.assertDictEqual(i.dict(), expect) + + def test_pull_up_K(self): + o = K(j=J(a=12)) + + expect = {"j": 12} + + self.assertDictEqual(o.dict(), expect) + + def test_raises_value_error_L(self): + o = L(a="a") + with self.assertRaises(ValueError): + o.dict() + + def test_raises_runtime_error_too_many_params_M(self): + o = M(a="a") + + with self.assertRaises(RuntimeError): + o.dict() + + def test_raises_runtime_error_too_few_params_N(self): + o = N(a="a") + + with self.assertRaises(RuntimeError): + o.dict() From 2cce097bd59742e841136e39e59b516350708c31 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 15:01:51 -0500 Subject: [PATCH 050/205] simplify Dataset's dict override with config field_serializer --- python/lib/core/dmod/core/dataset.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/lib/core/dmod/core/dataset.py b/python/lib/core/dmod/core/dataset.py index c3b5feb25..78634294f 100644 --- a/python/lib/core/dmod/core/dataset.py +++ b/python/lib/core/dmod/core/dataset.py @@ -117,6 +117,7 @@ class Config: # TODO: in future deprecate setting properties unless through a setter method validate_assignment = True arbitrary_types_allowed = True + field_serializers = {"uuid": lambda f: str(f)} def __hash__(self): return hash(','.join([self.__class__.__name__, self.name, self.category.name, str(hash(self.data_domain)), @@ -265,11 +266,11 @@ def dict(self, **kwargs) -> dict: exclude = self._get_exclude_fields() if kwargs.get("exclude", False) is False else kwargs["exclude"] kwargs["exclude"] = exclude - serial = super().dict(**kwargs) + return super().dict(**kwargs) - # serialize uuid - serial["uuid"] = str(self.uuid) - return serial + # # serialize uuid + # serial["uuid"] = str(self.uuid) + # return serial class DatasetUser(ABC): From 899507024c49e789e85498c14fbd4df6c7f9a6df Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 15:03:47 -0500 Subject: [PATCH 051/205] refactor ContinuousRestriction to use config field_serializers --- python/lib/core/dmod/core/meta_data.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index 8a6c3cde0..8e612c173 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -254,6 +254,18 @@ def validate_start_before_end(cls, values): # validate variable is not UNKNOWN variant _validate_variable = validator("variable", allow_reuse=True)(_validate_variable_is_known) + class Config: + def _serialize_datetime(self: "ContinuousRestriction", value: datetime) -> str: + if self.datetime_pattern is not None: + return value.strftime(self.datetime_pattern) + return str(value) + + field_serializers = { + "begin": _serialize_datetime, + "end": _serialize_datetime, + "subclass": lambda s: s.__class__.__name__ + } + def __eq__(self, o: object) -> bool: if not isinstance(o, ContinuousRestriction): return False @@ -339,14 +351,6 @@ def contains(self, other: 'ContinuousRestriction') -> bool: else: return self.begin <= other.begin and self.end >= other.end - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serial = self.dict(exclude_none=True) - serial["subclass"] = self.__class__.__name__ - if self.datetime_pattern is not None: - serial["begin"] = self.begin.strftime(self.datetime_pattern) - serial["end"] = self.end.strftime(self.datetime_pattern) - return serial - class DiscreteRestriction(Serializable): """ From e29100752fc392476163c89df5d5d417d8a5628d Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 15:56:51 -0500 Subject: [PATCH 052/205] DataDomain now factory inits continuous_restrictions --- python/lib/core/dmod/core/meta_data.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index 8e612c173..53269adcf 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -3,7 +3,7 @@ from .enum import PydanticEnum from .serializable import Serializable from numbers import Number -from typing import Any, Dict, List, Optional, Set, Type, Union +from typing import Any, Dict, List, Literal, Optional, Set, Type, Union from collections.abc import Iterable from collections import OrderedDict from pydantic import root_validator, validator, PyObject, Field, StrictStr, StrictFloat, StrictInt @@ -226,10 +226,19 @@ class ContinuousRestriction(Serializable): begin: datetime end: datetime datetime_pattern: Optional[str] - subclass: Optional[PyObject] = Field(exclude=True) + subclass: PyObject @root_validator(pre=True) def coerce_times_if_datetime_pattern(cls, values): + subclass_str = values.get("subclass") + + if subclass_str is None: + values["subclass"] = cls + + if isinstance(subclass_str, str): + if subclass_str == cls.__name__: + values["subclass"] = cls + datetime_ptr = values.get("datetime_pattern") if datetime_ptr is not None: @@ -263,7 +272,7 @@ def _serialize_datetime(self: "ContinuousRestriction", value: datetime) -> str: field_serializers = { "begin": _serialize_datetime, "end": _serialize_datetime, - "subclass": lambda s: s.__class__.__name__ + "subclass": lambda value: value.__name__ } def __eq__(self, o: object) -> bool: @@ -312,7 +321,7 @@ def factory_init_from_deserialized_json(cls, json_obj: dict): if subclass_str == cls.__name__: json_obj["subclass"] = cls - return subclass(**json_obj) + return cls(**json_obj) for subclass in cls.__subclasses__(): if subclass.__name__ == subclass_str: @@ -454,6 +463,12 @@ class DataDomain(Serializable): alias="data_fields" ) + @validator("continuous_restrictions", pre=True, each_item=True) + def _factory_init_continuous_restrictions(cls, value): + if isinstance(value, ContinuousRestriction): + return value + return ContinuousRestriction.factory_init_from_deserialized_json(value) + @validator("continuous_restrictions", "discrete_restrictions", always=True) def _validate_restriction_default(cls, value): if value is None: From 4ac88007b15230d1ccaf4512aeb76e36eec94711 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 15:57:11 -0500 Subject: [PATCH 053/205] remove commented code --- python/lib/core/dmod/core/dataset.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/lib/core/dmod/core/dataset.py b/python/lib/core/dmod/core/dataset.py index 78634294f..142bd88b5 100644 --- a/python/lib/core/dmod/core/dataset.py +++ b/python/lib/core/dmod/core/dataset.py @@ -268,10 +268,6 @@ def dict(self, **kwargs) -> dict: return super().dict(**kwargs) - # # serialize uuid - # serial["uuid"] = str(self.uuid) - # return serial - class DatasetUser(ABC): """ From 489b787e9bae7bfd179e0417098aa650292efb7c Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 16:06:32 -0500 Subject: [PATCH 054/205] fix test that was failing for the right reason --- python/lib/core/dmod/test/test_meta_data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/lib/core/dmod/test/test_meta_data.py b/python/lib/core/dmod/test/test_meta_data.py index 17c3914ec..b3c9a43e1 100644 --- a/python/lib/core/dmod/test/test_meta_data.py +++ b/python/lib/core/dmod/test/test_meta_data.py @@ -188,7 +188,8 @@ def test_to_dict(self): input_data_fields = {"a": "int", "b": "float", "c": "bool", "d": "str", "e": "flux_capacitor"} expected_serialized_data_fields = {"a": "int", "b": "float", "c": "bool", "d": "str", "e": "Any"} data = { - "data_format": "AORC_CSV", + # NOTE: NGEN_OUTPUT data_fields = None. + "data_format": "NGEN_OUTPUT", "continuous": [], "discrete": [{"variable": "DATA_ID", "values": ["0"]}], } From 34f712e1ee0eee85f8d7c594aab006a891841bcb Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 26 Jan 2023 13:28:01 -0500 Subject: [PATCH 055/205] bump dmod.core patch version to 0.4.2 --- python/lib/core/dmod/core/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/core/dmod/core/_version.py b/python/lib/core/dmod/core/_version.py index b703f5c96..a98734733 100644 --- a/python/lib/core/dmod/core/_version.py +++ b/python/lib/core/dmod/core/_version.py @@ -1 +1 @@ -__version__ = '0.4.1' \ No newline at end of file +__version__ = '0.4.2' From 0b8028289b57c27864a5a97371d70317e26691d7 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 24 Jan 2023 09:33:10 -0500 Subject: [PATCH 056/205] fix comm setup.py merge conflict --- python/lib/communication/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/communication/setup.py b/python/lib/communication/setup.py index 839d3e133..30289f00d 100644 --- a/python/lib/communication/setup.py +++ b/python/lib/communication/setup.py @@ -22,6 +22,6 @@ license='', include_package_data=True, #install_requires=['websockets', 'jsonschema'],vi - install_requires=['dmod-core>=0.1.2', 'websockets>=8.1', 'jsonschema', 'redis'], + install_requires=['dmod-core>=0.1.2', 'websockets>=8.1', 'jsonschema', 'redis', 'pydantic'], packages=find_namespace_packages(include=['dmod.*'], exclude=['dmod.test']) ) From 208675b5027aaadaae333bf35da570443ceaebbe Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 11 Jan 2023 15:31:25 -0500 Subject: [PATCH 057/205] dmod.communication enum's now subclass PydanticEnum --- .../dmod/communication/dataset_management_message.py | 8 ++++---- python/lib/communication/dmod/communication/message.py | 6 +++--- .../communication/dmod/communication/metadata_message.py | 5 +++-- python/lib/communication/dmod/communication/session.py | 4 ++-- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/python/lib/communication/dmod/communication/dataset_management_message.py b/python/lib/communication/dmod/communication/dataset_management_message.py index d148b6d0e..c9e275401 100644 --- a/python/lib/communication/dmod/communication/dataset_management_message.py +++ b/python/lib/communication/dmod/communication/dataset_management_message.py @@ -2,12 +2,12 @@ from dmod.core.serializable import Serializable from .maas_request import ExternalRequest, ExternalRequestResponse from dmod.core.meta_data import DataCategory, DataDomain, DataFormat, DataRequirement +from dmod.core.enum import PydanticEnum from numbers import Number -from enum import Enum from typing import Dict, Optional, Union, List -class QueryType(Enum): +class QueryType(PydanticEnum): LIST_FILES = 1 GET_CATEGORY = 2 GET_FORMAT = 3 @@ -65,7 +65,7 @@ def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: return serial -class ManagementAction(Enum): +class ManagementAction(PydanticEnum): """ Type enumerating the standard actions that can be requested via ::class:`DatasetManagementMessage`. """ @@ -670,4 +670,4 @@ def factory_create(cls, dataset_mgmt_response: DatasetManagementResponse) -> 'Ma MaaSDatasetManagementResponse Factory-created analog of this instance type. """ - return cls.factory_init_from_deserialized_json(dataset_mgmt_response.to_dict()) \ No newline at end of file + return cls.factory_init_from_deserialized_json(dataset_mgmt_response.to_dict()) diff --git a/python/lib/communication/dmod/communication/message.py b/python/lib/communication/dmod/communication/message.py index bad2e4869..1def75632 100644 --- a/python/lib/communication/dmod/communication/message.py +++ b/python/lib/communication/dmod/communication/message.py @@ -1,12 +1,12 @@ from abc import ABC -from enum import Enum from typing import Type from dmod.core.serializable import Serializable, ResultIndicator +from dmod.core.enum import PydanticEnum #FIXME make an independent enum of model request types??? -class MessageEventType(Enum): +class MessageEventType(PydanticEnum): SESSION_INIT = 1 MODEL_EXEC_REQUEST = 2 @@ -36,7 +36,7 @@ class MessageEventType(Enum): INVALID = -1 -class InitRequestResponseReason(Enum): +class InitRequestResponseReason(PydanticEnum): """ Values for the ``reason`` attribute in responses to ``AbstractInitRequest`` messages. """ diff --git a/python/lib/communication/dmod/communication/metadata_message.py b/python/lib/communication/dmod/communication/metadata_message.py index 193fa5ee4..9769ac2f5 100644 --- a/python/lib/communication/dmod/communication/metadata_message.py +++ b/python/lib/communication/dmod/communication/metadata_message.py @@ -1,10 +1,11 @@ from .message import AbstractInitRequest, MessageEventType, Response -from enum import Enum from numbers import Number from typing import Dict, Optional, Union +from dmod.core.enum import PydanticEnum -class MetadataPurpose(Enum): + +class MetadataPurpose(PydanticEnum): CONNECT = 1, """ The metadata relates to the opening of a connection. """ DISCONNECT = 2, diff --git a/python/lib/communication/dmod/communication/session.py b/python/lib/communication/dmod/communication/session.py index d6adbec17..69fec5d38 100644 --- a/python/lib/communication/dmod/communication/session.py +++ b/python/lib/communication/dmod/communication/session.py @@ -3,13 +3,13 @@ import random from .message import AbstractInitRequest, MessageEventType, Response from dmod.core.serializable import Serializable +from dmod.core.enum import PydanticEnum from abc import ABC, abstractmethod -from enum import Enum from numbers import Number from typing import Dict, Optional, Union -class SessionInitFailureReason(Enum): +class SessionInitFailureReason(PydanticEnum): AUTHENTICATION_SYS_FAIL = 1, # some error other than bad credentials prevented successful user authentication AUTHENTICATION_DENIED = 2, # the user's asserted identity was not authenticated due to the provided credentials USER_NOT_AUTHORIZED = 3, # the user was authenticated, but does not have authorized permission for a session From 7734ac844ee23ee13cee381a0bb7a03945a8326d Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 11 Jan 2023 15:35:54 -0500 Subject: [PATCH 058/205] refactor Message --- python/lib/communication/dmod/communication/message.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/python/lib/communication/dmod/communication/message.py b/python/lib/communication/dmod/communication/message.py index 1def75632..bc587c560 100644 --- a/python/lib/communication/dmod/communication/message.py +++ b/python/lib/communication/dmod/communication/message.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Type +from typing import ClassVar, Type from dmod.core.serializable import Serializable, ResultIndicator from dmod.core.enum import PydanticEnum @@ -62,7 +62,7 @@ class Message(Serializable, ABC): Class representing communication message of some kind between parts of the NWM MaaS system. """ - event_type: MessageEventType = None + event_type: ClassVar[MessageEventType] = MessageEventType.INVALID """ :class:`MessageEventType`: the event type for this message implementation """ @classmethod @@ -77,9 +77,6 @@ def get_message_event_type(cls) -> MessageEventType: """ return cls.event_type - def __init__(self, *args, **kwargs): - pass - class AbstractInitRequest(Message, ABC): """ From c5c13332949dd1280b96a217727deb1f1e0834d2 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 11 Jan 2023 15:36:41 -0500 Subject: [PATCH 059/205] remove unnessesary __init__ in AbstractInitRequest --- python/lib/communication/dmod/communication/message.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/lib/communication/dmod/communication/message.py b/python/lib/communication/dmod/communication/message.py index bc587c560..9bf3f53d1 100644 --- a/python/lib/communication/dmod/communication/message.py +++ b/python/lib/communication/dmod/communication/message.py @@ -89,9 +89,6 @@ class AbstractInitRequest(Message, ABC): interactions. """ - def __int__(self, *args, **kwargs): - super(AbstractInitRequest, self).__int__(*args, **kwargs) - class Response(ResultIndicator, Message, ABC): """ From b0da87d92fd7f7e3e4468e1f15e7d1457d2f3b91 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 11 Jan 2023 16:06:24 -0500 Subject: [PATCH 060/205] refactor Response --- .../dmod/communication/message.py | 45 +------------------ 1 file changed, 1 insertion(+), 44 deletions(-) diff --git a/python/lib/communication/dmod/communication/message.py b/python/lib/communication/dmod/communication/message.py index 9bf3f53d1..598a9b1af 100644 --- a/python/lib/communication/dmod/communication/message.py +++ b/python/lib/communication/dmod/communication/message.py @@ -118,7 +118,7 @@ class Response(ResultIndicator, Message, ABC): """ - response_to_type = AbstractInitRequest + response_to_type: ClassVar[Type[AbstractInitRequest]] = AbstractInitRequest """ The type of :class:`AbstractInitRequest` for which this type is the response""" @classmethod @@ -153,31 +153,6 @@ def _factory_init_data_attribute(cls, json_obj: dict): except Exception as e: return None - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. - - Parameters - ---------- - json_obj - - Returns - ------- - response_obj : Response - A new object of this type instantiated from the deserialize JSON object dictionary, or none if the provided - parameter could not be used to instantiated a new object. - - See Also - ------- - _factory_init_data_attribute - """ - try: - return cls(success=json_obj['success'], reason=json_obj['reason'], message=json_obj['message'], - data=cls._factory_init_data_attribute(json_obj)) - except Exception as e: - return None - @classmethod def get_message_event_type(cls) -> MessageEventType: """ @@ -205,24 +180,6 @@ def get_response_to_type(cls) -> Type[AbstractInitRequest]: """ return cls.response_to_type - def __init__(self, data=None, *args, **kwargs): - super(Response, self).__init__(*args, **kwargs) - self.data = data - - def __eq__(self, other): - return self.success == other.success and self.reason == other.reason and self.message == other.message \ - and self.data == other.data - - def to_dict(self) -> dict: - serial = super(Response, self).to_dict() - if self.data is None: - serial['data'] = {} - elif isinstance(self.data, dict): - serial['data'] = self.data - else: - serial['data'] = self.data.to_dict() - return serial - class InvalidMessage(AbstractInitRequest): """ From f790e27d5f43805e751be45a8a25301ee50e7b36 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 11 Jan 2023 16:07:11 -0500 Subject: [PATCH 061/205] add TODO. fill in Union of data field types once refactored --- python/lib/communication/dmod/communication/message.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/lib/communication/dmod/communication/message.py b/python/lib/communication/dmod/communication/message.py index 598a9b1af..427ffd88c 100644 --- a/python/lib/communication/dmod/communication/message.py +++ b/python/lib/communication/dmod/communication/message.py @@ -121,6 +121,9 @@ class Response(ResultIndicator, Message, ABC): response_to_type: ClassVar[Type[AbstractInitRequest]] = AbstractInitRequest """ The type of :class:`AbstractInitRequest` for which this type is the response""" + # TODO: aaraney, make this union + # data: Union[] + @classmethod def _factory_init_data_attribute(cls, json_obj: dict): """ From 7f726e35e4671c309c8fc8a46bf9411a924767d0 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 11 Jan 2023 16:09:31 -0500 Subject: [PATCH 062/205] refactor InvalidMessage --- .../dmod/communication/message.py | 29 ++----------------- 1 file changed, 3 insertions(+), 26 deletions(-) diff --git a/python/lib/communication/dmod/communication/message.py b/python/lib/communication/dmod/communication/message.py index 427ffd88c..5083f45fd 100644 --- a/python/lib/communication/dmod/communication/message.py +++ b/python/lib/communication/dmod/communication/message.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import ClassVar, Type +from typing import Any, ClassVar, Dict, Type from dmod.core.serializable import Serializable, ResultIndicator from dmod.core.enum import PydanticEnum @@ -190,33 +190,10 @@ class InvalidMessage(AbstractInitRequest): type. """ - event_type: MessageEventType = MessageEventType.INVALID + event_type: ClassVar[MessageEventType] = MessageEventType.INVALID """ :class:`MessageEventType`: the type of ``MessageEventType`` for which this message is applicable. """ - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. - - Parameters - ---------- - json_obj - - Returns - ------- - A new object of this type instantiated from the deserialize JSON object dictionary, or none if the provided - parameter could not be used to instantiated a new object. - """ - try: - return cls(content=json_obj['content']) - except: - return None - - def __init__(self, content: dict): - self.content = content - - def to_dict(self) -> dict: - return {'content': self.content} + content: Dict[str, Any] class InvalidMessageResponse(Response): From b29751a3ebab86de38beda627918873ec3f49902 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 11 Jan 2023 16:43:44 -0500 Subject: [PATCH 063/205] refactor Response and its subtypes --- .../dmod/communication/message.py | 75 ++++++------------- 1 file changed, 24 insertions(+), 51 deletions(-) diff --git a/python/lib/communication/dmod/communication/message.py b/python/lib/communication/dmod/communication/message.py index 5083f45fd..f9f08a76b 100644 --- a/python/lib/communication/dmod/communication/message.py +++ b/python/lib/communication/dmod/communication/message.py @@ -1,5 +1,6 @@ from abc import ABC -from typing import Any, ClassVar, Dict, Type +from typing import Any, ClassVar, Dict, Literal, Optional, Type +from pydantic import Field from dmod.core.serializable import Serializable, ResultIndicator from dmod.core.enum import PydanticEnum @@ -121,40 +122,7 @@ class Response(ResultIndicator, Message, ABC): response_to_type: ClassVar[Type[AbstractInitRequest]] = AbstractInitRequest """ The type of :class:`AbstractInitRequest` for which this type is the response""" - # TODO: aaraney, make this union - # data: Union[] - - @classmethod - def _factory_init_data_attribute(cls, json_obj: dict): - """ - Initialize the argument value for a constructor param used to set the :attr:`data` attribute appropriate for - this type, given the parent JSON object, which may mean simply returning the value or may mean deserializing the - value to some object type, depending on the implementation. - - The intent is for this to be used by :meth:`factory_init_from_deserialized_json`, where initialization logic for - the value to be set as :attr:`data` from the provided param may vary depending on the particular class. - - In the default implementation, the value found at the 'data' key is simply directly returned, or None is - returned if the 'data' key is not found. - - Parameters - ---------- - json_obj : dict - the parent JSON object containing the desired data value under the 'data' key - - Returns - ------- - data : dict - the resulting data value object - - See Also - ------- - factory_init_from_deserialized_json - """ - try: - return json_obj['data'] - except Exception as e: - return None + data: Optional[Serializable] @classmethod def get_message_event_type(cls) -> MessageEventType: @@ -198,27 +166,32 @@ class InvalidMessage(AbstractInitRequest): class InvalidMessageResponse(Response): - response_to_type = InvalidMessage + response_to_type: ClassVar[Type[AbstractInitRequest]] = InvalidMessage """ The type of :class:`AbstractInitRequest` for which this type is the response""" - def __init__(self, data=None): - super().__init__(success=False, - reason='Invalid Request Message', - message='Request message was not formatted as any known valid type', - data=data) + success = False + reason: Literal["Invalid Request message"] = "Invalid Request message" + message: Literal["Request message was not formatted as any known valid type"] = "Request message was not formatted as any known valid type" + data: Optional[Serializable] + def __init__(self, data: Optional[Serializable]=None, **kwargs): + super().__init__(data=data) + + +class HttpCode(Serializable): + http_code: int = Field(ge=100, le=599) class ErrorResponse(Response): """ A response to inform a client of an error that has occured within a request """ - def __init__(self, message: str, http_code: int = None): - if not http_code: - http_code = 500 - - if not isinstance(http_code, int): - try: - http_code = int(float(http_code)) - except: - http_code = str(http_code) - super().__init__(success=False, reason="Error", message=message, data={"http_code": http_code}) + success = False + reason: Literal["Error"] = "Error" + data: HttpCode = Field(default_factory=lambda: HttpCode(http_code=500)) + + def __init__(self, message: str, http_code: int = None, **kwargs): + if http_code is None: + super().__init__(message=message) + return + + super().__init__(message=message, data={"http_code": http_code}) From 96e4fceb24622d15b9d01ac93cb680074341c6f3 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 11 Jan 2023 16:56:26 -0500 Subject: [PATCH 064/205] refactor MetadataMessage --- .../dmod/communication/metadata_message.py | 130 +++++------------- 1 file changed, 38 insertions(+), 92 deletions(-) diff --git a/python/lib/communication/dmod/communication/metadata_message.py b/python/lib/communication/dmod/communication/metadata_message.py index 9769ac2f5..1f5c790c7 100644 --- a/python/lib/communication/dmod/communication/metadata_message.py +++ b/python/lib/communication/dmod/communication/metadata_message.py @@ -1,6 +1,7 @@ from .message import AbstractInitRequest, MessageEventType, Response from numbers import Number -from typing import Dict, Optional, Union +from typing import ClassVar, Dict, Optional, Union +from pydantic import Field, root_validator from dmod.core.enum import PydanticEnum @@ -28,30 +29,44 @@ def get_value_for_name(cls, name_str: str) -> Optional['MetadataPurpose']: class MetadataMessage(AbstractInitRequest): - event_type: MessageEventType = MessageEventType.METADATA + event_type: ClassVar[MessageEventType] = MessageEventType.INVALID - _purpose_serial_key = 'purpose' - _description_serial_key = 'description' - _metadata_follows_serial_key = 'additional_metadata' - _config_changes_serial_key = 'config_changes' - _config_change_dict_type_key = 'config_value_dict_type' + purpose: MetadataPurpose + description: Optional[str] + metadata_follows: bool = Field( + False, + alias="additional_metadata", + description=( + "An indication of whether there is more metadata the sender needs to communicate beyond what is contained in this" + "message, thus letting the receiver know whether it should continue receiving after sending the response to this." + ), + ) - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict) -> Optional['MetadataMessage']: - if cls._purpose_serial_key not in json_obj: - return None - purpose = MetadataPurpose.get_value_for_name(json_obj[cls._purpose_serial_key]) - if purpose is None: - return None - if cls._metadata_follows_serial_key in json_obj: - metadata_follows = json_obj[cls._metadata_follows_serial_key] - else: - # default to False for this, as this is pretty safe assumption if we don't see it explicit - metadata_follows = False - description = json_obj[cls._description_serial_key] if cls._description_serial_key in json_obj else None - cfg_changes = json_obj[cls._config_changes_serial_key] if cls._config_changes_serial_key in json_obj else None - return cls(purpose=purpose, description=description, metadata_follows=metadata_follows, - config_changes=cfg_changes) + """ + A dictionary, keyed by strings, representing some configurable setting(s) that need their value(s) changed. + + This will mainly be applicable when the purpose property is ``CHANGE_CONFIG``, and frequently can otherwise be + left to/expected to be ``None``. However, it should not be ``None`` when the purpose is ``CHANGE_CONFIG``. + + Note that the main dictionary can contain nested dictionaries also. These should essentially be the serialized + representations of ::class:`Serializable` object. While the type hinting does not explicitly note this due to + the recursive nature of the definition, nested dictionaries at any depth should have string keys and values of + one of the types allowed for values in the top-level dictionary. + + It is recommended that an additional value be added to such nested dictionaries, under the key returned by + ::method:`get_config_change_dict_type_key`. This should be the string representation of the class type of the + nested, serialized object. + """ + config_changes: Optional[Dict[str, Union[None, str, bool, int, float, dict, list]]] = Field(description="A dictionary, keyed by strings, representing some configurable setting(s) that need their value(s) changed.") + + @root_validator() + def validate_purpose(cls, values): + if values["purpose"] == MetadataPurpose.CHANGE_CONFIG and not values["config_changes"]: + raise RuntimeError('Invalid {} initialization, setting {} to {} but without any config changes.'.format( + cls.__class__, values["purpose"].__class__, values["purpose"].name)) + return values + + _config_change_dict_type_key: ClassVar[str] = 'config_value_dict_type' @classmethod def get_config_change_dict_type_key(cls) -> str: @@ -69,75 +84,6 @@ def get_config_change_dict_type_key(cls) -> str: """ return cls._config_change_dict_type_key - def __init__(self, purpose: MetadataPurpose, description: Optional[str] = None, metadata_follows: bool = False, - config_changes: Optional[Dict[str, Union[None, str, bool, Number, dict, list]]] = None): - self._purpose = purpose - self._description = description - self._metadata_follows = metadata_follows - self._config_changes = config_changes - if self._purpose == MetadataPurpose.CHANGE_CONFIG and not self._config_changes: - raise RuntimeError('Invalid {} initialization, setting {} to {} but without any config changes.'.format( - self.__class__, self._purpose.__class__, self._purpose.name)) - - @property - def config_changes(self) -> Optional[Dict[str, Union[None, str, bool, Number, dict, list]]]: - """ - A dictionary, keyed by strings, representing some configurable setting(s) that need their value(s) changed. - - This will mainly be applicable when the purpose property is ``CHANGE_CONFIG``, and frequently can otherwise be - left to/expected to be ``None``. However, it should not be ``None`` when the purpose is ``CHANGE_CONFIG``. - - Note that the main dictionary can contain nested dictionaries also. These should essentially be the serialized - representations of ::class:`Serializable` object. While the type hinting does not explicitly note this due to - the recursive nature of the definition, nested dictionaries at any depth should have string keys and values of - one of the types allowed for values in the top-level dictionary. - - It is recommended that an additional value be added to such nested dictionaries, under the key returned by - ::method:`get_config_change_dict_type_key`. This should be the string representation of the class type of the - nested, serialized object. - - Returns - ------- - Optional[Dict[str, Union[None, str, bool, Number, dict]]] - A dictionary, keyed by strings, representing some configurable setting(s) that need their value(s) changed. - """ - # This should get handled in __init__ but put here anyway - if self._purpose == MetadataPurpose.CHANGE_CONFIG and not self._config_changes: - raise RuntimeError('Invalid {} initialization, setting {} to {} but without any config changes.'.format( - self.__class__, self._purpose.__class__, self._purpose.name)) - return self._config_changes - - @property - def description(self) -> Optional[str]: - return self._description - - @property - def metadata_follows(self) -> bool: - """ - An indication of whether there is more metadata the sender needs to communicate beyond what is contained in this - message, thus letting the receiver know whether it should continue receiving after sending the response to this. - - Returns - ------- - bool - An indication of whether there is more metadata the sender needs to communicate beyond what is contained in - this message, thus letting the receiver know whether it should continue receiving after sending the response - to this. - """ - return self._metadata_follows - - @property - def purpose(self) -> MetadataPurpose: - return self._purpose - - def to_dict(self) -> dict: - result = {self._purpose_serial_key: self.purpose.name, self._metadata_follows_serial_key: self.metadata_follows} - if self.description: - result[self._description_serial_key] = self.description - if self.config_changes: - result[self._config_changes_serial_key] = self.config_changes - return result - class MetadataResponse(Response): """ From facf9a160da621e8c50e7247366652262051b3ea Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 11 Jan 2023 17:38:49 -0500 Subject: [PATCH 065/205] add MetadataSignal type --- .../dmod/communication/metadata_message.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/python/lib/communication/dmod/communication/metadata_message.py b/python/lib/communication/dmod/communication/metadata_message.py index 1f5c790c7..17f6b1571 100644 --- a/python/lib/communication/dmod/communication/metadata_message.py +++ b/python/lib/communication/dmod/communication/metadata_message.py @@ -1,8 +1,9 @@ from .message import AbstractInitRequest, MessageEventType, Response from numbers import Number -from typing import ClassVar, Dict, Optional, Union +from typing import ClassVar, Dict, Optional, Type, Union from pydantic import Field, root_validator +from dmod.core.serializable import Serializable from dmod.core.enum import PydanticEnum @@ -27,7 +28,23 @@ def get_value_for_name(cls, name_str: str) -> Optional['MetadataPurpose']: return None -class MetadataMessage(AbstractInitRequest): +class MetadataSignal(Serializable): + purpose: MetadataPurpose + metadata_follows: bool + + class Config: + fields = { + "metadata_follows": { + "alias": "additional_metadata", + "description": ( + "An indication of whether there is more metadata the sender needs to communicate beyond what is contained in this" + "message, thus letting the receiver know whether it should continue receiving after sending the response to this." + ), + } + } + + +class MetadataMessage(MetadataSignal, AbstractInitRequest): event_type: ClassVar[MessageEventType] = MessageEventType.INVALID From 99d33934c1c87a2642c9a560a5f7073ee02e9ef1 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 11 Jan 2023 17:39:08 -0500 Subject: [PATCH 066/205] refactor metadata message and response --- .../dmod/communication/metadata_message.py | 26 +++++-------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/python/lib/communication/dmod/communication/metadata_message.py b/python/lib/communication/dmod/communication/metadata_message.py index 17f6b1571..3256dede0 100644 --- a/python/lib/communication/dmod/communication/metadata_message.py +++ b/python/lib/communication/dmod/communication/metadata_message.py @@ -48,17 +48,9 @@ class MetadataMessage(MetadataSignal, AbstractInitRequest): event_type: ClassVar[MessageEventType] = MessageEventType.INVALID - purpose: MetadataPurpose description: Optional[str] - metadata_follows: bool = Field( - False, - alias="additional_metadata", - description=( - "An indication of whether there is more metadata the sender needs to communicate beyond what is contained in this" - "message, thus letting the receiver know whether it should continue receiving after sending the response to this." - ), - ) + config_changes: Optional[Dict[str, Union[None, str, bool, int, float, dict, list]]] = Field(description="A dictionary, keyed by strings, representing some configurable setting(s) that need their value(s) changed.") """ A dictionary, keyed by strings, representing some configurable setting(s) that need their value(s) changed. @@ -74,7 +66,6 @@ class MetadataMessage(MetadataSignal, AbstractInitRequest): ::method:`get_config_change_dict_type_key`. This should be the string representation of the class type of the nested, serialized object. """ - config_changes: Optional[Dict[str, Union[None, str, bool, int, float, dict, list]]] = Field(description="A dictionary, keyed by strings, representing some configurable setting(s) that need their value(s) changed.") @root_validator() def validate_purpose(cls, values): @@ -107,9 +98,8 @@ class MetadataResponse(Response): The subtype of ::class:`Response` appropriate for ::class:`MetadataMessage` objects. """ - _metadata_follows_serial_key = MetadataMessage._metadata_follows_serial_key - _purpose_serial_key = MetadataMessage._purpose_serial_key - response_to_type = MetadataMessage + response_to_type: ClassVar[Type[AbstractInitRequest]] = MetadataMessage + data: MetadataSignal @classmethod def factory_create(cls, success: bool, reason: str, purpose: MetadataPurpose, expect_more: bool, message: str = ''): @@ -129,16 +119,14 @@ def factory_create(cls, success: bool, reason: str, purpose: MetadataPurpose, ex ------- """ - data = {cls._purpose_serial_key: purpose.name, cls._metadata_follows_serial_key: expect_more} - return cls(success=success, reason=reason, data=data, message=message) + data = MetadataSignal(purpose=purpose, metadata_follows=expect_more) - def __init__(self, success: bool, reason: str, data: dict, message: str = ''): - super().__init__(success=success, reason=reason, message=message, data=data) + return cls(success=success, reason=reason, data=data, message=message) @property def metadata_follows(self) -> bool: - return self.data[self._metadata_follows_serial_key] + return self.data.metadata_follows @property def purpose(self) -> MetadataPurpose: - return MetadataPurpose.get_value_for_name(self.data[self._purpose_serial_key]) + return self.data.purpose From 7958201018210f7e72595a9e96a5d49adbc94217 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 12 Jan 2023 13:40:40 -0500 Subject: [PATCH 067/205] add generate_secret helper function --- .../communication/dmod/communication/session.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/python/lib/communication/dmod/communication/session.py b/python/lib/communication/dmod/communication/session.py index 69fec5d38..20aa9a71b 100644 --- a/python/lib/communication/dmod/communication/session.py +++ b/python/lib/communication/dmod/communication/session.py @@ -6,7 +6,20 @@ from dmod.core.enum import PydanticEnum from abc import ABC, abstractmethod from numbers import Number -from typing import Dict, Optional, Union +from typing import ClassVar, Dict, Optional, List, Type, Union +from pydantic import Field, IPvAnyAddress, validator, root_validator + + +def _generate_secret() -> str: + """Generate random sha256 session secret. + + Returns + ------- + str + sha256 digest + """ + random.seed() + return hashlib.sha256(str(random.random()).encode('utf-8')).hexdigest() class SessionInitFailureReason(PydanticEnum): From 8f7a1a91bba37ff16496a94c91a8d46f458fbc53 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 12 Jan 2023 13:41:39 -0500 Subject: [PATCH 068/205] refactor Sesssion --- .../dmod/communication/session.py | 127 ++++-------------- 1 file changed, 26 insertions(+), 101 deletions(-) diff --git a/python/lib/communication/dmod/communication/session.py b/python/lib/communication/dmod/communication/session.py index 20aa9a71b..635a334f1 100644 --- a/python/lib/communication/dmod/communication/session.py +++ b/python/lib/communication/dmod/communication/session.py @@ -40,53 +40,33 @@ class Session(Serializable): be made, and potentially other communication may take place. """ - _DATETIME_FORMAT = '%Y-%m-%d %H:%M:%S.%f' + _DATETIME_FORMAT: ClassVar[str] = '%Y-%m-%d %H:%M:%S.%f' - _full_equality_attributes = ['session_id', 'session_secret', 'created', 'last_accessed'] + session_id: int = Field(description="The unique identifier for this session.") + # QUESTION: we are using UUID4's elsewhere, do we want to use that instead here? Or perhaps a ULID? + session_secret: str = Field(default_factory=_generate_secret, min_length=64, max_length=64, description="The unique random secret for this session.") + created: datetime.datetime = Field(default_factory=datetime.datetime.now, description="The date and time this session was created.") + last_accessed: datetime.datetime = Field(default_factory=datetime.datetime.now) + + _full_equality_attributes: ClassVar[List[str]]= ['session_id', 'session_secret', 'created', 'last_accessed'] """ list of str: the names of attributes/properties to include when testing instances for complete equality """ - _serialized_attributes = ['session_id', 'session_secret', 'created', 'last_accessed'] + _serialized_attributes: ClassVar[List[str]]= ['session_id', 'session_secret', 'created', 'last_accessed'] """ list of str: the names of attributes/properties to include when serializing an instance """ - _session_timeout_delta = datetime.timedelta(minutes=30.0) + _session_timeout_delta: ClassVar[datetime.timedelta] = datetime.timedelta(minutes=30.0) + + @validator("created", "last_accessed", pre=True) + def validate_date(cls, value): + if isinstance(value, datetime): + return value - @classmethod - def _init_datetime_val(cls, value): try: - if value is None: - return datetime.datetime.now() - elif isinstance(value, str): - return datetime.datetime.strptime(value, Session._DATETIME_FORMAT) - elif not isinstance(value, datetime.datetime): - raise RuntimeError() - else: - return value - except Exception as e: + return datetime.datetime.strptime(value, cls.get_datetime_str_format()) + # TODO: improve error handling, or throw something know for downstream users. + except: return datetime.datetime.now() - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. - - Parameters - ---------- - json_obj - - Returns - ------- - A new object of this type instantiated from the deserialize JSON object dictionary - """ - int_converter = lambda x: int(x) - str_converter = lambda s: str(s) - date_converter = lambda date_str: datetime.datetime.strptime(date_str, cls.get_datetime_str_format()) - - return cls(session_id=cls.parse_simple_serialized(json_obj, 'session_id', int, True, int_converter), - session_secret=cls.parse_simple_serialized(json_obj, 'session_secret', str, False, str_converter), - created=cls.parse_simple_serialized(json_obj, 'created', datetime.datetime, False, date_converter), - last_accessed=cls.parse_simple_serialized(json_obj, 'last_accessed', datetime.datetime, False, - date_converter)) - @classmethod def get_datetime_str_format(cls): return cls._DATETIME_FORMAT @@ -104,7 +84,7 @@ def get_full_equality_attributes(cls) -> tuple: a tuple-ized (and therefore immutable) collection of attribute names for those attributes used for determining full/complete equality between instances. """ - return tuple(cls._full_equality_attributes) + return tuple(cls.__fields__) @classmethod def get_serialized_attributes(cls) -> tuple: @@ -119,7 +99,7 @@ def get_serialized_attributes(cls) -> tuple: tuple of str: a tuple-ized (and therefore immutable) collection of attribute names for attributes used in serialization """ - return tuple(cls._serialized_attributes) + return tuple(cls.__fields__) @classmethod def get_session_timeout_delta(cls) -> datetime.timedelta: @@ -128,46 +108,10 @@ def get_session_timeout_delta(cls) -> datetime.timedelta: def __eq__(self, other): return isinstance(other, Session) and self.session_id == other.session_id - def __init__(self, - session_id: Union[str, int], - session_secret: str = None, - created: Union[datetime.datetime, str, None] = None, - last_accessed: Union[datetime.datetime, str, None] = None): - """ - Instantiate, either from an existing record - in which case values for 'secret' and 'created' are provided - or - from a newly acquired session id - in which case 'secret' is randomly generated, 'created' is set to now(), and - the expectation is that a new session record will be created from this instance. - - Parameters - ---------- - session_id : Union[str, int] - numeric session id value - session_secret : :obj:`str`, optional - the session secret, if deserializing this object from an existing session record - created : Union[:obj:`datetime.datetime`, :obj:`str`] - the date and time of session creation, either as a datetime object or parseable string, set to - :method:`datetime.datetime.now()` by default - """ - - self._session_id = int(session_id) - if session_secret is None: - random.seed() - self._session_secret = hashlib.sha256(str(random.random()).encode('utf-8')).hexdigest() - else: - self._session_secret = session_secret - - self._created = self._init_datetime_val(created) - self._last_accessed = self._init_datetime_val(last_accessed) - def __hash__(self): return self.session_id - @property - def created(self): - """:obj:`datetime.datetime`: The date and time this session was created.""" - return self._created - - def full_equals(self, other) -> bool: + def full_equals(self, other: object) -> bool: """ Test if this object and another are both of the exact same type and are more "fully" equal than can be determined from the standard equality implementation, by comparing all the attributes from @@ -185,16 +129,7 @@ def full_equals(self, other) -> bool: fully_equal : bool whether the objects are of the same type and with equal values for all serialized attributes """ - if self.__class__ != other.__class__: - return False - try: - for attr in self.get_full_equality_attributes(): - if getattr(self, attr) != getattr(other, attr): - return False - return True - except Exception as e: - # TODO: do something with this exception - return False + return super().__eq__(other) def get_as_dict(self) -> dict: """ @@ -205,17 +140,7 @@ def get_as_dict(self) -> dict: dict a serialized representation of this instance """ - attributes = {} - for attr in self._serialized_attributes: - attr_val = getattr(self, attr) - if isinstance(attr_val, datetime.datetime): - attributes[attr] = attr_val.strftime(self.get_datetime_str_format()) - elif isinstance(attr_val, Number) or isinstance(attr_val, str): - attributes[attr] = attr_val - else: - attributes[attr] = str(attr_val) - - return attributes + return self.dict() def get_as_json(self) -> str: """ @@ -232,12 +157,12 @@ def get_created_serialized(self): return self.created.strftime(Session._DATETIME_FORMAT) def get_last_accessed_serialized(self): - return self._last_accessed.strftime(Session._DATETIME_FORMAT) + return self.last_accessed.strftime(Session._DATETIME_FORMAT) def is_expired(self): - return self._last_accessed + self.get_session_timeout_delta() < datetime.datetime.now() + return self.last_accessed + self.get_session_timeout_delta() < datetime.datetime.now() - def is_serialized_attribute(self, attribute) -> bool: + def is_serialized_attribute(self, attribute: str) -> bool: """ Test whether an attribute of the given name is included in the serialized version of the instance returned by :method:`get_as_dict` and/or :method:`get_as_json` (at the top level). From ea3634b5ec16546939b49bf518ecb50f3551c47a Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 12 Jan 2023 13:42:40 -0500 Subject: [PATCH 069/205] refactor FullAuthSession --- .../dmod/communication/session.py | 98 +++++++++---------- 1 file changed, 45 insertions(+), 53 deletions(-) diff --git a/python/lib/communication/dmod/communication/session.py b/python/lib/communication/dmod/communication/session.py index 635a334f1..c506d1397 100644 --- a/python/lib/communication/dmod/communication/session.py +++ b/python/lib/communication/dmod/communication/session.py @@ -177,30 +177,55 @@ def is_serialized_attribute(self, attribute: str) -> bool: True if there is an attribute with the given name in the :attr:`_serialized_attributes` list, or False otherwise """ - for attr in self._serialized_attributes: - if attribute == attr: - return True - return False - - @property - def session_id(self): - """int: The unique identifier for this session.""" - return int(self._session_id) - - @property - def session_secret(self): - """str: The unique random secret for this session.""" - return self._session_secret - - def to_dict(self) -> dict: - return self.get_as_dict() + if not isinstance(attribute, str): + return False + return attribute in self.__fields__ + + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> Dict[str, Union[str, int]]: + _exclude = {"created", "last_accessed"} + if exclude is not None: + _exclude = {*_exclude, *exclude} + + serial = super().dict( + include=include, + exclude=_exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + if exclude is None or "created" not in exclude: + serial["created"] = self.created.strftime(self.get_datetime_str_format()) + + if exclude is None or "last_accessed" not in exclude: + serial["last_accessed"] = self.last_accessed.strftime(self.get_datetime_str_format()) + + return serial # TODO: work more on this later, when authentication becomes more important class FullAuthSession(Session): - _full_equality_attributes = ['session_id', 'session_secret', 'created', 'ip_address', 'user', 'last_accessed'] - _serialized_attributes = ['session_id', 'session_secret', 'created', 'ip_address', 'user', 'last_accessed'] + ip_address: str + user: str = 'default' + + @validator("ip_address", pre=True) + def cast_ip_address_to_str(cls, value: str) -> str: + # this will raise if cannot be coerced into IPv(4|6)Address + IPvAnyAddress.validate(value) + return value @classmethod def factory_init_from_deserialized_json(cls, json_obj: dict): @@ -215,45 +240,12 @@ def factory_init_from_deserialized_json(cls, json_obj: dict): ------- A new object of this type instantiated from the deserialize JSON object dictionary """ - # TODO: these are duplicated ... try to improve on that - int_converter = lambda x: int(x) - str_converter = lambda s: str(s) - date_converter = lambda date_str: datetime.datetime.strptime(date_str, cls.get_datetime_str_format()) try: - return cls(session_id=cls.parse_simple_serialized(json_obj, 'session_id', int, True, int_converter), - session_secret=cls.parse_simple_serialized(json_obj, 'session_secret', str, False, str_converter), - created=cls.parse_simple_serialized(json_obj, 'created', datetime.datetime, False, date_converter), - ip_address=cls.parse_simple_serialized(json_obj, 'ip_address', str, True, str_converter), - user=cls.parse_simple_serialized(json_obj, 'user', str, True, str_converter), - last_accessed=cls.parse_simple_serialized(json_obj, 'last_accessed', datetime.datetime, False, date_converter)) + return cls(**json_obj) except: return Session.factory_init_from_deserialized_json(json_obj) - def __init__(self, - ip_address: str, - session_id: Union[str, int], - session_secret: str = None, - user: str = 'default', - created: Union[datetime.datetime, str, None] = None, - last_accessed: Union[datetime.datetime, str, None] = None): - super().__init__(session_id=session_id, session_secret=session_secret, created=created, - last_accessed=last_accessed) - self._user = user if user is not None else 'default' - self._ip_address = ip_address - - @property - def ip_address(self): - return self._ip_address - - @property - def last_accessed(self): - return self._last_accessed - - @property - def user(self): - return self._user - class SessionInitMessage(AbstractInitRequest): """ From 76b13a784fe092813be9c687314b5b07bff7133c Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 12 Jan 2023 13:43:47 -0500 Subject: [PATCH 070/205] refactor SessionInitMessage --- .../dmod/communication/session.py | 30 +++---------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/python/lib/communication/dmod/communication/session.py b/python/lib/communication/dmod/communication/session.py index c506d1397..6dc215dff 100644 --- a/python/lib/communication/dmod/communication/session.py +++ b/python/lib/communication/dmod/communication/session.py @@ -266,33 +266,11 @@ class SessionInitMessage(AbstractInitRequest): The secret through which the client entity establishes the authenticity of its username assertion """ - event_type: MessageEventType = MessageEventType.SESSION_INIT - """ :class:`MessageEventType`: the event type for this message implementation """ - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. - - Parameters - ---------- - json_obj + username: str + user_secret: str - Returns - ------- - A new object of this type instantiated from the deserialize JSON object dictionary - """ - try: - return SessionInitMessage(username=json_obj['username'], user_secret=json_obj['user_secret']) - except: - return None - - def __init__(self, username: str, user_secret: str): - self.username = username - self.user_secret = user_secret - - def to_dict(self) -> dict: - return {'username': self.username, 'user_secret': self.user_secret} + event_type: ClassVar[MessageEventType] = MessageEventType.SESSION_INIT + """ :class:`MessageEventType`: the event type for this message implementation """ class FailedSessionInitInfo(Serializable): From 06bc81fa2293633151b8cdcba357a4df57832bff Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 12 Jan 2023 13:44:02 -0500 Subject: [PATCH 071/205] refactor FailedSessionInitInfo --- .../dmod/communication/session.py | 64 ++----------------- 1 file changed, 5 insertions(+), 59 deletions(-) diff --git a/python/lib/communication/dmod/communication/session.py b/python/lib/communication/dmod/communication/session.py index 6dc215dff..c92e73c5c 100644 --- a/python/lib/communication/dmod/communication/session.py +++ b/python/lib/communication/dmod/communication/session.py @@ -279,73 +279,19 @@ class FailedSessionInitInfo(Serializable): successfully init a session. """ + user: str + reason: SessionInitFailureReason = SessionInitFailureReason.UNKNOWN + fail_time: datetime.datetime = Field(default_factory=datetime.datetime.now) + details: Optional[str] + @classmethod def get_datetime_str_format(cls): return Session.get_datetime_str_format() - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - date_converter = lambda date_str: datetime.datetime.strptime(date_str, cls.get_datetime_str_format()) - reason_converter = lambda r: SessionInitFailureReason[r] - try: - user = cls.parse_simple_serialized(json_obj, 'user', str, True) - fail_time = cls.parse_simple_serialized(json_obj, 'fail_time', datetime.datetime, False, date_converter) - reason = cls.parse_simple_serialized(json_obj, 'reason', SessionInitFailureReason, False, reason_converter) - details = cls.parse_simple_serialized(json_obj, 'details', str, False) - - if reason is None: - FailedSessionInitInfo(user=user, fail_time=fail_time, details=details) - else: - return FailedSessionInitInfo(user=user, reason=reason, fail_time=fail_time, details=details) - except: - return None - - def __eq__(self, other): - if self.__class__ != other.__class__ or self.user != other.user or self.reason != other.reason: - return False - if self.fail_time is not None and other.fail_time is not None and self.fail_time != other.fail_time: - return False - return True - - def __init__(self, user: str, reason: SessionInitFailureReason = SessionInitFailureReason.UNKNOWN, - fail_time: Optional[datetime.datetime] = None, details: Optional[str] = None): - self.user = user - self.reason = reason - self.fail_time = fail_time if fail_time is not None else datetime.datetime.now() - self.details = details - - def to_dict(self) -> Dict[str, str]: - """ - Get the representation of this instance as a serialized dictionary or dictionary-like object (e.g., a JSON - object). - - Since the returned value must be serializable and JSON-like, key and value types are restricted. For this - implementation, all keys and values in the returned dictionary must be strings. Thus, for the - ::attribute:`fail_time` and ::attribute:`details` attributes, there should be no key or value if the attribute - has a current value of ``None``. - - Returns - ------- - Dict[str, str] - The representation of this instance as a serialized dictionary or dictionary-like object, with valid types - of keys and values. - - See Also - ------- - ::method:`Serializable.to_dict` - """ - result = {'user': self.user, 'reason': self.reason.value} - if self.fail_time is not None: - result['fail_time'] = self.fail_time.strftime(self.get_datetime_str_format()) - if self.details is not None: - result['details'] = self.details - return result - # Define this custom type here for hinting SessionInitDataType = Union[Session, FailedSessionInitInfo] - class SessionInitResponse(Response): """ The :class:`~.message.Response` subtype used to response to a :class:`.SessionInitMessage`, either From 0c7aff376827cb70f6290434adbaa03c504420e1 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 12 Jan 2023 13:44:28 -0500 Subject: [PATCH 072/205] refactor SessionInitResponse --- .../dmod/communication/session.py | 104 ++++++++---------- 1 file changed, 44 insertions(+), 60 deletions(-) diff --git a/python/lib/communication/dmod/communication/session.py b/python/lib/communication/dmod/communication/session.py index c92e73c5c..31cc6ca1b 100644 --- a/python/lib/communication/dmod/communication/session.py +++ b/python/lib/communication/dmod/communication/session.py @@ -335,42 +335,54 @@ class SessionInitResponse(Response): """ - response_to_type = SessionInitMessage + response_to_type: ClassVar[Type[AbstractInitRequest]] = SessionInitMessage """ Type[`SessionInitMessage`]: the type or subtype of :class:`Message` for which this type is the response""" - @classmethod - def _factory_init_data_attribute(cls, json_obj: dict) -> Optional[SessionInitDataType]: - """ - Initialize the argument value for a constructor param used to set the :attr:`data` attribute appropriate for - this type, given the parent JSON object, which for this type means deserializing the dict value to either a - session object or a failure info object. - - Parameters - ---------- - json_obj : dict - the parent JSON object containing the desired session data serialized value - - Returns - ------- - data - the resulting :class:`Session` or :class:`FailedSessionInitInfo` object obtained after processing, - or None if no valid object could be processed of either type - """ - data = None - try: - data = json_obj['data'] - except: - det = 'Received serialized JSON response object that did not contain expected key for serialized session.' - return FailedSessionInitInfo(user='', reason=SessionInitFailureReason.SESSION_DETAILS_MISSING, details=det) - - try: - # If we can, return the FullAuthSession or Session obtained by this class method - return FullAuthSession.factory_init_from_deserialized_json(data) - except: + # NOTE: this field _is_ optional, however `data` will be FailedSessionInitInfo if it is not + # provided or set to None. + # NOTE: order of this Union matters. types will be coerced from left to right. meaning, more + # specific types (i.e. subtypes) should be listed before more general types. see `SmartUnion` + # for more detail: https://docs.pydantic.dev/usage/model_config/#smart-union + data: Union[FailedSessionInitInfo, FullAuthSession, Session] + + @root_validator(pre=True) + def _coerce_data_field(cls, values): + data = values.get("data") + + if data is None: + details = "Instantiated SessionInitResponse object without session data; defaulting to failure" + values["data"] = FailedSessionInitInfo( + user="", + reason=SessionInitFailureReason.SESSION_DETAILS_MISSING, + details=details, + ) + return values + + # run `data` field validators + coerced_data, errors = cls.__fields__["data"].validate(data, {}, loc="") + if errors is not None: + details = 'Instantiated SessionInitResponse object using unexpected type for data ({})'.format( + data.__class__.__name__) try: - return FailedSessionInitInfo.factory_init_from_deserialized_json(data) + as_str = '; converted to string: \n{}'.format(str(data)) + details += as_str except: - return None + # If we can't cast to string, don't worry; just leave out that part in details + pass + values["data"] = FailedSessionInitInfo( + user="", + reason=SessionInitFailureReason.SESSION_DETAILS_MISSING, + details=details, + ) + return values + + values["data"] = coerced_data + return values + + @validator("success") + def _update_success(cls, value: bool, values): + # Make sure to reset/change self.success if self.data ends up being a failure info object + return value and isinstance(values["data"], Session) def __eq__(self, other): return self.__class__ == other.__class__ \ @@ -379,34 +391,6 @@ def __eq__(self, other): and self.message == other.message \ and self.data.full_equals(other.data) if isinstance(self.data, Session) else self.data == other.data - def __init__(self, success: bool, reason: str, message: str = '', data: Optional[SessionInitDataType] = None): - super().__init__(success=success, reason=reason, message=message, data=data) - - # If we received a dict for data, try to deserialize using the class method (failures will set to None, - # which will get handled by the next conditional logic) - if isinstance(self.data, dict): - # Remember, the class method expects a JSON obj dict with the data as a child element, not the data directly - self.data = self.__class__._factory_init_data_attribute({'success': self.success, 'data': data}) - - if self.data is None: - details = 'Instantiated SessionInitResponse object without session data; defaulting to failure' - self.data = FailedSessionInitInfo(user='', reason=SessionInitFailureReason.SESSION_DETAILS_MISSING, - details=details) - elif not (isinstance(self.data, Session) or isinstance(self.data, FailedSessionInitInfo)): - details = 'Instantiated SessionInitResponse object using unexpected type for data ({})'.format( - self.data.__class__.__name__) - try: - as_str = '; converted to string: \n{}'.format(str(self.data)) - details += as_str - except: - # If we can't cast to string, don't worry; just leave out that part in details - pass - self.data = FailedSessionInitInfo(user='', reason=SessionInitFailureReason.SESSION_DETAILS_MISSING, - details=details) - - # Make sure to reset/change self.success if self.data ends up being a failure info object - self.success = self.success and isinstance(self.data, Session) - class SessionManager(ABC): """ From 05f1853983b73f19afce2136f3b608e7101c22e3 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 12 Jan 2023 14:06:32 -0500 Subject: [PATCH 073/205] refactor UnsupportedMessageTypeResponse message field should not be optional --- .../dmod/communication/unsupported_message.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/lib/communication/dmod/communication/unsupported_message.py b/python/lib/communication/dmod/communication/unsupported_message.py index 88295ad74..77a95a0f6 100644 --- a/python/lib/communication/dmod/communication/unsupported_message.py +++ b/python/lib/communication/dmod/communication/unsupported_message.py @@ -4,12 +4,16 @@ class UnsupportedMessageTypeResponse(Response): + actual_event_type: MessageEventType + listener_type: Type[WebSocketInterface] + message: str + + success = False + reason = "Message Event Type Unsupported" def __init__(self, actual_event_type: MessageEventType, listener_type: Type[WebSocketInterface], - message: str = None, data=None): + message: str = None, data=None, **kwargs): if message is None: message = 'The {} event type is not supported by this {} listener'.format( actual_event_type, listener_type.__name__) - super().__init__(success=False, reason='Message Event Type Unsupported', message=message, data=data) - self.actual_event_type = actual_event_type - self.listener_type = listener_type \ No newline at end of file + super().__init__(message=message, data=data, actual_event_type=actual_event_type, listener_type=listener_type) From ad87928cac80000d3b41487a195ed26a1efd580a Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 12 Jan 2023 14:11:31 -0500 Subject: [PATCH 074/205] format unsupported_message module --- .../dmod/communication/unsupported_message.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/python/lib/communication/dmod/communication/unsupported_message.py b/python/lib/communication/dmod/communication/unsupported_message.py index 77a95a0f6..e1c64ec2b 100644 --- a/python/lib/communication/dmod/communication/unsupported_message.py +++ b/python/lib/communication/dmod/communication/unsupported_message.py @@ -11,9 +11,21 @@ class UnsupportedMessageTypeResponse(Response): success = False reason = "Message Event Type Unsupported" - def __init__(self, actual_event_type: MessageEventType, listener_type: Type[WebSocketInterface], - message: str = None, data=None, **kwargs): + def __init__( + self, + actual_event_type: MessageEventType, + listener_type: Type[WebSocketInterface], + message: str = None, + data=None, + **kwargs + ): if message is None: - message = 'The {} event type is not supported by this {} listener'.format( - actual_event_type, listener_type.__name__) - super().__init__(message=message, data=data, actual_event_type=actual_event_type, listener_type=listener_type) + message = "The {} event type is not supported by this {} listener".format( + actual_event_type, listener_type.__name__ + ) + super().__init__( + message=message, + data=data, + actual_event_type=actual_event_type, + listener_type=listener_type, + ) From c108eef7628ddcaeba0ebd17febff2ff0514f791 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 12 Jan 2023 15:01:55 -0500 Subject: [PATCH 075/205] refactor UpdateMessage --- .../dmod/communication/update_message.py | 163 ++++++------------ 1 file changed, 54 insertions(+), 109 deletions(-) diff --git a/python/lib/communication/dmod/communication/update_message.py b/python/lib/communication/dmod/communication/update_message.py index 787f5cfeb..130ba020c 100644 --- a/python/lib/communication/dmod/communication/update_message.py +++ b/python/lib/communication/dmod/communication/update_message.py @@ -1,6 +1,7 @@ from .message import AbstractInitRequest, MessageEventType, Response from pydoc import locate -from typing import Dict, Optional, Type, Union +from typing import ClassVar, Dict, Optional, Type, Union +from pydantic import Field, validator import uuid @@ -28,132 +29,76 @@ class type, but note that when messages are serialized, it is converted to the f update it conveys. """ - event_type: MessageEventType = MessageEventType.INFORMATION_UPDATE + event_type: ClassVar[MessageEventType] = MessageEventType.INFORMATION_UPDATE - _DIGEST_KEY = 'digest' - _OBJECT_ID_KEY = 'object_id' - _OBJECT_TYPE_KEY = 'object_type' - _UPDATED_DATA_KEY = 'updated_data' + object_id: str = Field(description="The identifier for the object being updated, as a string.") + object_type: Type[object] = Field(description="The type of object being updated.") + # NOTE: updated_data must container at least one key + updated_data: Dict[str, str] = Field(description="A serialized dictionary of properties to new values.") + digest: str = Field(default_factory=lambda: uuid.uuid4().hex) + + @validator("object_type", pre=True) + def _coerce_object_type(cls, value): + obj_type = locate(value) + if obj_type is None: + raise ValueError("could not resolve `object_type`") + return obj_type + + @validator("updated_data") + def _validate_updated_data_has_keys(cls, value: Dict[str, str]): + if not value.keys(): + raise ValueError("`updated_data` must have at least one key.") + return value @classmethod def get_digest_key(cls) -> str: - return cls._DIGEST_KEY + return cls.__fields__["digest"].alias @classmethod def get_object_id_key(cls) -> str: - return cls._OBJECT_ID_KEY + return cls.__fields__["object_id"].alias @classmethod def get_object_type_key(cls) -> str: - return cls._OBJECT_TYPE_KEY + return cls.__fields__["object_type"].alias @classmethod def get_updated_data_key(cls) -> str: - return cls._UPDATED_DATA_KEY - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. - - The method expects the ::attribute:`object_type` to be represented as the fully-qualified name string for the - particular class type. If the method cannot located the actual class type by this string, the JSON is - considered invalid. - - Additionally, if the representation of the ::attribute:`updated_data` property is not a (serialized) nested - dictionary, or is an empty dictionary, this is also considered invalid. - - Both ::attribute:`digest` and ::attribute:`object_id` representations are valid if they can be cast to strings. - - The JSON is not considered invalid if it has other keys/values at the root level beyond those for the standard - properties. - - For invalid JSON representations, ``None`` is returned. - - Parameters - ---------- - json_obj - - Returns - ------- - Optional[UpdateMessage] - A new object of this type instantiated from the deserialize JSON object dictionary, or ``None`` if the JSON - is not a valid serialized representation of this type. - """ - try: - obj_type = locate(json_obj[cls.get_object_type_key()]) - if obj_type is None: - return None - obj_id = str(json_obj[cls.get_object_id_key()]) - updated_data = json_obj[cls.get_updated_data_key()] - if not isinstance(updated_data, dict) or len(updated_data.keys()) == 0: - return None - message = cls(object_id=obj_id, object_type=obj_type, updated_data=updated_data) - message._digest = str(json_obj[cls.get_digest_key()]) - except: - return None - - def __init__(self, object_id: str, object_type: Type, updated_data: Dict[str, str]): - """ - Initialize a new object. - - Parameters - ---------- - object_id : str - The identifier for the object being updated, as a string. - object_type : Type - The type of object being updated. - updated_data : Dict[str, str] - A serialized dictionary of properties to new values. - """ - self._digest = None - self._object_type = object_type - self._object_id = object_id - self._updated_data = updated_data - - @property - def digest(self) -> str: - if self._digest is None: - self._digest = uuid.uuid4().hex - return self._digest - - @property - def object_id(self) -> str: - return self._object_id - - @property - def object_type(self) -> Type: - return self._object_type + return cls.__fields__["updated_data"].alias @property def object_type_string(self) -> str: return '{}.{}'.format(self.object_type.__module__, self.object_type.__name__) - def to_dict(self) -> dict: - """ - Get the representation of this instance as a dictionary or dictionary-like object (e.g., a JSON object). - - Returns - ------- - dict - The representation of this instance as a dictionary or dictionary-like object (e.g., a JSON object). - """ - return {self.get_object_id_key(): self.object_id, self.get_digest_key(): self.digest, - self.get_object_type_key(): self.object_type_string, self.get_updated_data_key(): self.updated_data} - - @property - def updated_data(self) -> Dict[str, str]: - """ - Get the updated properties of the updated entity and the new values, as a dictionary of string property name - keys mapped to string representations of the values. - - Returns - ------- - Dict[str, str] - The updated properties of the updated entity and the new values, as a dictionary of string property name - keys mapped to string representations of the values. - """ - return self._updated_data + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> Dict[str, Union[str, int]]: + _exclude = {"object_type"} + if exclude is not None: + _exclude = {*_exclude, *exclude} + + serial = super().dict( + include=include, + exclude=_exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + if exclude is None or "object_type" not in exclude: + serial["object_type"] = self.object_type_string + + return serial class UpdateMessageResponse(Response): From aabf8c79492e3b73f7f58d126c951984788327f2 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 12 Jan 2023 15:30:54 -0500 Subject: [PATCH 076/205] add UpdateMessageData; inner type of UpdateMessageResponse --- .../lib/communication/dmod/communication/update_message.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/lib/communication/dmod/communication/update_message.py b/python/lib/communication/dmod/communication/update_message.py index 130ba020c..b5f71c9fd 100644 --- a/python/lib/communication/dmod/communication/update_message.py +++ b/python/lib/communication/dmod/communication/update_message.py @@ -4,6 +4,8 @@ from pydantic import Field, validator import uuid +from dmod.core.serializable import Serializable + class UpdateMessage(AbstractInitRequest): """ @@ -100,6 +102,10 @@ def dict( return serial +class UpdateMessageData(Serializable): + digest: Optional[str] + object_found: Optional[bool] + class UpdateMessageResponse(Response): """ From 0eadac1777b5827018bcb4997f68bf012b339a05 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 12 Jan 2023 15:31:12 -0500 Subject: [PATCH 077/205] refactor UpdateMessageResponse --- .../dmod/communication/update_message.py | 30 ++++++++----------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/python/lib/communication/dmod/communication/update_message.py b/python/lib/communication/dmod/communication/update_message.py index b5f71c9fd..6d0e61cbc 100644 --- a/python/lib/communication/dmod/communication/update_message.py +++ b/python/lib/communication/dmod/communication/update_message.py @@ -107,15 +107,14 @@ class UpdateMessageData(Serializable): object_found: Optional[bool] -class UpdateMessageResponse(Response): +class UpdateMessageResponse(UpdateMessageData, Response): """ The subtype of ::class:`Response` appropriate for ::class:`UpdateMessage` objects. """ - _DIGEST_SUBKEY = 'digest' - _OBJECT_FOUND_SUBKEY = 'object_found' + response_to_type: ClassVar[Type[AbstractInitRequest]] = UpdateMessage - response_to_type = UpdateMessage + data: Optional[UpdateMessageData] = Field(default_factory=UpdateMessageData) @classmethod def get_digest_subkey(cls) -> str: @@ -129,7 +128,7 @@ def get_digest_subkey(cls) -> str: The "subkey" (i.e., the key for the value within the nested ``data`` dictionary) for the ``digest`` in serialized representations. """ - return cls._DIGEST_SUBKEY + return cls.__fields__["digest"].alias @classmethod def get_object_found_subkey(cls) -> str: @@ -143,25 +142,20 @@ def get_object_found_subkey(cls) -> str: The "subkey" (i.e., the key for the value within the nested ``data`` dictionary) for the ``digest`` in serialized representations. """ - return cls._OBJECT_FOUND_SUBKEY + return cls.__fields__["object_found"].alias def __init__(self, success: bool, reason: str, response_text: str = '', data: Optional[Dict[str, Union[str, bool]]] = None, digest: Optional[str] = None, - object_found: Optional[bool] = None): + object_found: Optional[bool] = None, **kwargs): # Work with digest/found either as params or contained within data param # However, move explicit params into the data dict param, allowing non-None params to overwrite data = dict() if data is None else data - digest = data[self.get_digest_subkey()] if digest is None and self.get_digest_subkey() in data else digest - if object_found is None and self.get_object_found_subkey(): - object_found = data[self.get_object_found_subkey()] - super().__init__(success=success, reason=reason, message=response_text, - data={self.get_digest_subkey(): digest, self.get_object_found_subkey(): object_found}) + if digest is None and self.get_digest_subkey() in data: + digest = data[self.get_digest_subkey()] - @property - def digest(self) -> str: - return self.data[self.get_digest_subkey()] + if object_found is None and self.get_object_found_subkey() in data: + object_found = data[self.get_object_found_subkey()] - @property - def object_found(self) -> bool: - return self.data[self.get_object_found_subkey()] + super().__init__(success=success, reason=reason, message=response_text, + data=UpdateMessageData(digest=digest, object_found=object_found)) From 0238e951824d7f949d5241e5e99590a59726a740 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 12 Jan 2023 16:02:36 -0500 Subject: [PATCH 078/205] refactor SchedulerRequestMessage --- .../dmod/communication/scheduler_request.py | 158 +++++++----------- 1 file changed, 58 insertions(+), 100 deletions(-) diff --git a/python/lib/communication/dmod/communication/scheduler_request.py b/python/lib/communication/dmod/communication/scheduler_request.py index 790534475..68985cc02 100644 --- a/python/lib/communication/dmod/communication/scheduler_request.py +++ b/python/lib/communication/dmod/communication/scheduler_request.py @@ -1,14 +1,30 @@ from dmod.core.execution import AllocationParadigm from .maas_request import ModelExecRequest, ModelExecRequestResponse from .message import AbstractInitRequest, MessageEventType, Response -from typing import Optional, Union +from pydantic import Field, PrivateAttr +from typing import ClassVar, Dict, Optional, Union class SchedulerRequestMessage(AbstractInitRequest): - event_type: MessageEventType = MessageEventType.SCHEDULER_REQUEST + event_type: ClassVar[MessageEventType] = MessageEventType.SCHEDULER_REQUEST """ :class:`MessageEventType`: the event type for this message implementation """ + model_request: ModelExecRequest = Field(description="The underlying request for a job to be scheduled.") + user_id: str = Field(description="The associated user id for this scheduling request.") + memory: int = Field(500_000, description="The amount of memory, in bytes, requested for the scheduling of this job.") + cpus_: Optional[int] = Field(description="The number of processors requested for the scheduling of this job.") + allocation_paradigm_: Optional[AllocationParadigm] + + _memory_unset: bool = PrivateAttr() + + class Config: + fields = { + "memory": {"alias": "mem"}, + "cpus_": {"alias": "cpus"}, + "allocation_paradigm_": {"alias": "allocation_paradigm"}, + } + @classmethod def default_allocation_paradigm_str(cls) -> str: """ @@ -27,61 +43,26 @@ def default_allocation_paradigm_str(cls) -> str: """ return AllocationParadigm.get_default_selection().name - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. - - Parameters - ---------- - json_obj - - Returns - ------- - SchedulerRequestMessage - A new object of this type instantiated from the deserialize JSON object dictionary, or ``None`` if the - provided parameter could not be used to instantiated a new object of this type. - """ - try: - model_request = ModelExecRequest.factory_init_correct_subtype_from_deserialized_json(json_obj['model_request']) - if model_request is not None: - alloc_paradigm = json_obj['allocation_paradigm'] if 'allocation_paradigm' in json_obj else None - return cls(model_request=model_request, - user_id=json_obj['user_id'], - # This may be absent to indicate use the value from the backing model request - cpus=json_obj['cpus'] if 'cpus' in json_obj else None, - # This may be absent to indicate it should be marked "unset" and a default should be used - mem=json_obj['mem'] if 'mem' in json_obj else None, - allocation_paradigm=alloc_paradigm) - else: - return None - except: - return None - # TODO: may need to generalize the underlying request to support, say, scheduling evaluation jobs - def __init__(self, model_request: ModelExecRequest, user_id: str, cpus: Optional[int] = None, mem: Optional[int] = None, - allocation_paradigm: Optional[Union[str, AllocationParadigm]] = None): - self._model_request = model_request - self._user_id = user_id - self._cpus = cpus + def __init__( + self, + model_request: ModelExecRequest, + user_id: str, + cpus: Optional[int] = None, + mem: Optional[int] = None, + allocation_paradigm: Optional[Union[str, AllocationParadigm]] = None, + ): + super().__init__( + model_request=model_request, + user_id=user_id, + cpus=cpus, + memory=mem, + allocation_paradigm=allocation_paradigm + ) if mem is None: self._memory_unset = True - self._memory = 500000 else: self._memory_unset = False - self._memory = mem - if isinstance(allocation_paradigm, str): - self._allocation_paradigm = AllocationParadigm.get_from_name(allocation_paradigm) - else: - self._allocation_paradigm = allocation_paradigm - - def __eq__(self, other): - return self.__class__ == other.__class__ \ - and self.model_request == other.model_request \ - and self.cpus == other.cpus \ - and self.memory == other.memory \ - and self.user_id == other.user_id \ - and self.allocation_paradigm == other.allocation_paradigm @property def allocation_paradigm(self) -> AllocationParadigm: @@ -93,10 +74,10 @@ def allocation_paradigm(self) -> AllocationParadigm: AllocationParadigm The allocation paradigm requested for the job to be scheduled. """ - if self._allocation_paradigm is None: + if self.allocation_paradigm_ is None: return self.model_request.allocation_paradigm else: - return self._allocation_paradigm + return self.allocation_paradigm_ @property def cpus(self) -> int: @@ -111,19 +92,7 @@ def cpus(self) -> int: int The number of processors requested for the scheduling of this job. """ - return self.model_request.cpu_count if self._cpus is None else self._cpus - - @property - def memory(self) -> int: - """ - The amount of memory, in bytes, requested for the scheduling of this job. - - Returns - ------- - int - The amount of memory, in bytes, requested for the scheduling of this job. - """ - return self._memory + return self.model_request.cpu_count if self.cpus_ is None else self.cpus_ @property def memory_unset(self) -> bool: @@ -137,18 +106,6 @@ def memory_unset(self) -> bool: """ return self._memory_unset - @property - def model_request(self) -> ModelExecRequest: - """ - The underlying request for a job to be scheduled. - - Returns - ------- - ModelExecRequest - The underlying request for a job to be scheduled. - """ - return self._model_request - @property def nested_event(self) -> MessageEventType: """ @@ -161,29 +118,30 @@ def nested_event(self) -> MessageEventType: """ return self.model_request.get_message_event_type() - @property - def user_id(self) -> str: - """ - The associated user id for this scheduling request. - - Returns - ------- - str - The associated user id for this scheduling request. - """ - return self._user_id - - def to_dict(self) -> dict: - serial = {'model_request': self.model_request.to_dict(), 'user_id': self.user_id} - if self._allocation_paradigm is not None: - serial['allocation_paradigm'] = self._allocation_paradigm.name - # Don't include this in serial form if property value is sourced from underlying model request - if self._cpus is not None: - serial['cpus'] = self._cpus + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> Dict[str, Union[str, int]]: # Only including memory value in serial form if it was explicitly set in the first place if not self.memory_unset: - serial['mem'] = self.memory - return serial + exclude = {"memory"} if exclude is None else {"memory", *exclude} + + return super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) class SchedulerRequestResponse(Response): From f6bd5527623ae72303baf71daf40b4171b6e63cf Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 12 Jan 2023 16:12:40 -0500 Subject: [PATCH 079/205] refactor parameter module --- .../communication/maas_request/parameter.py | 30 ++++++------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/python/lib/communication/dmod/communication/maas_request/parameter.py b/python/lib/communication/dmod/communication/maas_request/parameter.py index 6ac1f5ea3..71362f8ac 100644 --- a/python/lib/communication/dmod/communication/maas_request/parameter.py +++ b/python/lib/communication/dmod/communication/maas_request/parameter.py @@ -1,16 +1,11 @@ -class Scalar: +from dmod.core.serializable import Serializable + + +class Scalar(Serializable): """ Represents a parameter value that is bound to a single number """ - - def __init__(self, scalar: int): - """ - :param int scalar: The value for the parameter - """ - self.scalar = scalar - - def to_dict(self): - return {"scalar": self.scalar} + scalar: int def __str__(self): return str(self.scalar) @@ -19,16 +14,11 @@ def __repr__(self): return self.__str__() -class Parameter: +class Parameter(Serializable): """ Base clase for model parameter descriptions that a given model may expose to DMOD for dynamic parameter selection. """ - - def __init__(self, name): - """ - Set the base meta data of the parameter - """ - self.name = name + name: str class ScalarParameter(Parameter): @@ -36,7 +26,5 @@ class ScalarParameter(Parameter): A Scalar parameter is a simple interger parameter who's valid range are integer increments between min and max, inclusive. """ - - def __init__(self, min, max): - self.min = min - self.max = max + min: int + max: int From 787d9b07f8f963770ba9b64d4ac6be1b8f238a35 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 12 Jan 2023 16:14:02 -0500 Subject: [PATCH 080/205] refactor DmodJobRequest --- .../dmod/communication/maas_request/dmod_job_request.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/lib/communication/dmod/communication/maas_request/dmod_job_request.py b/python/lib/communication/dmod/communication/maas_request/dmod_job_request.py index 5288e7941..451e59e56 100644 --- a/python/lib/communication/dmod/communication/maas_request/dmod_job_request.py +++ b/python/lib/communication/dmod/communication/maas_request/dmod_job_request.py @@ -11,9 +11,6 @@ class DmodJobRequest(AbstractInitRequest, ABC): The base class underlying all types of messages requesting execution of some kind of workflow job. """ - def __int__(self, *args, **kwargs): - super(DmodJobRequest, self).__int__(*args, **kwargs) - @property @abstractmethod def data_requirements(self) -> List[DataRequirement]: From 7b948b1ba8ff5d8fed57a0d6cf011ad5f5227c78 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 12 Jan 2023 16:30:00 -0500 Subject: [PATCH 081/205] refactor distribution module --- .../maas_request/distribution.py | 66 +++++++++++++++---- 1 file changed, 53 insertions(+), 13 deletions(-) diff --git a/python/lib/communication/dmod/communication/maas_request/distribution.py b/python/lib/communication/dmod/communication/maas_request/distribution.py index aa1bef5a2..d088302bc 100644 --- a/python/lib/communication/dmod/communication/maas_request/distribution.py +++ b/python/lib/communication/dmod/communication/maas_request/distribution.py @@ -1,8 +1,28 @@ -class Distribution: +from dmod.core.serializable import Serializable + +from typing import Literal + + +class DistributionBounds(Serializable): + minimum: int = 0 + maximum: int = 0 + distribution_type: Literal["normal"] = "normal" + + class Config: + feilds = { + "distribution_type": {"alias": "type"}, + "minimum": {"alias": "min"}, + "maximum": {"alias": "max"}, + } + + +class Distribution(Serializable): """ Represents the definition of a distribution of numbers """ + distribution: DistributionBounds + def __init__( self, minimum: int = 0, maximum: int = 0, distribution_type: str = "normal" ): @@ -11,18 +31,38 @@ def __init__( :param int maximum: The upper bound of the distribution :param str distribution_type: The type of the distribution """ - self.minimum = minimum - self.maximum = maximum - self.distribution_type = distribution_type - - def to_dict(self): - return { - "distribution": { - "min": self.minimum, - "max": self.maximum, - "type": self.distribution_type, - } - } + super().__init__( + distribution=DistributionBounds( + minimum=minimum, maximum=maximum, distribution_type=distribution_type + ) + ) + + @property + def minimum(self) -> int: + """The lower bound for the distribution""" + return self.distribution.minimum + + @minimum.setter + def minimum(self, value: int): + self.distribution.minimum = value + + @property + def maximum(self) -> int: + """The upper bound for the distribution""" + return self.distribution.maximum + + @maximum.setter + def maximum(self, value: int): + self.distribution.maximum = value + + @property + def distribution_type(self) -> str: + """The type of the distribution""" + return self.distribution.distribution_type + + @distribution_type.setter + def distribution_type(self, value: str) -> str: + self.distribution.distribution_type = value def __str__(self): return str(self.to_dict()) From 70e10353991b7205185884585aa947fd6d303da3 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 12 Jan 2023 16:31:14 -0500 Subject: [PATCH 082/205] refactor ExternalRequestResponse --- .../maas_request/external_request_response.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/lib/communication/dmod/communication/maas_request/external_request_response.py b/python/lib/communication/dmod/communication/maas_request/external_request_response.py index 36b320de1..f03d10d2d 100644 --- a/python/lib/communication/dmod/communication/maas_request/external_request_response.py +++ b/python/lib/communication/dmod/communication/maas_request/external_request_response.py @@ -1,13 +1,12 @@ from abc import ABC -from ..message import Response +from ..message import AbstractInitRequest, Response from .external_request import ExternalRequest +from typing import ClassVar, Type + class ExternalRequestResponse(Response, ABC): - response_to_type = ExternalRequest + response_to_type: ClassVar[Type[AbstractInitRequest]] = ExternalRequest """ The type of :class:`AbstractInitRequest` for which this type is the response""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) From 358e5dae5230a6bac890f045c0ac1b577d79a798 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 17 Jan 2023 10:09:02 -0500 Subject: [PATCH 083/205] refactor DatasetQuery --- .../dataset_management_message.py | 20 +------------------ 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/python/lib/communication/dmod/communication/dataset_management_message.py b/python/lib/communication/dmod/communication/dataset_management_message.py index c9e275401..e62e43c6d 100644 --- a/python/lib/communication/dmod/communication/dataset_management_message.py +++ b/python/lib/communication/dmod/communication/dataset_management_message.py @@ -41,29 +41,11 @@ def get_for_name(cls, name_str: str) -> 'QueryType': class DatasetQuery(Serializable): - _KEY_QUERY_TYPE = 'query_type' - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict) -> Optional['DatasetQuery']: - try: - return cls(query_type=QueryType.get_for_name(json_obj[cls._KEY_QUERY_TYPE])) - except Exception as e: - return None + query_file: QueryType def __hash__(self): return hash(self.query_type) - def __eq__(self, other): - return isinstance(other, DatasetQuery) and self.query_type == other.query_type - - def __init__(self, query_type: QueryType): - self.query_type = query_type - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serial = dict() - serial[self._KEY_QUERY_TYPE] = self.query_type.name - return serial - class ManagementAction(PydanticEnum): """ From 1377def619cbf9514b225bf48c81d731f7b1ad1b Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 17 Jan 2023 16:04:27 -0500 Subject: [PATCH 084/205] refactor DatasetManagementMessage --- .../dataset_management_message.py | 248 +++++------------- 1 file changed, 60 insertions(+), 188 deletions(-) diff --git a/python/lib/communication/dmod/communication/dataset_management_message.py b/python/lib/communication/dmod/communication/dataset_management_message.py index e62e43c6d..08284e310 100644 --- a/python/lib/communication/dmod/communication/dataset_management_message.py +++ b/python/lib/communication/dmod/communication/dataset_management_message.py @@ -3,8 +3,9 @@ from .maas_request import ExternalRequest, ExternalRequestResponse from dmod.core.meta_data import DataCategory, DataDomain, DataFormat, DataRequirement from dmod.core.enum import PydanticEnum +from pydantic import root_validator, Field from numbers import Number -from typing import Dict, Optional, Union, List +from typing import ClassVar, Dict, Optional, Union, List class QueryType(PydanticEnum): @@ -157,65 +158,43 @@ class DatasetManagementMessage(AbstractInitRequest): Valid actions are enumerated by the ::class:`ManagementAction`. """ - event_type: MessageEventType = MessageEventType.DATASET_MANAGEMENT + event_type: ClassVar[MessageEventType] = MessageEventType.DATASET_MANAGEMENT - _SERIAL_KEY_ACTION = 'action' - _SERIAL_KEY_CATEGORY = 'category' - _SERIAL_KEY_DATA_DOMAIN = 'data_domain' - _SERIAL_KEY_DATA_LOCATION = 'data_location' - _SERIAL_KEY_DATASET_NAME = 'dataset_name' - _SERIAL_KEY_IS_PENDING_DATA = 'pending_data' - _SERIAL_KEY_QUERY = 'query' - _SERIAL_KEY_IS_READ_ONLY = 'read_only' - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict) -> Optional['DatasetManagementMessage']: - """ - Inflate serialized representation back to a full object, if serial representation is valid. + management_action: ManagementAction = Field(description="The type of ::class:`ManagementAction` this message embodies or requests.") + dataset_name: Optional[str] = Field(description="The name of the involved dataset, if applicable.") + is_read_only_dataset: bool = Field(False, description="Whether the dataset involved is, should be, or must be (depending on action) read-only.") + data_category: Optional[DataCategory] = Field(description="The category of the involved data, if applicable.") + data_domain: Optional[DataDomain] = Field(description="The domain of the involved data, if applicable.") + data_location: Optional[str] = Field(description="Location for acted-upon data.") + is_pending_data: bool = Field(False, description="Whether the sender has data pending transmission after this message.") + """ + Whether the sender has data it wants to transmit after this message. The typical use case is during a + ``CREATE`` action, where this indicates there is already data to add to the newly created dataset. + """ + query: Optional[DatasetQuery] - Parameters - ---------- - json_obj : dict - Serialized representation of a ::class:`DatasetManagementMessage` instance. + @root_validator() + def _post_init_validate_dependent_fields(cls, values): + # Sanity check certain param values depending on the action; e.g., can't CREATE a dataset without a name + action: ManagementAction = values["management_action"] + name, category, domain = values["dataset_name"], values["data_category"], values["data_domain"] + err_msg_template = "Cannot create {} for action {} without {}" + if name is None and action.requires_dataset_name: + raise RuntimeError(err_msg_template.format(cls.__name__, action, "a dataset name")) + if category is None and action.requires_data_category: + raise RuntimeError(err_msg_template.format(cls.__name__, action, "a data category")) + if domain is None and action.requires_data_domain: + raise RuntimeError(err_msg_template.format(cls.__name__, action, "a data domain")) - Returns - ------- - Optional[DatasetManagementMessage] - The inflated ::class:`DatasetManagementMessage`, or ``None`` if the serialized form was invalid. - """ - try: - # Grab the class to deserialize, popping it from the json obj (it was temp injected by a subclass) if there - deserialized_class = json_obj.pop('deserialized_class', cls) - - # Similarly, get/pop any temporarily injected kwargs values to pass to deserialized_class's init function - deserialized_class_kwargs = json_obj.pop('deserialized_class_kwargs', dict()) - - action = ManagementAction.get_for_name(json_obj[cls._SERIAL_KEY_ACTION]) - if json_obj[cls._SERIAL_KEY_ACTION] != action.name: - raise RuntimeError("Unparseable serialized {} value: {}".format(ManagementAction.__name__, - json_obj[cls._SERIAL_KEY_ACTION])) - - dataset_name = json_obj.get(cls._SERIAL_KEY_DATASET_NAME) - category_str = json_obj.get(cls._SERIAL_KEY_CATEGORY) - category = None if category_str is None else DataCategory.get_for_name(category_str) - data_loc = json_obj.get(cls._SERIAL_KEY_DATA_LOCATION) - #page = json_obj[cls._SERIAL_KEY_PAGE] if cls._SERIAL_KEY_PAGE in json_obj else None - if cls._SERIAL_KEY_QUERY in json_obj: - query = DatasetQuery.factory_init_from_deserialized_json(json_obj[cls._SERIAL_KEY_QUERY]) - else: - query = None - if cls._SERIAL_KEY_DATA_DOMAIN in json_obj: - domain = DataDomain.factory_init_from_deserialized_json(json_obj[cls._SERIAL_KEY_DATA_DOMAIN]) - else: - domain = None + return values - return deserialized_class(action=action, dataset_name=dataset_name, category=category, - is_read_only_dataset=json_obj[cls._SERIAL_KEY_IS_READ_ONLY], domain=domain, - data_location=data_loc, - is_pending_data=json_obj.get(cls._SERIAL_KEY_IS_PENDING_DATA), #page=page, - query=query, **deserialized_class_kwargs) - except Exception as e: - return None + class Config: + fields = { + "management_action": {"alias": "action"}, + "data_category": {"alias": "category"}, + "is_read_only_dataset": {"alias": "read_only"}, + "is_pending_data": {"alias": "pending_data"}, + } def __eq__(self, other): try: @@ -241,10 +220,20 @@ def __hash__(self): self.data_category.name, str(hash(self.data_domain)), self.data_location, str(self.is_pending_data), self.query.to_json()])) - def __init__(self, action: ManagementAction, dataset_name: Optional[str] = None, is_read_only_dataset: bool = False, - category: Optional[DataCategory] = None, domain: Optional[DataDomain] = None, - data_location: Optional[str] = None, is_pending_data: bool = False, - query: Optional[DatasetQuery] = None, *args, **kwargs): + def __init__( + self, + *, + # NOTE: default is None for backwards compatibility. could be specified using alias. + action: ManagementAction = None, + dataset_name: Optional[str] = None, + is_read_only_dataset: bool = False, + category: Optional[DataCategory] = None, + domain: Optional[DataDomain] = None, + data_location: Optional[str] = None, + is_pending_data: bool = False, + query: Optional[DatasetQuery] = None, + **data + ): """ Initialize this instance. @@ -265,134 +254,17 @@ def __init__(self, action: ManagementAction, dataset_name: Optional[str] = None, query : Optional[DatasetQuery] Optional ::class:`DatasetQuery` object for query messages. """ - # Sanity check certain param values depending on the action; e.g., can't CREATE a dataset without a name - err_msg_template = "Cannot create {} for action {} without {}" - if dataset_name is None and action.requires_dataset_name: - raise RuntimeError(err_msg_template.format(self.__class__.__name__, action, "a dataset name")) - if category is None and action.requires_data_category: - raise RuntimeError(err_msg_template.format(self.__class__.__name__, action, "a data category")) - if domain is None and action.requires_data_domain: - raise RuntimeError(err_msg_template.format(self.__class__.__name__, action, "a data domain")) - - super(DatasetManagementMessage, self).__init__(*args, **kwargs) - - # TODO: raise exceptions for actions for which the workflow is not yet supported (e.g., REMOVE_DATA) - - self._action = action - self._dataset_name = dataset_name - self._is_read_only_dataset = is_read_only_dataset - self._category = category - self._domain = domain - self._data_location = data_location - self._query = query - self._is_pending_data = is_pending_data - - @property - def data_location(self) -> Optional[str]: - """ - Location for acted-upon data. - - Returns - ------- - Optional[str] - Location for acted-upon data. - """ - return self._data_location - - @property - def is_pending_data(self) -> bool: - """ - Whether the sender has data pending transmission after this message. - - Whether the sender has data it wants to transmit after this message. The typical use case is during a - ``CREATE`` action, where this indicates there is already data to add to the newly created dataset. - - Returns - ------- - bool - Whether the sender has data pending transmission after this message. - """ - return self._is_pending_data - - @property - def data_category(self) -> Optional[DataCategory]: - """ - The category of the involved data, if applicable. - - Returns - ------- - bool - The category of the involved data, if applicable. - """ - return self._category - - @property - def data_domain(self) -> Optional[DataDomain]: - """ - The domain of the involved data, if applicable. - - Returns - ------- - Optional[DataDomain] - The domain of the involved data, if applicable. - """ - return self._domain - - @property - def dataset_name(self) -> Optional[str]: - """ - The name of the involved dataset, if applicable. - - Returns - ------- - Optional - The name of the involved dataset, if applicable. - """ - return self._dataset_name - - @property - def is_read_only_dataset(self) -> bool: - """ - Whether the dataset involved is, should be, or must be (depending on action) read-only. - - Returns - ------- - bool - Whether the dataset involved is, should be, or must be (depending on action) read-only. - """ - return self._is_read_only_dataset - - @property - def management_action(self) -> ManagementAction: - """ - The type of ::class:`ManagementAction` this message embodies or requests. - - Returns - ------- - ManagementAction - The type of ::class:`ManagementAction` this message embodies or requests. - """ - return self._action - - @property - def query(self) -> Optional[DatasetQuery]: - return self._query - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serial = {self._SERIAL_KEY_ACTION: self.management_action.name, - self._SERIAL_KEY_IS_READ_ONLY: self.is_read_only_dataset, - self._SERIAL_KEY_IS_PENDING_DATA: self.is_pending_data} - if self.dataset_name is not None: - serial[self._SERIAL_KEY_DATASET_NAME] = self.dataset_name - if self.data_category is not None: - serial[self._SERIAL_KEY_CATEGORY] = self.data_category.name - if self.data_location is not None: - serial[self._SERIAL_KEY_DATA_LOCATION] = self.data_location - if self.data_domain is not None: - serial[self._SERIAL_KEY_DATA_DOMAIN] = self.data_domain.to_dict() - if self.query is not None: - serial[self._SERIAL_KEY_QUERY] = self.query.to_dict() - return serial + super().__init__( + management_action=action or data.pop("management_action", None), + dataset_name=dataset_name, + is_read_only_dataset=is_read_only_dataset or data.pop("read_only", False), + data_category=category or data.pop("data_category", None), + data_domain=domain or data.pop("data_domain", None), + data_location=data_location, + is_pending_data=is_pending_data or data.pop("pending_data", False), + query=query, + **data + ) class DatasetManagementResponse(Response): From 9969f32e81efd1b9e52b9a803e6b82762f8ef4db Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 17 Jan 2023 22:22:11 -0500 Subject: [PATCH 085/205] refactor MaaSDatasetManagementMessage --- .../dataset_management_message.py | 134 ++++++------------ 1 file changed, 47 insertions(+), 87 deletions(-) diff --git a/python/lib/communication/dmod/communication/dataset_management_message.py b/python/lib/communication/dmod/communication/dataset_management_message.py index 08284e310..03e88b43d 100644 --- a/python/lib/communication/dmod/communication/dataset_management_message.py +++ b/python/lib/communication/dmod/communication/dataset_management_message.py @@ -390,9 +390,26 @@ class MaaSDatasetManagementMessage(DatasetManagementMessage, ExternalRequest): the superclass. """ - _SERIAL_KEY_DATA_REQUIREMENTS = 'data_requirements' - _SERIAL_KEY_OUTPUT_FORMATS = 'output_formats' - _SERIAL_KEY_SESSION_SECRET = 'session_secret' + data_requirements: List[DataRequirement] = Field( + default_factory=list, + description="List of all the explicit and implied data requirements for this request.", + ) + """ + By default, this is an empty list, though it is possible to append requirements to the list. + """ + + output_formats: List[DataFormat] = Field( + default_factory=list, + description="List of the formats of each required output dataset for the requested task." + ) + """ + By default, this will be an empty list, though if any request does need to produce output, + formats can be appended to it. + """ + + class Config: + # NOTE: in parent class, `ExternalRequest`, `session_secret` is aliased using `session-secret` + fields = {"session_secret": {"alias": "session_secret"}} @classmethod def factory_create(cls, mgmt_msg: DatasetManagementMessage, session_secret: str) -> 'MaaSDatasetManagementMessage': @@ -416,90 +433,33 @@ def factory_init_correct_response_subtype(cls, json_obj: dict) -> 'MaaSDatasetMa """ return MaaSDatasetManagementResponse.factory_init_from_deserialized_json(json_obj=json_obj) - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict) -> Optional['MaaSDatasetManagementMessage']: - try: - # Inject this if necessary before passing to supertype - if 'deserialized_class' not in json_obj: - json_obj['deserialized_class'] = cls - elif isinstance(json_obj['deserialized_class'], str): - json_obj['deserialized_class'] = globals()[json_obj['deserialized_class']] - # Also inject things that will be used as additional kwargs to the eventual class init - if 'deserialized_class_kwargs' not in json_obj: - json_obj['deserialized_class_kwargs'] = dict() - if 'session_secret' not in json_obj['deserialized_class_kwargs']: - json_obj['deserialized_class_kwargs']['session_secret'] = json_obj[cls._SERIAL_KEY_SESSION_SECRET] - - obj = super().factory_init_from_deserialized_json(json_obj=json_obj) - - # Also add these if there happened to be any present - if cls._SERIAL_KEY_DATA_REQUIREMENTS in json_obj: - obj.data_requirements.extend([DataRequirement.factory_init_from_deserialized_json(json) for json in - json_obj[cls._SERIAL_KEY_DATA_REQUIREMENTS]]) - if cls._SERIAL_KEY_OUTPUT_FORMATS in json_obj: - obj.output_formats.extend( - [DataFormat.get_for_name(f) for f in json_obj[cls._SERIAL_KEY_OUTPUT_FORMATS]]) - - # Finally, return the object - return obj - except Exception as e: - return None - - def __init__(self, session_secret: str, *args, **kwargs): - """ - - Keyword Args - ---------- - session_secret : str - action : ManagementAction - dataset_name : Optional[str] - is_read_only_dataset : bool - category : Optional[DataCategory] - data_location : Optional[str] - is_pending_data : bool - query : Optional[DataQuery] - """ - super(MaaSDatasetManagementMessage, self).__init__(session_secret=session_secret, *args, **kwargs) - self._data_requirements = [] - self._output_formats = [] - - @property - def data_requirements(self) -> List[DataRequirement]: - """ - List of all the explicit and implied data requirements for this request. - - By default, this is an empty list, though it is possible to append requirements to the list. - - Returns - ------- - List[DataRequirement] - List of all the explicit and implied data requirements for this request. - """ - return self._data_requirements - - @property - def output_formats(self) -> List[DataFormat]: - """ - List of the formats of each required output dataset for the requested task. - - By default, this will be an empty list, though if any request does need to produce output, formats can be - appended to it - - Returns - ------- - List[DataFormat] - List of the formats of each required output dataset for the requested. - """ - return self._output_formats - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serial = super(MaaSDatasetManagementMessage, self).to_dict() - serial[self._SERIAL_KEY_SESSION_SECRET] = self.session_secret - if len(self.data_requirements) > 0: - serial[self._SERIAL_KEY_DATA_REQUIREMENTS] = [r.to_dict() for r in self.data_requirements] - if len(self.output_formats) > 0: - serial[self._SERIAL_KEY_OUTPUT_FORMATS] = [f.name for f in self.output_formats] - return serial + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> Dict[str, Union[str, int]]: + exclude = exclude or set() + + if not self.data_requirements: + exclude.add("data_requirements") + if not self.output_formats: + exclude.add("output_formats") + + return super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) class MaaSDatasetManagementResponse(ExternalRequestResponse, DatasetManagementResponse): From 6267ab93259f684f68b394b8a1388581e31176b0 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 17 Jan 2023 22:50:20 -0500 Subject: [PATCH 086/205] refactor PartitionRequest --- .../dmod/communication/partition_request.py | 153 +++++++----------- 1 file changed, 61 insertions(+), 92 deletions(-) diff --git a/python/lib/communication/dmod/communication/partition_request.py b/python/lib/communication/dmod/communication/partition_request.py index 4184eaa92..bc1d41689 100644 --- a/python/lib/communication/dmod/communication/partition_request.py +++ b/python/lib/communication/dmod/communication/partition_request.py @@ -1,6 +1,7 @@ from uuid import uuid4 from numbers import Number -from typing import Optional, Union, Dict +from pydantic import Field +from typing import ClassVar, Dict, Optional, Union from .message import AbstractInitRequest, MessageEventType, Response from .maas_request import ExternalRequest @@ -11,27 +12,23 @@ class PartitionRequest(AbstractInitRequest): Request for partitioning of the catchments in a hydrofabric, typically for distributed processing. """ - event_type = MessageEventType.PARTITION_REQUEST - _KEY_NUM_PARTS = 'partition_count' - _KEY_NUM_CATS = 'catchment_count' - _KEY_UUID = 'uuid' - _KEY_HYDROFABRIC_UID = 'hydrofabric_uid' - _KEY_HYDROFABRIC_DATA_ID = 'hydrofabric_data_id' - _KEY_HYDROFABRIC_DESC = 'hydrofabric_description' + event_type: ClassVar[MessageEventType] = MessageEventType.PARTITION_REQUEST - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict, **kwargs): - hy_data_id = json_obj[cls._KEY_HYDROFABRIC_DATA_ID] if cls._KEY_HYDROFABRIC_DATA_ID in json_obj else None + num_partitions: int + uuid: Optional[str] = Field(default_factory=lambda: str(uuid4()), description="Get (as a string) the UUID for this instance.") + hydrofabric_uid: str = Field(description="The unique identifier for the hydrofabric that is to be partitioned.") + hydrofabric_data_id: Optional[str] = Field(description="When known, the 'data_id' for the dataset containing the associated hydrofabric.") + description: Optional[str] = Field(description="The optional description or name of the hydrofabric that is to be partitioned.") - try: - return cls(hydrofabric_uid=json_obj[cls._KEY_HYDROFABRIC_UID], - hydrofabric_data_id=hy_data_id, - num_partitions=json_obj[cls._KEY_NUM_PARTS], - description=json_obj.get(cls._KEY_HYDROFABRIC_DESC), - uuid=json_obj[cls._KEY_UUID], - **kwargs) - except: - return None + class Config: + fields = { + "num_partitions": {"alias": "partition_count"}, + "description": {"alias": "hydrofabric_description"} + } + + # QUESTION: is this unused? + # catchment_count: str + # _KEY_NUM_CATS = 'catchment_count' @classmethod def factory_init_correct_response_subtype(cls, json_obj: dict): @@ -48,8 +45,16 @@ def factory_init_correct_response_subtype(cls, json_obj: dict): """ return PartitionResponse.factory_init_from_deserialized_json(json_obj=json_obj) - def __init__(self, num_partitions: int, hydrofabric_uid: str, hydrofabric_data_id: Optional[str] = None, - uuid: Optional[str] = None, description: Optional[str] = None, *args, **kwargs): + def __init__( + self, + # NOTE: default is None for backwards compatibility. could be specified using alias. + num_partitions: int = None, + hydrofabric_uid: str = None, + hydrofabric_data_id: Optional[str] = None, + uuid: Optional[str] = None, + description: Optional[str] = None, + **data + ): """ Initialize the request. @@ -66,12 +71,15 @@ def __init__(self, num_partitions: int, hydrofabric_uid: str, hydrofabric_data_i description : Optional[str] An optional description or name for the hydrofabric. """ - super(PartitionRequest, self).__init__(*args, **kwargs) - self._hydrofabric_uid = hydrofabric_uid - self._hydrofabric_data_id = hydrofabric_data_id - self._num_partitions = num_partitions - self._uuid = uuid if uuid else str(uuid4()) - self._description = description + + super().__init__( + num_partitions=num_partitions or data.pop("partition_count", None), + hydrofabric_uid=hydrofabric_uid or data.pop("hydrofabric_description", None), + hydrofabric_data_id=hydrofabric_data_id, + uuid=uuid, + description=description, + **data + ) def __eq__(self, other): return self.uuid == other.uuid and self.hydrofabric_uid == other.hydrofabric_uid and self.hydrofabric_data_id == other.hydrofabric_data_id @@ -79,71 +87,32 @@ def __eq__(self, other): def __hash__(self): return hash("{}{}{}".format(self.uuid, self.hydrofabric_uid, self.hydrofabric_data_id)) - @property - def description(self) -> Optional[str]: - """ - The optional description or name of the hydrofabric that is to be partitioned. - - Returns - ------- - Optional[str] - The optional description or name of the hydrofabric that is to be partitioned. - """ - return self._description - - @property - def hydrofabric_data_id(self) -> Optional[str]: - """ - When known, the 'data_id' for the dataset containing the associated hydrofabric. - - Returns - ------- - Optional[str] - When known, the 'data_id' for the dataset containing the associated hydrofabric. - """ - return self._hydrofabric_data_id - - @property - def hydrofabric_uid(self) -> str: - """ - The unique identifier for the hydrofabric that is to be partitioned. - - Returns - ------- - str - The unique identifier for the hydrofabric that is to be partitioned. - """ - return self._hydrofabric_uid - - @property - def num_partitions(self) -> int: - return self._num_partitions - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serialized = { - 'class_name': self.__class__.__name__, - self._KEY_HYDROFABRIC_UID: self.hydrofabric_uid, - self._KEY_NUM_PARTS: self.num_partitions, - #self._KEY_SECRET: self.session_secret, - self._KEY_UUID: self.uuid, - } - if self.description is not None: - serialized[self._KEY_HYDROFABRIC_DESC] = self.description - if self.hydrofabric_data_id is not None: - serialized[self._KEY_HYDROFABRIC_DATA_ID] = self.hydrofabric_data_id - return serialized - - @property - def uuid(self) -> str: - """ - Get (as a string) the UUID for this instance. + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> Dict[str, Union[str, int]]: + serial = super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + # include "class_name" if not in excludes + if exclude is not None and "class_name" not in exclude: + serial["class_name"] = self.__class__.__name__ - Returns - ------- - str - The UUID for this instance, as a string. - """ - return self._uuid + return serial class PartitionResponse(Response): From 1d1449b9d561a00f6a6757c73cf1ffefa307839e Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 17 Jan 2023 22:50:36 -0500 Subject: [PATCH 087/205] refactor PartitionExternalRequest --- .../dmod/communication/partition_request.py | 20 +++---------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/python/lib/communication/dmod/communication/partition_request.py b/python/lib/communication/dmod/communication/partition_request.py index bc1d41689..729acf384 100644 --- a/python/lib/communication/dmod/communication/partition_request.py +++ b/python/lib/communication/dmod/communication/partition_request.py @@ -173,23 +173,9 @@ def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: class PartitionExternalRequest(PartitionRequest, ExternalRequest): - _KEY_SECRET = 'session_secret' - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict, **kwargs): - try: - kwargs['session_secret'] = json_obj[cls._KEY_SECRET] - return super(PartitionExternalRequest, cls).factory_init_from_deserialized_json(json_obj, **kwargs) - except: - return None - - def __init__(self, *args, **kwargs): - super(PartitionExternalRequest, self).__init__(*args, **kwargs) - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serial = super(PartitionExternalRequest, self).to_dict() - serial[self._KEY_SECRET] = self.session_secret - return serial + class Config: + # NOTE: in parent class, `ExternalRequest`, `session_secret` is aliased using `session-secret` + fields = {"session_secret": {"alias": "session_secret"}} class PartitionExternalResponse(PartitionResponse): From 3bf30e7f394b4e0fe85c039e0fc5b9f41dff6582 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 17 Jan 2023 22:52:50 -0500 Subject: [PATCH 088/205] fix bug in PartitionRequest __init__ --- .../communication/dmod/communication/partition_request.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/lib/communication/dmod/communication/partition_request.py b/python/lib/communication/dmod/communication/partition_request.py index 729acf384..a2d0c8eae 100644 --- a/python/lib/communication/dmod/communication/partition_request.py +++ b/python/lib/communication/dmod/communication/partition_request.py @@ -47,9 +47,10 @@ def factory_init_correct_response_subtype(cls, json_obj: dict): def __init__( self, + *, + hydrofabric_uid: str, # NOTE: default is None for backwards compatibility. could be specified using alias. num_partitions: int = None, - hydrofabric_uid: str = None, hydrofabric_data_id: Optional[str] = None, uuid: Optional[str] = None, description: Optional[str] = None, @@ -74,10 +75,10 @@ def __init__( super().__init__( num_partitions=num_partitions or data.pop("partition_count", None), - hydrofabric_uid=hydrofabric_uid or data.pop("hydrofabric_description", None), + hydrofabric_uid=hydrofabric_uid, hydrofabric_data_id=hydrofabric_data_id, uuid=uuid, - description=description, + description=description or data.pop("hydrofabric_description", None), **data ) From d9376b576eb433868ccf4a3698e7e8a3eb1d6e99 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 17 Jan 2023 23:02:42 -0500 Subject: [PATCH 089/205] refactor EvaluationRequest --- .../dmod/communication/evaluation_request.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/lib/communication/dmod/communication/evaluation_request.py b/python/lib/communication/dmod/communication/evaluation_request.py index 4c633d259..f18b08289 100644 --- a/python/lib/communication/dmod/communication/evaluation_request.py +++ b/python/lib/communication/dmod/communication/evaluation_request.py @@ -4,7 +4,7 @@ from numbers import Number from typing import Dict -from typing import Union +from typing import ClassVar, Union from .message import Message, MessageEventType, Response @@ -16,18 +16,16 @@ class EvaluationRequest(Message, abc.ABC): A request to be forwarded to the evaluation service """ - event_type: MessageEventType = MessageEventType.EVALUATION_REQUEST + event_type: ClassVar[MessageEventType] = MessageEventType.EVALUATION_REQUEST """ :class:`MessageEventType`: the event type for this message implementation """ + action: str + @classmethod @abc.abstractmethod def get_action(cls) -> str: ... - @property - def action(self) -> str: - return self.get_action() - class EvaluationConnectionRequest(EvaluationRequest): """ From 702ea32eeeaa844f5e8f06711f33b26596127061 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 17 Jan 2023 23:04:17 -0500 Subject: [PATCH 090/205] refactor EvaluationConnectionRequest --- .../dmod/communication/evaluation_request.py | 29 +++++-------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/python/lib/communication/dmod/communication/evaluation_request.py b/python/lib/communication/dmod/communication/evaluation_request.py index f18b08289..72430ade4 100644 --- a/python/lib/communication/dmod/communication/evaluation_request.py +++ b/python/lib/communication/dmod/communication/evaluation_request.py @@ -31,18 +31,17 @@ class EvaluationConnectionRequest(EvaluationRequest): """ A request used to communicate through a chained websocket connection """ - _action_parameters: typing.Dict[str, typing.Any] + action: typing.Literal["connect"] = "connect" + parameters: typing.Dict[str, typing.Any] - def __init__(self, **kwargs): - self._action_parameters = kwargs or dict() + class Config: + fields = { + "parameters": {"alias": "action_parameters"} + } @classmethod def get_action(cls) -> str: - return "connect" - - @property - def parameters(self) -> typing.Dict[str, typing.Any]: - return self._action_parameters + return cls.action @classmethod def factory_init_from_deserialized_json(cls, json_obj: dict) -> typing.Optional[EvaluationRequest]: @@ -62,20 +61,6 @@ def factory_init_from_deserialized_json(cls, json_obj: dict) -> typing.Optional[ return cls(**json_obj) - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - """ - Returns: - A dictionary representation of this request - """ - dictionary_representation = { - "action": self.action - } - - if self._action_parameters: - dictionary_representation['action_parameters'] = self._action_parameters.copy() - - return dictionary_representation - class EvaluationConnectionRequestResponse(Response): pass From 5e9c2411a613d88ea05319216c7ce1ed4ad7d0f4 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 18 Jan 2023 14:04:35 -0500 Subject: [PATCH 091/205] fix typing hints, default value, and retrieval of action in EvaluationConnectionRequest --- .../dmod/communication/evaluation_request.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/lib/communication/dmod/communication/evaluation_request.py b/python/lib/communication/dmod/communication/evaluation_request.py index 72430ade4..203c2ac51 100644 --- a/python/lib/communication/dmod/communication/evaluation_request.py +++ b/python/lib/communication/dmod/communication/evaluation_request.py @@ -3,8 +3,7 @@ import json from numbers import Number -from typing import Dict -from typing import ClassVar, Union +from pydantic import Field, validator from .message import Message, MessageEventType, Response @@ -16,7 +15,7 @@ class EvaluationRequest(Message, abc.ABC): A request to be forwarded to the evaluation service """ - event_type: ClassVar[MessageEventType] = MessageEventType.EVALUATION_REQUEST + event_type: typing.ClassVar[MessageEventType] = MessageEventType.EVALUATION_REQUEST """ :class:`MessageEventType`: the event type for this message implementation """ action: str @@ -32,7 +31,7 @@ class EvaluationConnectionRequest(EvaluationRequest): A request used to communicate through a chained websocket connection """ action: typing.Literal["connect"] = "connect" - parameters: typing.Dict[str, typing.Any] + parameters: typing.Dict[str, typing.Any] = Field(default_factory=dict) class Config: fields = { @@ -41,7 +40,7 @@ class Config: @classmethod def get_action(cls) -> str: - return cls.action + return cls.__fields__["action"].default @classmethod def factory_init_from_deserialized_json(cls, json_obj: dict) -> typing.Optional[EvaluationRequest]: From 56ec508a341eedfce33fe15be8d4683f43a357f6 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 18 Jan 2023 14:05:28 -0500 Subject: [PATCH 092/205] add ActionParameters class. encapsulates evaluation_name and instructions --- .../dmod/communication/evaluation_request.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/lib/communication/dmod/communication/evaluation_request.py b/python/lib/communication/dmod/communication/evaluation_request.py index 203c2ac51..5a99a6676 100644 --- a/python/lib/communication/dmod/communication/evaluation_request.py +++ b/python/lib/communication/dmod/communication/evaluation_request.py @@ -5,6 +5,7 @@ from numbers import Number from pydantic import Field, validator +from dmod.core.serializable import Serializable from .message import Message, MessageEventType, Response SERIALIZABLE_DICT = typing.Dict[str, typing.Union[str, Number, dict, typing.List]] @@ -68,6 +69,16 @@ class EvaluationConnectionRequestResponse(Response): class SaveEvaluationRequest(EvaluationRequest): pass +class ActionParameters(Serializable): + evaluation_name: str + instructions: str + + @validator("instructions", pre=True) + def _coerce_instructions(cls, value): + if isinstance(value, dict): + return json.dumps(value, indent=4) + return value + class StartEvaluationRequest(EvaluationRequest): @classmethod From 7ae2d071613f3d067c872554bd6099d956c0023e Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 18 Jan 2023 14:05:42 -0500 Subject: [PATCH 093/205] refactor StartEvaluationRequest --- .../dmod/communication/evaluation_request.py | 66 +++++++------------ 1 file changed, 25 insertions(+), 41 deletions(-) diff --git a/python/lib/communication/dmod/communication/evaluation_request.py b/python/lib/communication/dmod/communication/evaluation_request.py index 5a99a6676..54cfef349 100644 --- a/python/lib/communication/dmod/communication/evaluation_request.py +++ b/python/lib/communication/dmod/communication/evaluation_request.py @@ -81,62 +81,46 @@ def _coerce_instructions(cls, value): class StartEvaluationRequest(EvaluationRequest): - @classmethod - def get_action(cls) -> str: - return "launch" + action: typing.Literal["launch"] = "launch" - evaluation_name: str = None + # Note: `parameters`, from parent class, is in this subclass, a dictionary representation of + # `ActionParameters` plus arbitrary keys and values + @validator("parameters", pre=True) + def _coerce_action_parameters(cls, value: typing.Union[typing.Dict[str, typing.Any], ActionParameters]): + if isinstance(value, ActionParameters): + return value.to_dict() - instructions: typing.Union[str, dict] = None + parameters = ActionParameters(**value) + return {**value, **parameters.to_dict()} - action_parameters: dict = None + @classmethod + def get_action(cls) -> str: + return cls.__fields__["action"].default @classmethod def factory_init_from_deserialized_json(cls, json_obj: dict) -> typing.Optional[EvaluationRequest]: try: - if "action" in json_obj and json_obj['action'] != cls.get_action(): + if "action" in json_obj and json_obj["action"] != cls.get_action(): return None - if "action_parameters" in json_obj: - parameters = json_obj['action_parameters'] - else: - parameters = json_obj - - missing_instructions = not parameters.get("instructions") \ - or not isinstance(parameters.get("instructions"), (str, dict)) - missing_name = not parameters.get("evaluation_name") - - if missing_instructions or missing_name: - return None - - return cls( - instructions=parameters.get("instructions"), - evaluation_name=parameters.get("evaluation_name"), - **parameters - ) - except Exception as e: + return cls(**json_obj) + except Exception: return None - def to_dict(self) -> SERIALIZABLE_DICT: - return { - "action": self.action, - "action_parameters": self.action_parameters.update( - { - "evaluation_name": self.evaluation_name, - "instructions": self.instructions - } - ) - } - def __init__( self, - instructions: str, - evaluation_name: str, + # NOTE: None for backwards compatibility + instructions: str = None, + evaluation_name: str = None, **kwargs ): - self._instructions = json.dumps(instructions, indent=4) if isinstance(instructions, dict) else instructions - self._evaluation_name = evaluation_name - self._action_parameters = kwargs + # assume no need for backwards compatibility + if instructions is None or evaluation_name is None: + super().__init__(**kwargs) + return + + parameters = ActionParameters(instructions=instructions, evaluation_name=evaluation_name, **kwargs) + super().__init__(parameters=parameters.to_dict()) class FindEvaluationRequest(EvaluationRequest): From 66a28ecbade380fb53e739bb4d51d43de5422ea8 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 18 Jan 2023 14:29:16 -0500 Subject: [PATCH 094/205] refactor DataTransmitMessage --- .../communication/data_transmit_message.py | 98 ++++++++----------- 1 file changed, 40 insertions(+), 58 deletions(-) diff --git a/python/lib/communication/dmod/communication/data_transmit_message.py b/python/lib/communication/dmod/communication/data_transmit_message.py index 6702bb7a0..1a52229d8 100644 --- a/python/lib/communication/dmod/communication/data_transmit_message.py +++ b/python/lib/communication/dmod/communication/data_transmit_message.py @@ -1,5 +1,6 @@ from .message import AbstractInitRequest, MessageEventType, Response -from typing import Dict, Optional, Union +from pydantic import Field +from typing import ClassVar, Dict, Optional, Union from numbers import Number from uuid import UUID @@ -18,64 +19,45 @@ class DataTransmitMessage(AbstractInitRequest): ::class:`str` object. However, instances can be initialized using either ::class:`str` or ::class:`bytes` data. """ - _KEY_SERIES_UUID = 'series_uuid' + event_type: ClassVar[MessageEventType] = MessageEventType.DATA_TRANSMISSION - event_type: MessageEventType = MessageEventType.DATA_TRANSMISSION - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict) -> Optional['DataTransmitMessage']: - try: - return cls(data=json_obj['data'], series_uuid=UUID(json_obj[cls._KEY_SERIES_UUID]), - is_last=json_obj['is_last']) - except Exception as e: - return None - - def __init__(self, data: Union[str, bytes], series_uuid: UUID, is_last: bool = False, *args, **kwargs): - super(DataTransmitMessage, self).__init__(*args, **kwargs) - self._data: str = data if isinstance(data, str) else data.decode() - self._series_uuid = series_uuid - self._is_last: bool = is_last - - @property - def data(self) -> str: - """ - The data carried by this message, in decoded string form. - - Returns - ------- - str - The data carried by this message, in decoded string form. - """ - return self._data - - @property - def is_last(self) -> bool: - """ - Whether this is the last data transmission message in this series. - - Returns - ------- - bool - Whether this is the last data transmission message in this series. - """ - return self._is_last - - @property - def series_uuid(self) -> UUID: - """ - A unique id for the collective series of transmission message this instance is a part of. - - The expectation is that a larger amount of data will be broken up into multiple messages in a series. - - Returns - ------- - UUID - A unique id for the collective series of transmission message this instance is a part of. - """ - return self._series_uuid - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - return {'data': self.data, self._KEY_SERIES_UUID: str(self.series_uuid), 'is_last': self.is_last} + data: str = Field(description="The data carried by this message, in decoded string form.") + series_uuid: UUID = Field(description="A unique id for the collective series of transmission message this instance is a part of.") + """ + The expectation is that a larger amount of data will be broken up into multiple messages in a series. + """ + is_last: bool = Field(False, description="Whether this is the last data transmission message in this series.") + + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> Dict[str, Union[str, int]]: + SERIES_UUID_KEY = "series_uuid" + exclude = exclude or set() + series_uuid_in_exclude = SERIES_UUID_KEY in exclude + exclude.add(SERIES_UUID_KEY) + + serial = super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + if not series_uuid_in_exclude: + serial[SERIES_UUID_KEY] = str(self.series_uuid) + + return serial class DataTransmitResponse(Response): From a0e1c88c498e91abc2caa106defd1da4d16d9523 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 18 Jan 2023 14:54:27 -0500 Subject: [PATCH 095/205] refactor dict method logic to follow pattern foundelse where in the codebase --- .../dmod/communication/update_message.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/lib/communication/dmod/communication/update_message.py b/python/lib/communication/dmod/communication/update_message.py index 6d0e61cbc..0e130c6c2 100644 --- a/python/lib/communication/dmod/communication/update_message.py +++ b/python/lib/communication/dmod/communication/update_message.py @@ -83,13 +83,14 @@ def dict( exclude_defaults: bool = False, exclude_none: bool = False ) -> Dict[str, Union[str, int]]: - _exclude = {"object_type"} - if exclude is not None: - _exclude = {*_exclude, *exclude} + OBJECT_TYPE_KEY = "object_type" + exclude = exclude or set() + object_type_in_exclude = OBJECT_TYPE_KEY in exclude + exclude.add(OBJECT_TYPE_KEY) serial = super().dict( include=include, - exclude=_exclude, + exclude=exclude, by_alias=by_alias, skip_defaults=skip_defaults, exclude_unset=exclude_unset, @@ -97,8 +98,8 @@ def dict( exclude_none=exclude_none, ) - if exclude is None or "object_type" not in exclude: - serial["object_type"] = self.object_type_string + if not object_type_in_exclude: + serial[OBJECT_TYPE_KEY] = self.object_type_string return serial From 37cac10bed0b99a70dc41a6d1ee1a82561a29689 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 18 Jan 2023 15:12:20 -0500 Subject: [PATCH 096/205] complete refactor of UpdateMessageResponse --- .../dmod/communication/update_message.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/python/lib/communication/dmod/communication/update_message.py b/python/lib/communication/dmod/communication/update_message.py index 0e130c6c2..cc136fab6 100644 --- a/python/lib/communication/dmod/communication/update_message.py +++ b/python/lib/communication/dmod/communication/update_message.py @@ -108,14 +108,14 @@ class UpdateMessageData(Serializable): object_found: Optional[bool] -class UpdateMessageResponse(UpdateMessageData, Response): +class UpdateMessageResponse(Response): """ The subtype of ::class:`Response` appropriate for ::class:`UpdateMessage` objects. """ response_to_type: ClassVar[Type[AbstractInitRequest]] = UpdateMessage - data: Optional[UpdateMessageData] = Field(default_factory=UpdateMessageData) + data: UpdateMessageData = Field(default_factory=UpdateMessageData) @classmethod def get_digest_subkey(cls) -> str: @@ -129,7 +129,7 @@ def get_digest_subkey(cls) -> str: The "subkey" (i.e., the key for the value within the nested ``data`` dictionary) for the ``digest`` in serialized representations. """ - return cls.__fields__["digest"].alias + return UpdateMessageData.__fields__["digest"].alias @classmethod def get_object_found_subkey(cls) -> str: @@ -143,7 +143,7 @@ def get_object_found_subkey(cls) -> str: The "subkey" (i.e., the key for the value within the nested ``data`` dictionary) for the ``digest`` in serialized representations. """ - return cls.__fields__["object_found"].alias + return UpdateMessageData.__fields__["object_found"].alias def __init__(self, success: bool, reason: str, response_text: str = '', data: Optional[Dict[str, Union[str, bool]]] = None, digest: Optional[str] = None, @@ -158,5 +158,18 @@ def __init__(self, success: bool, reason: str, response_text: str = '', if object_found is None and self.get_object_found_subkey() in data: object_found = data[self.get_object_found_subkey()] - super().__init__(success=success, reason=reason, message=response_text, - data=UpdateMessageData(digest=digest, object_found=object_found)) + super().__init__( + success=success, + reason=reason, + message=response_text, + data=UpdateMessageData(digest=digest, object_found=object_found), + **kwargs + ) + + @property + def digest(self) -> Optional[str]: + return self.data.digest + + @property + def object_found(self) -> Optional[bool]: + return self.data.object_found From 56716a6ec162e0183f8d1c09b1f2f51e639aaa3c Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 18 Jan 2023 15:34:21 -0500 Subject: [PATCH 097/205] add missing parameters field to StartEvaluationRequest --- .../dmod/communication/evaluation_request.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/lib/communication/dmod/communication/evaluation_request.py b/python/lib/communication/dmod/communication/evaluation_request.py index 54cfef349..78f76b6d4 100644 --- a/python/lib/communication/dmod/communication/evaluation_request.py +++ b/python/lib/communication/dmod/communication/evaluation_request.py @@ -82,9 +82,15 @@ def _coerce_instructions(cls, value): class StartEvaluationRequest(EvaluationRequest): action: typing.Literal["launch"] = "launch" + parameters: typing.Dict[str, typing.Any] - # Note: `parameters`, from parent class, is in this subclass, a dictionary representation of - # `ActionParameters` plus arbitrary keys and values + class Config: + fields = { + "parameters": {"alias": "action_parameters"} + } + + # Note: `parameters` is a dictionary representation of `ActionParameters` plus arbitrary keys + # and values @validator("parameters", pre=True) def _coerce_action_parameters(cls, value: typing.Union[typing.Dict[str, typing.Any], ActionParameters]): if isinstance(value, ActionParameters): From 6c875d8aa6b9ba522a1499a690236381ddc5ea4b Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 18 Jan 2023 15:44:00 -0500 Subject: [PATCH 098/205] fix bug in validating Sesssion dates --- python/lib/communication/dmod/communication/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/communication/dmod/communication/session.py b/python/lib/communication/dmod/communication/session.py index 31cc6ca1b..6a6bf68d9 100644 --- a/python/lib/communication/dmod/communication/session.py +++ b/python/lib/communication/dmod/communication/session.py @@ -58,7 +58,7 @@ class Session(Serializable): @validator("created", "last_accessed", pre=True) def validate_date(cls, value): - if isinstance(value, datetime): + if isinstance(value, datetime.datetime): return value try: From 77c22d3d1171f66a87eaf6735a9271ed1f93d4d3 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 18 Jan 2023 15:46:41 -0500 Subject: [PATCH 099/205] explicity cast from IPAddress type into str. this ensures ip_addresses are correcly formatted as well as validated. --- python/lib/communication/dmod/communication/session.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/lib/communication/dmod/communication/session.py b/python/lib/communication/dmod/communication/session.py index 6a6bf68d9..93c7b2620 100644 --- a/python/lib/communication/dmod/communication/session.py +++ b/python/lib/communication/dmod/communication/session.py @@ -224,8 +224,7 @@ class FullAuthSession(Session): @validator("ip_address", pre=True) def cast_ip_address_to_str(cls, value: str) -> str: # this will raise if cannot be coerced into IPv(4|6)Address - IPvAnyAddress.validate(value) - return value + return str(IPvAnyAddress.validate(value)) @classmethod def factory_init_from_deserialized_json(cls, json_obj: dict): From ad8833da27716f9af3d7ce1cb0bc2ad6136c56ef Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 18 Jan 2023 15:48:14 -0500 Subject: [PATCH 100/205] fix validation of success --- .../communication/dmod/communication/session.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/python/lib/communication/dmod/communication/session.py b/python/lib/communication/dmod/communication/session.py index 93c7b2620..be91e4c2e 100644 --- a/python/lib/communication/dmod/communication/session.py +++ b/python/lib/communication/dmod/communication/session.py @@ -339,10 +339,16 @@ class SessionInitResponse(Response): # NOTE: this field _is_ optional, however `data` will be FailedSessionInitInfo if it is not # provided or set to None. - # NOTE: order of this Union matters. types will be coerced from left to right. meaning, more + # NOTE: order of this Union matters. types will be coerced from left to right. meaning, more # specific types (i.e. subtypes) should be listed before more general types. see `SmartUnion` # for more detail: https://docs.pydantic.dev/usage/model_config/#smart-union - data: Union[FailedSessionInitInfo, FullAuthSession, Session] + data: Union[FailedSessionInitInfo, FullAuthSession, Session] = Field( + default_factory=lambda: FailedSessionInitInfo( + user="", + reason=SessionInitFailureReason.SESSION_DETAILS_MISSING, + details="Instantiated SessionInitResponse object without session data; defaulting to failure", + ) + ) @root_validator(pre=True) def _coerce_data_field(cls, values): @@ -378,10 +384,11 @@ def _coerce_data_field(cls, values): values["data"] = coerced_data return values - @validator("success") - def _update_success(cls, value: bool, values): + @root_validator() + def _update_success(cls, values): # Make sure to reset/change self.success if self.data ends up being a failure info object - return value and isinstance(values["data"], Session) + values["success"] = values["success"] and isinstance(values["data"], Session) + return values def __eq__(self, other): return self.__class__ == other.__class__ \ From 0ffca3782a6c762fcb01abf7f9d5505841bf0651 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 18 Jan 2023 16:34:12 -0500 Subject: [PATCH 101/205] fix FailedSessionInitInfo's fail_time serialization formatting --- .../dmod/communication/session.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/python/lib/communication/dmod/communication/session.py b/python/lib/communication/dmod/communication/session.py index be91e4c2e..db669a07d 100644 --- a/python/lib/communication/dmod/communication/session.py +++ b/python/lib/communication/dmod/communication/session.py @@ -287,6 +287,37 @@ class FailedSessionInitInfo(Serializable): def get_datetime_str_format(cls): return Session.get_datetime_str_format() + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> Dict[str, Union[str, int]]: + FAIL_TIME_KEY = "fail_time" + exclude = exclude or set() + fail_time_in_exclude = FAIL_TIME_KEY in exclude + exclude.add(FAIL_TIME_KEY) + + serial = super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + if not fail_time_in_exclude: + serial[FAIL_TIME_KEY] = self.fail_time.strftime(self.get_datetime_str_format()) + + return serial + # Define this custom type here for hinting SessionInitDataType = Union[Session, FailedSessionInitInfo] From 5ce9223c23f9afcbaac04bf657f7feaa80f70212 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 18 Jan 2023 16:36:20 -0500 Subject: [PATCH 102/205] try and coerce SessionInitResponse' data field in correct order --- python/lib/communication/dmod/communication/session.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/python/lib/communication/dmod/communication/session.py b/python/lib/communication/dmod/communication/session.py index db669a07d..746088b57 100644 --- a/python/lib/communication/dmod/communication/session.py +++ b/python/lib/communication/dmod/communication/session.py @@ -373,13 +373,7 @@ class SessionInitResponse(Response): # NOTE: order of this Union matters. types will be coerced from left to right. meaning, more # specific types (i.e. subtypes) should be listed before more general types. see `SmartUnion` # for more detail: https://docs.pydantic.dev/usage/model_config/#smart-union - data: Union[FailedSessionInitInfo, FullAuthSession, Session] = Field( - default_factory=lambda: FailedSessionInitInfo( - user="", - reason=SessionInitFailureReason.SESSION_DETAILS_MISSING, - details="Instantiated SessionInitResponse object without session data; defaulting to failure", - ) - ) + data: Union[FullAuthSession, Session, FailedSessionInitInfo] @root_validator(pre=True) def _coerce_data_field(cls, values): From 11632386278f336071d81d15fa5f7564b43d561c Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 19 Jan 2023 08:52:57 -0500 Subject: [PATCH 103/205] add NWMRequestBody and NWMInnerRequestBody For backwards compatibility sake (nested properties), it was required to break this functionality into two classes. For now, NWMRequestBody forwards the properties of NWMInnerRequestBody. A TODO was added to revist this section of code and consider flattening the property heirarchy to become more consistent with NGENRequestBody. --- .../maas_request/nwm/nwm_exec_request_body.py | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 python/lib/communication/dmod/communication/maas_request/nwm/nwm_exec_request_body.py diff --git a/python/lib/communication/dmod/communication/maas_request/nwm/nwm_exec_request_body.py b/python/lib/communication/dmod/communication/maas_request/nwm/nwm_exec_request_body.py new file mode 100644 index 000000000..b6e1d2705 --- /dev/null +++ b/python/lib/communication/dmod/communication/maas_request/nwm/nwm_exec_request_body.py @@ -0,0 +1,75 @@ +from pydantic import root_validator + +from dmod.core.meta_data import ( + DataCategory, + DataDomain, + DataFormat, + DataRequirement, + DiscreteRestriction, +) +from dmod.core.execution import AllocationParadigm +from dmod.core.serializable import Serializable +from ..model_exec_request_body import ModelExecRequestBody + +from typing import List, Literal + + +class NWMInnerRequestBody(ModelExecRequestBody): + name: Literal["nwm"] = "nwm" + + # NOTE: default value, `None`, is not validated by pydantic + data_requirements: List[DataRequirement] = None + + @root_validator() + def _add_data_requirements_if_missing(cls, values: dict): + data_requirements = values["data_requirements"] + + # None is non-validated default + if data_requirements is None: + config_data_id: str = values["config_data_id"] + + data_id_restriction = DiscreteRestriction( + variable="data_id", values=[config_data_id] + ) + values["data_requirements"] = [ + DataRequirement( + domain=DataDomain( + data_format=DataFormat.NWM_CONFIG, + discrete_restrictions=[data_id_restriction], + ), + is_input=True, + category=DataCategory.CONFIG, + ) + ] + + return values + + class Config: + # NOTE: `name` field is not included at this point for backwards compatibility sake. This + # may change in the future. + fields = {"name": {"exclude": True}} + + +class NWMRequestBody(Serializable): + # TODO: flatten this hierarchy by replacing NWMRequestBody with NWMInnerRequestBody. + nwm: NWMInnerRequestBody + + @property + def name(self) -> str: + return self.nwm.name + + @property + def config_data_id(self) -> str: + return self.nwm.config_data_id + + @property + def cpu_count(self) -> int: + return self.nwm.cpu_count + + @property + def allocation_paradigm(self) -> AllocationParadigm: + return self.nwm.allocation_paradigm + + @property + def data_requirements(self) -> List[DataRequirement]: + return self.nwm.data_requirements From 25ed1748ee6618035db576acc3f488d292a45248 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 19 Jan 2023 09:08:36 -0500 Subject: [PATCH 104/205] refactor NWMRequest --- .../maas_request/nwm/nwm_request.py | 144 ++++-------------- 1 file changed, 32 insertions(+), 112 deletions(-) diff --git a/python/lib/communication/dmod/communication/maas_request/nwm/nwm_request.py b/python/lib/communication/dmod/communication/maas_request/nwm/nwm_request.py index e9d947601..f417c8e99 100644 --- a/python/lib/communication/dmod/communication/maas_request/nwm/nwm_request.py +++ b/python/lib/communication/dmod/communication/maas_request/nwm/nwm_request.py @@ -1,26 +1,27 @@ -from typing import List +from typing import ClassVar, List, Optional, Union +from dmod.core.execution import AllocationParadigm from dmod.core.meta_data import ( - DataCategory, - DataDomain, DataFormat, DataRequirement, - DiscreteRestriction, ) from ...message import MessageEventType from ..model_exec_request import ModelExecRequest from ..model_exec_request_response import ModelExecRequestResponse +from .nwm_exec_request_body import NWMRequestBody class NWMRequest(ModelExecRequest): - event_type = MessageEventType.MODEL_EXEC_REQUEST + event_type: ClassVar[MessageEventType] = MessageEventType.MODEL_EXEC_REQUEST """(:class:`MessageEventType`) The type of event for this message""" # Once more the case senstivity of this model name is called into question # note: this is essentially keyed to image_and_domain.yml and the cases must match! - model_name = "nwm" + model_name: ClassVar[str] = "nwm" """(:class:`str`) The name of the model to be used""" + model: NWMRequestBody + @classmethod def factory_init_correct_response_subtype( cls, json_obj: dict @@ -38,65 +39,32 @@ def factory_init_correct_response_subtype( """ return NWMRequestResponse.factory_init_from_deserialized_json(json_obj=json_obj) - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. - - Recall this will look something like: - - { - 'model': { - 'NWM': { - 'allocation_paradigm': '', - 'config_data_id': '', - 'cpu_count': , - 'data_requirements': [ ... (serialized DataRequirement objects) ... ] - } - } - 'session-secret': 'secret-string-val' - } + def __init__( + self, + # required in prior version of code + config_data_id: str = None, + # optional in prior version of code + cpu_count: Optional[int] = None, + allocation_paradigm: Optional[Union[str, AllocationParadigm]] = None, + **data + ): + # assume no need for backwards compatibility + if "model" in data: + super().__init__(**data) + return - Parameters - ---------- - json_obj + data["model"] = dict() + nwm_inner_request_body = {"config_data_id": config_data_id} - Returns - ------- - A new object of this type instantiated from the deserialize JSON object dictionary, or none if the provided - parameter could not be used to instantiated a new object. - """ - try: - nwm_element = json_obj["model"][cls.model_name] - additional_kwargs = dict() - if "cpu_count" in nwm_element: - additional_kwargs["cpu_count"] = nwm_element["cpu_count"] - - if "allocation_paradigm" in nwm_element: - additional_kwargs["allocation_paradigm"] = nwm_element[ - "allocation_paradigm" - ] - - obj = cls( - config_data_id=nwm_element["config_data_id"], - session_secret=json_obj["session-secret"], - **additional_kwargs - ) - - reqs = [ - DataRequirement.factory_init_from_deserialized_json(req_json) - for req_json in json_obj["model"][cls.model_name]["data_requirements"] - ] - - obj._data_requirements = reqs - - return obj - except Exception as e: - return None - - def __init__(self, *args, **kwargs): - super(NWMRequest, self).__init__(*args, **kwargs) - self._data_requirements = None + if cpu_count is not None: + nwm_inner_request_body["cpu_count"] = cpu_count + + if allocation_paradigm is not None: + nwm_inner_request_body["allocation_paradigm"] = allocation_paradigm + + data["model"]["nwm"] = nwm_inner_request_body + + super().__init__(**data) @property def data_requirements(self) -> List[DataRequirement]: @@ -108,21 +76,7 @@ def data_requirements(self) -> List[DataRequirement]: List[DataRequirement] List of all the explicit and implied data requirements for this request. """ - if self._data_requirements is None: - data_id_restriction = DiscreteRestriction( - variable="data_id", values=[self.config_data_id] - ) - self._data_requirements = [ - DataRequirement( - domain=DataDomain( - data_format=DataFormat.NWM_CONFIG, - discrete_restrictions=[data_id_restriction], - ), - is_input=True, - category=DataCategory.CONFIG, - ) - ] - return self._data_requirements + return self.model.data_requirements @property def output_formats(self) -> List[DataFormat]: @@ -136,40 +90,6 @@ def output_formats(self) -> List[DataFormat]: """ return [DataFormat.NWM_OUTPUT] - def to_dict(self) -> dict: - """ - Converts the request to a dictionary that may be passed to web requests. - - Will look like: - - { - 'model': { - 'NWM': { - 'allocation_paradigm': '', - 'config_data_id': '', - 'cpu_count': , - 'data_requirements': [ ... (serialized DataRequirement objects) ... ] - } - } - 'session-secret': 'secret-string-val' - } - - Returns - ------- - dict - A dictionary containing all the data in such a way that it may be used by a web request - """ - model = dict() - model[self.get_model_name()] = dict() - model[self.get_model_name()][ - "allocation_paradigm" - ] = self.allocation_paradigm.name - model[self.get_model_name()]["config_data_id"] = self.config_data_id - model[self.get_model_name()]["cpu_count"] = self.cpu_count - model[self.get_model_name()]["data_requirements"] = [ - r.to_dict() for r in self.data_requirements - ] - return {"model": model, "session-secret": self.session_secret} class NWMRequestResponse(ModelExecRequestResponse): From 92efb6faf6063acbb7244ecdd7d70d50965fd920 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 19 Jan 2023 09:22:04 -0500 Subject: [PATCH 105/205] adds ModelExecRequestBody. refactored abstract encapsulation of model exec request metadata. This metadata previously was encapsulated in ModelExecRequest subtypes. Separation of these concepts aims to improve readability and congruence. ModelExecRequest subtypes should now embed a subtype of ModelExecRequestBody that adds model specific metadata properties. --- .../maas_request/model_exec_request_body.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 python/lib/communication/dmod/communication/maas_request/model_exec_request_body.py diff --git a/python/lib/communication/dmod/communication/maas_request/model_exec_request_body.py b/python/lib/communication/dmod/communication/maas_request/model_exec_request_body.py new file mode 100644 index 000000000..4d063a448 --- /dev/null +++ b/python/lib/communication/dmod/communication/maas_request/model_exec_request_body.py @@ -0,0 +1,31 @@ +from abc import ABC +from pydantic import Field, validator + +from dmod.core.serializable import Serializable +from dmod.core.execution import AllocationParadigm + +from typing import ClassVar + + +class ModelExecRequestBody(Serializable, ABC): + _DEFAULT_CPU_COUNT: ClassVar[int] = 1 + """ The default number of CPUs to assume are being requested for the job, when not explicitly provided. """ + + # model type discriminator field. enables constructing correct subclass based on `name` field + # value. + # override `name` in subclasses using `typing.Literal` + # e.g. `name: Literal["ngen"] = "ngen"` + name: str = Field("", description="The name of the model to be used") + + config_data_id: str = Field(description="Uniquely identifies the dataset with the primary configuration for this request.") + cpu_count: int = Field(_DEFAULT_CPU_COUNT, gt=0, description="The number of processors requested for this job.") + allocation_paradigm: AllocationParadigm = Field( + default_factory=AllocationParadigm.get_default_selection, + description="The allocation paradigm desired for use when allocating resources for this request." + ) + + @validator("name", pre=True) + def _lower_model_name_(cls, value: str): + # NOTE: this should enable case insensitive subclass construction based on `name`, that is + # if all `name` field's are lowercase. + return str(value).lower() From c5495d5240f8546192c68b7c055d6da87602fb41 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 19 Jan 2023 09:41:04 -0500 Subject: [PATCH 106/205] improve get_available_model type hints --- .../dmod/communication/maas_request/model_exec_request.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/lib/communication/dmod/communication/maas_request/model_exec_request.py b/python/lib/communication/dmod/communication/maas_request/model_exec_request.py index de537091b..5d5c9a437 100644 --- a/python/lib/communication/dmod/communication/maas_request/model_exec_request.py +++ b/python/lib/communication/dmod/communication/maas_request/model_exec_request.py @@ -1,6 +1,6 @@ from abc import ABC -from typing import Optional, Union +from typing import Dict, Optional, Union from dmod.core.execution import AllocationParadigm from ..message import MessageEventType @@ -8,7 +8,7 @@ from .external_request import ExternalRequest -def get_available_models() -> dict: +def get_available_models() -> Dict[str, "ModelExecRequest"]: """ :return: The names of all models mapped to their class """ From 90f42f0c3937607bf43b47edb3aacbe6e64c2dff Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 19 Jan 2023 09:41:42 -0500 Subject: [PATCH 107/205] refactor ModelExecRequest. The main change, other than the move to pydantic, is ModelExecRequest's now embed ModelExecRequestBody subtypes that capture model specific metadata. See 123d9ae3 for more detail on ModelExecRequestBody. --- .../maas_request/model_exec_request.py | 65 ++++++++++--------- 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/python/lib/communication/dmod/communication/maas_request/model_exec_request.py b/python/lib/communication/dmod/communication/maas_request/model_exec_request.py index 5d5c9a437..af12a97d0 100644 --- a/python/lib/communication/dmod/communication/maas_request/model_exec_request.py +++ b/python/lib/communication/dmod/communication/maas_request/model_exec_request.py @@ -1,11 +1,12 @@ from abc import ABC -from typing import Dict, Optional, Union +from typing import ClassVar, Dict, Optional, Union from dmod.core.execution import AllocationParadigm from ..message import MessageEventType from .dmod_job_request import DmodJobRequest from .external_request import ExternalRequest +from .model_exec_request_body import ModelExecRequestBody def get_available_models() -> Dict[str, "ModelExecRequest"]: @@ -25,14 +26,16 @@ class ModelExecRequest(ExternalRequest, DmodJobRequest, ABC): An abstract extension of ::class:`DmodJobRequest` for requesting model execution jobs. """ - event_type: MessageEventType = MessageEventType.MODEL_EXEC_REQUEST + event_type: ClassVar[MessageEventType] = MessageEventType.MODEL_EXEC_REQUEST - model_name = None + model_name: ClassVar[str] = None """(:class:`str`) The name of the model to be used""" - _DEFAULT_CPU_COUNT = 1 + _DEFAULT_CPU_COUNT: ClassVar[int] = 1 """ The default number of CPUs to assume are being requested for the job, when not explicitly provided. """ + model: ModelExecRequestBody + @classmethod def factory_init_correct_subtype_from_deserialized_json( cls, json_obj: dict @@ -55,14 +58,10 @@ def factory_init_correct_subtype_from_deserialized_json( A deserialized ::class:`ModelExecRequest` of the appropriate subtype. """ try: - for model in get_available_models(): - if model in json_obj["model"] or ( - "name" in json_obj["model"] and json_obj["model"]["name"] == model - ): - return get_available_models()[ - model - ].factory_init_from_deserialized_json(json_obj) - return None + model_name = json_obj["model"]["name"] + models = get_available_models() + + return models[model_name].factory_init_from_deserialized_json(json_obj) except: return None @@ -71,15 +70,16 @@ def get_model_name(cls) -> str: """ :return: The name of this model """ - return cls.model_name + return cls.__fields__["model"].type_.__fields__["name"].default def __init__( self, - config_data_id: str, + # required in prior version of code + config_data_id: str = None, + # optional in prior version of code cpu_count: Optional[int] = None, allocation_paradigm: Optional[Union[str, AllocationParadigm]] = None, - *args, - **kwargs + **data ): """ Initialize model-exec-specific attributes and state of this request object common to all model exec requests. @@ -89,19 +89,20 @@ def __init__( session_secret : str The session secret for the right session when communicating with the request handler. """ - super(ModelExecRequest, self).__init__(*args, **kwargs) - self._config_data_id = config_data_id - self._cpu_count = ( - cpu_count if cpu_count is not None else self._DEFAULT_CPU_COUNT - ) - if allocation_paradigm is None: - self._allocation_paradigm = AllocationParadigm.get_default_selection() - elif isinstance(allocation_paradigm, str): - self._allocation_paradigm = AllocationParadigm.get_from_name( - allocation_paradigm - ) - else: - self._allocation_paradigm = allocation_paradigm + # assume no need for backwards compatibility + if "model" in data: + super().__init__(**data) + return + + data["model"] = {"config_data_id": config_data_id} + + if cpu_count is not None: + data["model"]["cpu_count"] = cpu_count + + if allocation_paradigm is not None: + data["model"]["allocation_paradigm"] = cpu_count + + super().__init__(**data) def __eq__(self, other): if not self._check_class_compatible_for_equality(other): @@ -132,7 +133,7 @@ def allocation_paradigm(self) -> AllocationParadigm: AllocationParadigm The allocation paradigm desired for use with this request. """ - return self._allocation_paradigm + return self.model.allocation_paradigm @property def config_data_id(self) -> str: @@ -144,7 +145,7 @@ def config_data_id(self) -> str: str Value of ``data_id`` identifying the dataset with the primary configuration applicable to this request. """ - return self._config_data_id + return self.model.config_data_id @property def cpu_count(self) -> int: @@ -156,4 +157,4 @@ def cpu_count(self) -> int: int The number of processors requested for this job. """ - return self._cpu_count + return self.model.cpu_count From b25abf033f3028825b9f64122b2b08da8abc53e2 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 19 Jan 2023 17:54:35 -0500 Subject: [PATCH 108/205] refactor NGENRequest --- .../maas_request/ngen/ngen_request.py | 209 +++++------------- 1 file changed, 59 insertions(+), 150 deletions(-) diff --git a/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py b/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py index 074045370..cf0eb65ff 100644 --- a/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py +++ b/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py @@ -1,6 +1,6 @@ -from numbers import Number +from pydantic import PrivateAttr -from typing import Dict, List, Optional, Set, Union +from typing import List, Optional, Set, Union from dmod.core.meta_data import ( DataCategory, @@ -13,6 +13,7 @@ from ...message import MessageEventType from ..model_exec_request import ModelExecRequest from ..model_exec_request_response import ModelExecRequestResponse +from .ngen_exec_request_body import NGENRequestBody class NGENRequest(ModelExecRequest): @@ -23,58 +24,13 @@ class NGENRequest(ModelExecRequest): model_name = "ngen" # FIXME case sentitivity """(:class:`str`) The name of the model to be used""" - @classmethod - def factory_init_from_deserialized_json( - cls, json_obj: dict - ) -> Optional["NGENRequest"]: - """ - Deserialize request formated as JSON to an instance. - - See the documentation of this type's ::method:`to_dict` for an example of the format of valid JSON. - - Parameters - ---------- - json_obj : dict - The serialized JSON representation of a request object. + model: NGENRequestBody - Returns - ------- - The deserialized ::class:`NGENRequest`, or ``None`` if the JSON was not valid for deserialization. - - See Also - ------- - ::method:`to_dict` - """ - try: - optional_kwargs_w_defaults = dict() - if "cpu_count" in json_obj["model"]: - optional_kwargs_w_defaults["cpu_count"] = json_obj["model"]["cpu_count"] - if "allocation_paradigm" in json_obj["model"]: - optional_kwargs_w_defaults["allocation_paradigm"] = json_obj["model"][ - "allocation_paradigm" - ] - if "catchments" in json_obj["model"]: - optional_kwargs_w_defaults["catchments"] = json_obj["model"][ - "catchments" - ] - if "partition_config_data_id" in json_obj["model"]: - optional_kwargs_w_defaults["partition_config_data_id"] = json_obj[ - "model" - ]["partition_config_data_id"] - - return cls( - time_range=TimeRange.factory_init_from_deserialized_json( - json_obj["model"]["time_range"] - ), - hydrofabric_uid=json_obj["model"]["hydrofabric_uid"], - hydrofabric_data_id=json_obj["model"]["hydrofabric_data_id"], - config_data_id=json_obj["model"]["config_data_id"], - bmi_cfg_data_id=json_obj["model"]["bmi_config_data_id"], - session_secret=json_obj["session-secret"], - **optional_kwargs_w_defaults - ) - except Exception as e: - return None + _hydrofabric_data_requirement = PrivateAttr(None) + _forcing_data_requirement = PrivateAttr(None) + _realization_cfg_data_requirement = PrivateAttr(None) + _bmi_cfg_data_requirement = PrivateAttr(None) + _partition_cfg_data_requirement = PrivateAttr(None) @classmethod def factory_init_correct_response_subtype( @@ -95,20 +51,23 @@ def factory_init_correct_response_subtype( json_obj=json_obj ) - def __eq__(self, other): - return ( - self.time_range == other.time_range - and self.hydrofabric_data_id == other.hydrofabric_data_id - and self.hydrofabric_uid == other.hydrofabric_uid - and self.config_data_id == other.config_data_id - and self.bmi_config_data_id == other.bmi_config_data_id - and self.session_secret == other.session_secret - and self.cpu_count == other.cpu_count - and self.partition_cfg_data_id == other.partition_cfg_data_id - and self.catchments == other.catchments - ) + def __eq__(self, other: "NGENRequest"): + try: + return ( + self.time_range == other.time_range + and self.hydrofabric_data_id == other.hydrofabric_data_id + and self.hydrofabric_uid == other.hydrofabric_uid + and self.config_data_id == other.config_data_id + and self.bmi_config_data_id == other.bmi_config_data_id + and self.session_secret == other.session_secret + and self.cpu_count == other.cpu_count + and self.partition_cfg_data_id == other.partition_cfg_data_id + and self.catchments == other.catchments + ) + except AttributeError: + return False - def __hash__(self): + def __hash__(self) -> int: hash_str = "{}-{}-{}-{}-{}-{}-{}-{}-{}".format( self.time_range.to_json(), self.hydrofabric_data_id, @@ -124,14 +83,15 @@ def __hash__(self): def __init__( self, - time_range: TimeRange, - hydrofabric_uid: str, - hydrofabric_data_id: str, - bmi_cfg_data_id: str, + # required in prior version of code + time_range: TimeRange = None, + hydrofabric_uid: str = None, + hydrofabric_data_id: str = None, + bmi_cfg_data_id: str = None, + # optional in prior version of code catchments: Optional[Union[Set[str], List[str]]] = None, partition_cfg_data_id: Optional[str] = None, - *args, - **kwargs + **data ): """ Initialize an instance. @@ -159,28 +119,24 @@ def __init__( session_secret : str The session secret for the right session when communicating with the MaaS request handler """ - super().__init__(*args, **kwargs) - self._time_range = time_range - self._hydrofabric_uid = hydrofabric_uid - self._hydrofabric_data_id = hydrofabric_data_id - self._bmi_config_data_id = bmi_cfg_data_id - self._part_config_data_id = partition_cfg_data_id - # Convert an initial list to a set to remove duplicates - try: - catchments = set(catchments) - # TypeError should mean that we received `None`, so just use that to set _catchments - except TypeError: - self._catchments = catchments - # Assuming we have a set now, move this set back to list and sort - else: - self._catchments = list(catchments) - self._catchments.sort() - - self._hydrofabric_data_requirement = None - self._forcing_data_requirement = None - self._realization_cfg_data_requirement = None - self._bmi_cfg_data_requirement = None - self._partition_cfg_data_requirement = None + # If `model` key is present, assume there is not a need for backwards compatibility + if "model" in data: + super().__init__(**data) + return + + # NOTE: backwards compatibility support. + model = NGENRequestBody( + time_range=time_range, + hydrofabric_uid=hydrofabric_uid, + hydrofabric_data_id=hydrofabric_data_id, + catchments=catchments, + partition_cfg_data_id=partition_cfg_data_id, + # previous version of code used `bmi_cfg_data_id` as parameter name. + bmi_config_data_id=bmi_cfg_data_id, + **data + ) + + super().__init__(model=model, **data) def _gen_catchments_domain_restriction( self, var_name: str = "catchment_id" @@ -237,7 +193,7 @@ def bmi_config_data_id(self) -> str: str Index value of ``data_id`` to uniquely identify sets of BMI module config data that are otherwise similar. """ - return self._bmi_config_data_id + return self.model.bmi_config_data_id @property def bmi_cfg_data_requirement(self) -> DataRequirement: @@ -276,7 +232,7 @@ def catchments(self) -> Optional[List[str]]: Optional[List[str]] An optional list of catchment ids for those catchments in the request ngen execution. """ - return self._catchments + return self.model.catchments @property def forcing_data_requirement(self) -> DataRequirement: @@ -292,7 +248,7 @@ def forcing_data_requirement(self) -> DataRequirement: # TODO: going to need to address the CSV usage later forcing_domain = DataDomain( data_format=DataFormat.AORC_CSV, - continuous_restrictions=[self._time_range], + continuous_restrictions=[self.model.time_range], discrete_restrictions=[self._gen_catchments_domain_restriction()], ) self._forcing_data_requirement = DataRequirement( @@ -313,10 +269,10 @@ def hydrofabric_data_requirement(self) -> DataRequirement: if self._hydrofabric_data_requirement is None: hydro_restrictions = [ DiscreteRestriction( - variable="hydrofabric_id", values=[self._hydrofabric_uid] + variable="hydrofabric_id", values=[self.model.hydrofabric_uid] ), DiscreteRestriction( - variable="data_id", values=[self._hydrofabric_data_id] + variable="data_id", values=[self.model.hydrofabric_data_id] ), ] hydro_domain = DataDomain( @@ -343,7 +299,7 @@ def hydrofabric_data_id(self) -> str: str The data format ``data_id`` for the hydrofabric dataset to use in requested modeling. """ - return self._hydrofabric_data_id + return self.model.hydrofabric_data_id @property def hydrofabric_uid(self) -> str: @@ -355,7 +311,7 @@ def hydrofabric_uid(self) -> str: str The unique id of the hydrofabric for this modeling request. """ - return self._hydrofabric_uid + return self.model.hydrofabric_uid @property def output_formats(self) -> List[DataFormat]: @@ -384,7 +340,7 @@ def partition_cfg_data_id(self) -> Optional[str]: Optional[str] The data format ``data_id`` for the partition config dataset to use in requested modeling, or ``None``. """ - return self._part_config_data_id + return self.model.partition_cfg_data_id @property def partition_cfg_data_requirement(self) -> DataRequirement: @@ -480,54 +436,7 @@ def time_range(self) -> TimeRange: TimeRange The time range for the requested model execution. """ - return self._time_range - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - """ - Converts the request to a dictionary that may be passed to web requests - - Will look like: - - { - 'model': { - 'name': 'ngen', - 'allocation_paradigm': , - 'cpu_count': , - 'time_range': { }, - 'hydrofabric_data_id': 'hy-data-id-val', - 'hydrofabric_uid': 'hy-uid-val', - 'config_data_id': 'config-data-id-val', - 'bmi_config_data_id': 'bmi-config-data-id', - 'partition_config_data_id': 'partition_config_data_id', - ['catchments': { },] - 'version': 4.0 - }, - 'session-secret': 'secret-string-val' - } - - As a reminder, the ``catchments`` item may be absent, which implies the object does not have a specified list of - catchment ids. - - Returns - ------- - Dict[str, Union[str, Number, dict, list]] - A dictionary containing all the data in such a way that it may be used by a web request - """ - model = dict() - model["name"] = self.get_model_name() - model["allocation_paradigm"] = self.allocation_paradigm.name - model["cpu_count"] = self.cpu_count - model["time_range"] = self.time_range.to_dict() - model["hydrofabric_data_id"] = self.hydrofabric_data_id - model["hydrofabric_uid"] = self.hydrofabric_uid - model["config_data_id"] = self.config_data_id - model["bmi_config_data_id"] = self._bmi_config_data_id - if self.catchments is not None: - model["catchments"] = self.catchments - if self.partition_cfg_data_id is not None: - model["partition_config_data_id"] = self.partition_cfg_data_id - - return {"model": model, "session-secret": self.session_secret} + return self.model.time_range class NGENRequestResponse(ModelExecRequestResponse): From 45586a1ca45ae88d0472fed5bfa6759dfe3e6296 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 19 Jan 2023 19:32:30 -0500 Subject: [PATCH 109/205] refactor SchedulerRequestResponse. split it's data property into its own model. --- .../dmod/communication/scheduler_request.py | 50 +++++++++++-------- .../scheduler_request_response_body.py | 16 ++++++ 2 files changed, 44 insertions(+), 22 deletions(-) create mode 100644 python/lib/communication/dmod/communication/scheduler_request_response_body.py diff --git a/python/lib/communication/dmod/communication/scheduler_request.py b/python/lib/communication/dmod/communication/scheduler_request.py index 68985cc02..250377766 100644 --- a/python/lib/communication/dmod/communication/scheduler_request.py +++ b/python/lib/communication/dmod/communication/scheduler_request.py @@ -1,9 +1,9 @@ from dmod.core.execution import AllocationParadigm -from .maas_request import ModelExecRequest, ModelExecRequestResponse +from .maas_request import ModelExecRequest from .message import AbstractInitRequest, MessageEventType, Response +from .scheduler_request_response_body import SchedulerRequestResponseBody, UNSUCCESSFUL_JOB from pydantic import Field, PrivateAttr -from typing import ClassVar, Dict, Optional, Union - +from typing import ClassVar, Dict, Optional, Type, Union class SchedulerRequestMessage(AbstractInitRequest): @@ -51,13 +51,15 @@ def __init__( cpus: Optional[int] = None, mem: Optional[int] = None, allocation_paradigm: Optional[Union[str, AllocationParadigm]] = None, + **data ): super().__init__( model_request=model_request, user_id=user_id, - cpus=cpus, - memory=mem, - allocation_paradigm=allocation_paradigm + cpus=cpus or data.pop("cpus_", None), + memory=mem or data.pop("memory", None) or self.__fields__["memory"].default, + allocation_paradigm=allocation_paradigm or data.pop("allocation_paradigm_", None), + **data ) if mem is None: self._memory_unset = True @@ -143,37 +145,44 @@ def dict( exclude_none=exclude_none, ) - class SchedulerRequestResponse(Response): - response_to_type = SchedulerRequestMessage + + response_to_type: ClassVar[Type[AbstractInitRequest]] = SchedulerRequestMessage + data: SchedulerRequestResponseBody def __init__(self, job_id: Optional[int] = None, output_data_id: Optional[str] = None, data: dict = None, **kwargs): # TODO: how to handle if kwargs has success=True, but job_id value (as param or in data) implies success=False - key_job_id = ModelExecRequestResponse.get_job_id_key() + # Create an empty data if not supplied a dict, but only if there is a job_id or output_data_id to insert if data is None and (job_id is not None or output_data_id is not None): data = {} + # Prioritize provided job_id over something already in data # Note that this condition implies that either a data dict was passed as param, or one just got created above if job_id is not None: - data[key_job_id] = job_id + data["job_id"] = job_id + # Insert this into dict if present also (again, it being non-None implies data must be a dict object) if output_data_id is not None: - data[ModelExecRequestResponse.get_output_data_id_key()] = output_data_id + data["output_data_id"] = output_data_id + + data_body = SchedulerRequestResponseBody(**data if data is not None else {}) + # Ensure that 'success' is being passed as a kwarg to the superclass constructor - if 'success' not in kwargs: - kwargs['success'] = data is not None and key_job_id in data and data[key_job_id] > 0 - super(SchedulerRequestResponse, self).__init__(data=data, **kwargs) + if "success" not in kwargs: + kwargs["success"] = data is not None and data_body.job_id > 0 + + super().__init__(data=data_body, **kwargs) def __eq__(self, other): return self.__class__ == other.__class__ and self.success == other.success and self.job_id == other.job_id @property - def job_id(self): - if self.success and self.data is not None: - return self.data[ModelExecRequestResponse.get_job_id_key()] + def job_id(self) -> int: + if self.success: + return self.data.job_id else: - return -1 + return UNSUCCESSFUL_JOB # TODO: make sure this value gets included in the data dict @property @@ -186,7 +195,4 @@ def output_data_id(self) -> Optional[str]: Optional[str] The 'data_id' of the output dataset for requested job, or ``None`` if not known. """ - if self.data is not None and ModelExecRequestResponse.get_output_data_id_key() in self.data: - return self.data[ModelExecRequestResponse.get_output_data_id_key()] - else: - return None + return self.data.output_data_id diff --git a/python/lib/communication/dmod/communication/scheduler_request_response_body.py b/python/lib/communication/dmod/communication/scheduler_request_response_body.py new file mode 100644 index 000000000..e068f2fa4 --- /dev/null +++ b/python/lib/communication/dmod/communication/scheduler_request_response_body.py @@ -0,0 +1,16 @@ +from pydantic import Extra + +from dmod.core.serializable import Serializable + +from typing import Optional + +UNSUCCESSFUL_JOB = -1 + + +class SchedulerRequestResponseBody(Serializable): + job_id: int = UNSUCCESSFUL_JOB + output_data_id: Optional[str] + + class Config: + # allow extra model fields + extra = Extra.allow From 5e091541bbfbeccca75d06d9a79cf2f521e0d5d9 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 20 Jan 2023 14:54:42 -0500 Subject: [PATCH 110/205] add NGENRequestBody. encapsulates NGENRequest's `model` field. --- .../ngen/ngen_exec_request_body.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 python/lib/communication/dmod/communication/maas_request/ngen/ngen_exec_request_body.py diff --git a/python/lib/communication/dmod/communication/maas_request/ngen/ngen_exec_request_body.py b/python/lib/communication/dmod/communication/maas_request/ngen/ngen_exec_request_body.py new file mode 100644 index 000000000..a10d5d75f --- /dev/null +++ b/python/lib/communication/dmod/communication/maas_request/ngen/ngen_exec_request_body.py @@ -0,0 +1,33 @@ +from pydantic import validator + +from dmod.core.meta_data import TimeRange +from ..model_exec_request_body import ModelExecRequestBody + +from typing import List, Literal, Optional + + +class NGENRequestBody(ModelExecRequestBody): + name: Literal["ngen"] = "ngen" + + time_range: TimeRange + hydrofabric_uid: str + hydrofabric_data_id: str + bmi_config_data_id: str + # NOTE: consider pydantic.conlist to constrain this type rather than using validators + catchments: Optional[List[str]] + partition_cfg_data_id: Optional[str] + + @validator("catchments") + def validate_deduplicate_and_sort_catchments( + cls, value: List[str] + ) -> Optional[List[str]]: + if value is None: + return None + + deduped = set(value) + return sorted(list(deduped)) + + class Config: + fields = { + "partition_cfg_data_id": {"alias": "partition_config_data_id"}, + } From 014dc3091ade94ff29cba2f7449a8475bb3d449d Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 20 Jan 2023 14:56:53 -0500 Subject: [PATCH 111/205] add class var type hints to NGENRequestResponse --- .../dmod/communication/maas_request/ngen/ngen_request.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py b/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py index cf0eb65ff..c30b8ef73 100644 --- a/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py +++ b/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py @@ -1,6 +1,6 @@ from pydantic import PrivateAttr -from typing import List, Optional, Set, Union +from typing import ClassVar, List, Optional, Set, Type, Union from dmod.core.meta_data import ( DataCategory, @@ -10,7 +10,7 @@ DiscreteRestriction, TimeRange, ) -from ...message import MessageEventType +from ...message import AbstractInitRequest, MessageEventType from ..model_exec_request import ModelExecRequest from ..model_exec_request_response import ModelExecRequestResponse from .ngen_exec_request_body import NGENRequestBody @@ -474,4 +474,4 @@ class NGENRequestResponse(ModelExecRequestResponse): } """ - response_to_type = NGENRequest + response_to_type: ClassVar[Type[AbstractInitRequest]] = NGENRequest From d5113107da1145f5789049e81997504b7e3e235e Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 20 Jan 2023 15:01:39 -0500 Subject: [PATCH 112/205] specify ExternalRequest aliases in Config class. This ensure that subtypes will carry these aliases if they are, for example, to use `pydantic.Field` to provide more subclass specific information. However, this does allow subtypes to override this behavior by providing an overriding alias in their own Config class. --- .../dmod/communication/maas_request/external_request.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/lib/communication/dmod/communication/maas_request/external_request.py b/python/lib/communication/dmod/communication/maas_request/external_request.py index 0b6df842a..04fe934ac 100644 --- a/python/lib/communication/dmod/communication/maas_request/external_request.py +++ b/python/lib/communication/dmod/communication/maas_request/external_request.py @@ -6,6 +6,11 @@ class ExternalRequest(AbstractInitRequest, ABC): """ The base class underlying all types of externally-initiated (and, therefore, authenticated) MaaS system requests. """ + # NOTE: in some places this is serialized as `session-secret` + session_secret: str + + class Config: + fields = {"session_secret": {"alias": "session-secret"}} @classmethod @abstractmethod From 119bb31aa45135a54f5a2ab63a9e5e7302994e00 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 20 Jan 2023 15:05:36 -0500 Subject: [PATCH 113/205] remove, now, unncessary ExternalRequest initializer --- .../communication/maas_request/external_request.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/python/lib/communication/dmod/communication/maas_request/external_request.py b/python/lib/communication/dmod/communication/maas_request/external_request.py index 04fe934ac..b84093e55 100644 --- a/python/lib/communication/dmod/communication/maas_request/external_request.py +++ b/python/lib/communication/dmod/communication/maas_request/external_request.py @@ -28,19 +28,7 @@ def factory_init_correct_response_subtype(cls, json_obj: dict): """ pass - def __init__(self, session_secret: str, *args, **kwargs): - """ - Initialize the base attributes and state of this request object. - - Parameters - ---------- - session_secret : str - The session secret for the right session when communicating with the MaaS request handler - """ - super(ExternalRequest, self).__init__(*args, **kwargs) - self.session_secret = session_secret - - def _check_class_compatible_for_equality(self, other) -> bool: + def _check_class_compatible_for_equality(self, other: object) -> bool: """ Check and return whether another object is of some class that is compatible for equality checking with the class of this instance, such that the class difference does not independently imply the other object and this instance From 8a7ac31ed61c9ba6276c1486fba6d9759ae575c8 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 20 Jan 2023 16:46:22 -0500 Subject: [PATCH 114/205] add class that encapsulates a ModelExecRequestResponseBody --- .../model_exec_request_response_body.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 python/lib/communication/dmod/communication/maas_request/model_exec_request_response_body.py diff --git a/python/lib/communication/dmod/communication/maas_request/model_exec_request_response_body.py b/python/lib/communication/dmod/communication/maas_request/model_exec_request_response_body.py new file mode 100644 index 000000000..9f2c82759 --- /dev/null +++ b/python/lib/communication/dmod/communication/maas_request/model_exec_request_response_body.py @@ -0,0 +1,23 @@ +from ..scheduler_request import SchedulerRequestResponse, SchedulerRequestResponseBody + + +class ModelExecRequestResponseBody(SchedulerRequestResponseBody): + scheduler_response: SchedulerRequestResponse + + @classmethod + def from_scheduler_request_response( + cls, scheduler_response: SchedulerRequestResponse + ) -> "ModelExecRequestResponseBody": + return cls( + job_id=scheduler_response.job_id, + output_data_id=scheduler_response.output_data_id, + scheduler_response=scheduler_response.copy(), + ) + + # NOTE: legacy support. previously this class was treated as a dictionary + def __contains__(self, element: str) -> bool: + return element in self.__dict__ + + # NOTE: legacy support. previously this class was treated as a dictionary + def __getitem__(self, item: str): + return self.__dict__[item] From dc350b666e09cba312cf81faa60c3b81a2c671c8 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 20 Jan 2023 16:46:37 -0500 Subject: [PATCH 115/205] refactor ModelExecRequestResponse --- .../model_exec_request_response.py | 123 ++++++++---------- 1 file changed, 51 insertions(+), 72 deletions(-) diff --git a/python/lib/communication/dmod/communication/maas_request/model_exec_request_response.py b/python/lib/communication/dmod/communication/maas_request/model_exec_request_response.py index f82593cbf..c8514bb12 100644 --- a/python/lib/communication/dmod/communication/maas_request/model_exec_request_response.py +++ b/python/lib/communication/dmod/communication/maas_request/model_exec_request_response.py @@ -1,34 +1,30 @@ from abc import ABC +from typing import Any, ClassVar, Dict, Optional, Type, Union +from pydantic import validator -from typing import Optional - -from ..message import InitRequestResponseReason +from ..scheduler_request import SchedulerRequestResponse, UNSUCCESSFUL_JOB +from ..message import AbstractInitRequest, InitRequestResponseReason from .external_request_response import ExternalRequestResponse from .model_exec_request import ModelExecRequest +from .model_exec_request_response_body import ModelExecRequestResponseBody class ModelExecRequestResponse(ExternalRequestResponse, ABC): - _data_dict_key_job_id = "job_id" - _data_dict_key_output_data_id = "output_data_id" - _data_dict_key_scheduler_response = "scheduler_response" - response_to_type = ModelExecRequest + response_to_type: ClassVar[Type[AbstractInitRequest]] = ModelExecRequest """ The type of :class:`AbstractInitRequest` for which this type is the response""" - @classmethod - def _convert_scheduler_response_to_data_attribute(cls, scheduler_response=None): - if scheduler_response is None: - return None - elif isinstance(scheduler_response, dict) and len(scheduler_response) == 0: - return {} - elif isinstance(scheduler_response, dict): - return scheduler_response - else: - return { - cls._data_dict_key_job_id: scheduler_response.job_id, - cls._data_dict_key_output_data_id: scheduler_response.output_data_id, - cls._data_dict_key_scheduler_response: scheduler_response.to_dict(), - } + data: Optional[Union[ModelExecRequestResponseBody, Dict[str, Any]]] = None + + @validator("data", pre=True) + def _convert_data_field(cls, value: Optional[Union[SchedulerRequestResponse, ModelExecRequestResponseBody, Dict[str, Any]]]) -> Optional[Union[ModelExecRequestResponseBody, Dict[str, Any]]]: + if value is None: + return value + + elif isinstance(value, SchedulerRequestResponse): + return ModelExecRequestResponseBody.from_scheduler_request_response(value) + + return value @classmethod def get_job_id_key(cls) -> str: @@ -40,7 +36,7 @@ def get_job_id_key(cls) -> str: str Serialization dictionary key for the field containing the ::attribute:`job_id` property. """ - return str(cls._data_dict_key_job_id) + return "job_id" @classmethod def get_output_data_id_key(cls) -> str: @@ -52,7 +48,7 @@ def get_output_data_id_key(cls) -> str: str Serialization dictionary key for the field containing the ::attribute:`output_data_id` property. """ - return str(cls._data_dict_key_output_data_id) + return "output_data_id" @classmethod def get_scheduler_response_key(cls) -> str: @@ -64,52 +60,35 @@ def get_scheduler_response_key(cls) -> str: str Serialization dictionary key for the field containing the 'scheduler_response' value. """ - return str(cls._data_dict_key_scheduler_response) - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. + return "scheduler_response" + + def __init__( + self, + scheduler_response: Optional[ + Union[ + SchedulerRequestResponse, ModelExecRequestResponseBody, Dict[str, Any] + ] + ] = None, + **kwargs + ): + if scheduler_response is None: + super().__init__(**kwargs) + return - Parameters - ---------- - json_obj + # NOTE: if `scheduler_response` is not None, it is given precedence over "data" that might + # be present in `kwargs`. + kwargs["data"] = scheduler_response + super().__init__(**kwargs) - Returns - ------- - response_obj : Response - A new object of this type instantiated from the deserialize JSON object dictionary, or none if the provided - parameter could not be used to instantiated a new object. + @property + def job_id(self) -> int: + if isinstance(self.data, ModelExecRequestResponseBody): + return self.data.job_id - See Also - ------- - _factory_init_data_attribute - """ - try: - return cls( - success=json_obj["success"], - reason=json_obj["reason"], - message=json_obj["message"], - scheduler_response=json_obj["data"], - ) - except Exception as e: - return None - - def __init__(self, scheduler_response=None, *args, **kwargs): - data = self._convert_scheduler_response_to_data_attribute(scheduler_response) - if data is not None: - kwargs["data"] = data - super().__init__(*args, **kwargs) + elif isinstance(self.data, dict) and "job_id" in self.data: + return self.data["job_id"] - @property - def job_id(self): - if ( - not isinstance(self.data, dict) - or self._data_dict_key_job_id not in self.data - ): - return -1 - else: - return self.data[self._data_dict_key_job_id] + return UNSUCCESSFUL_JOB @property def output_data_id(self) -> Optional[str]: @@ -121,13 +100,13 @@ def output_data_id(self) -> Optional[str]: Optional[str] The 'data_id' of the output dataset for requested job, if request was successful; otherwise ``None``. """ - if ( - not isinstance(self.data, dict) - or self._data_dict_key_output_data_id not in self.data - ): - return None - else: - return self.data[self._data_dict_key_output_data_id] + if isinstance(self.data, ModelExecRequestResponseBody): + return self.data.output_data_id + + elif isinstance(self.data, dict) and "output_data_id" in self.data: + return self.data["output_data_id"] + + return None @property def reason_enum(self): From 38db3873c5dbba50a31bcfd3b04ddff32004ee7f Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 20 Jan 2023 17:17:25 -0500 Subject: [PATCH 116/205] add DataTransmitUUID class --- .../communication/data_transmit_message.py | 40 ++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/python/lib/communication/dmod/communication/data_transmit_message.py b/python/lib/communication/dmod/communication/data_transmit_message.py index 1a52229d8..c421376ae 100644 --- a/python/lib/communication/dmod/communication/data_transmit_message.py +++ b/python/lib/communication/dmod/communication/data_transmit_message.py @@ -1,3 +1,4 @@ +from dmod.core.serializable import Serializable from .message import AbstractInitRequest, MessageEventType, Response from pydantic import Field from typing import ClassVar, Dict, Optional, Union @@ -5,7 +6,44 @@ from uuid import UUID -class DataTransmitMessage(AbstractInitRequest): +class DataTransmitUUID(Serializable): + series_uuid: UUID = Field(description="A unique id for the collective series of transmission message this instance is a part of.") + """ + The expectation is that a larger amount of data will be broken up into multiple messages in a series. + """ + + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> Dict[str, Union[str, int]]: + SERIES_UUID_KEY = "series_uuid" + exclude = exclude or set() + series_uuid_in_exclude = SERIES_UUID_KEY in exclude + exclude.add(SERIES_UUID_KEY) + + serial = super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + if not series_uuid_in_exclude: + serial[SERIES_UUID_KEY] = str(self.series_uuid) + + return serial + +class DataTransmitMessage(DataTransmitUUID, AbstractInitRequest): """ Specialized message type for transmitting data. From 16e192622cc3a2b75e3e8d8fe86a93be77a029f1 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 20 Jan 2023 17:19:22 -0500 Subject: [PATCH 117/205] refactor DataTransmitMessage to use DataTransmitUUID --- .../communication/data_transmit_message.py | 35 ------------------- 1 file changed, 35 deletions(-) diff --git a/python/lib/communication/dmod/communication/data_transmit_message.py b/python/lib/communication/dmod/communication/data_transmit_message.py index c421376ae..3e7571419 100644 --- a/python/lib/communication/dmod/communication/data_transmit_message.py +++ b/python/lib/communication/dmod/communication/data_transmit_message.py @@ -60,43 +60,8 @@ class DataTransmitMessage(DataTransmitUUID, AbstractInitRequest): event_type: ClassVar[MessageEventType] = MessageEventType.DATA_TRANSMISSION data: str = Field(description="The data carried by this message, in decoded string form.") - series_uuid: UUID = Field(description="A unique id for the collective series of transmission message this instance is a part of.") - """ - The expectation is that a larger amount of data will be broken up into multiple messages in a series. - """ is_last: bool = Field(False, description="Whether this is the last data transmission message in this series.") - def dict( - self, - *, - include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - by_alias: bool = True, # Note this follows Serializable convention - skip_defaults: Optional[bool] = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False - ) -> Dict[str, Union[str, int]]: - SERIES_UUID_KEY = "series_uuid" - exclude = exclude or set() - series_uuid_in_exclude = SERIES_UUID_KEY in exclude - exclude.add(SERIES_UUID_KEY) - - serial = super().dict( - include=include, - exclude=exclude, - by_alias=by_alias, - skip_defaults=skip_defaults, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) - - if not series_uuid_in_exclude: - serial[SERIES_UUID_KEY] = str(self.series_uuid) - - return serial - class DataTransmitResponse(Response): """ From a5cdd691faa1601303e74671c1279cbd632ecef3 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 20 Jan 2023 17:31:22 -0500 Subject: [PATCH 118/205] refactor DataTransmitResponse --- .../communication/data_transmit_message.py | 57 ++++++++----------- 1 file changed, 25 insertions(+), 32 deletions(-) diff --git a/python/lib/communication/dmod/communication/data_transmit_message.py b/python/lib/communication/dmod/communication/data_transmit_message.py index 3e7571419..9d5b2c7da 100644 --- a/python/lib/communication/dmod/communication/data_transmit_message.py +++ b/python/lib/communication/dmod/communication/data_transmit_message.py @@ -1,7 +1,8 @@ from dmod.core.serializable import Serializable +from pydantic import Extra from .message import AbstractInitRequest, MessageEventType, Response from pydantic import Field -from typing import ClassVar, Dict, Optional, Union +from typing import ClassVar, Dict, Optional, Type, Union from numbers import Number from uuid import UUID @@ -63,6 +64,12 @@ class DataTransmitMessage(DataTransmitUUID, AbstractInitRequest): is_last: bool = Field(False, description="Whether this is the last data transmission message in this series.") +class DataTransmitResponseBody(DataTransmitUUID): + + class Config: + extra = Extra.allow + + class DataTransmitResponse(Response): """ A ::class:`Response` subtype corresponding to ::class:`DataTransmitMessage`. @@ -71,38 +78,24 @@ class DataTransmitResponse(Response): series of which it is a part. """ - response_to_type = DataTransmitMessage - - _KEY_SERIES_UUID = response_to_type._KEY_SERIES_UUID - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict) -> 'DataTransmitResponse': - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. - - Parameters - ---------- - json_obj - - Returns - ------- - response_obj : Response - A new object of this type instantiated from the deserialize JSON object dictionary, or none if the provided - parameter could not be used to instantiated a new object. - """ - try: - return cls(success=json_obj['success'], reason=json_obj['reason'], message=json_obj['message'], - series_uuid=json_obj['data'][cls._KEY_SERIES_UUID], data=json_obj['data']) - except Exception as e: - return None - - def __init__(self, series_uuid: Union[str, UUID], *args, **kwargs): - if 'data' not in kwargs: - kwargs['data'] = dict() - kwargs['data'][self._KEY_SERIES_UUID] = str(series_uuid) - super(DataTransmitResponse, self).__init__(*args, **kwargs) + response_to_type: ClassVar[Type[AbstractInitRequest]] = DataTransmitMessage + + data: DataTransmitResponseBody + + # `series_uuid` required in prior version of code + def __init__(self, series_uuid: Union[str, UUID] = None, **kwargs): + # assume no need for backwards compatibility + if series_uuid is None: + super().__init__(**kwargs) + return + + if "data" not in kwargs: + kwargs["data"] = dict() + + kwargs["data"]["series_uuid"] = series_uuid + super().__init__(**kwargs) @property def series_uuid(self) -> UUID: - return UUID(self.data[self._KEY_SERIES_UUID]) + return self.data.series_uuid From 687e3175609ffa7e3f8a56fe47e927ca57e887a8 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 20 Jan 2023 18:45:36 -0500 Subject: [PATCH 119/205] refactor DatasetManagementResponse. speperated its body into own data model --- .../dataset_management_message.py | 80 +++++++++---------- 1 file changed, 38 insertions(+), 42 deletions(-) diff --git a/python/lib/communication/dmod/communication/dataset_management_message.py b/python/lib/communication/dmod/communication/dataset_management_message.py index 03e88b43d..9b5949c2d 100644 --- a/python/lib/communication/dmod/communication/dataset_management_message.py +++ b/python/lib/communication/dmod/communication/dataset_management_message.py @@ -4,8 +4,7 @@ from dmod.core.meta_data import DataCategory, DataDomain, DataFormat, DataRequirement from dmod.core.enum import PydanticEnum from pydantic import root_validator, Field -from numbers import Number -from typing import ClassVar, Dict, Optional, Union, List +from typing import Any, ClassVar, Dict, Optional, Type, Union, List class QueryType(PydanticEnum): @@ -267,41 +266,45 @@ def __init__( ) +class DatasetManagementResponseBody(Serializable): + action: Optional[ManagementAction] + data_id: Optional[str] + dataset_name: Optional[str] + item_name: Optional[str] + # TODO: in the future, tighten the type restrictions of this field + query_results: Optional[Dict[str, Any]] + is_awaiting: bool = False + + class DatasetManagementResponse(Response): - _DATA_KEY_ACTION= 'action' - _DATA_KEY_DATA_ID = 'data_id' - _DATA_KEY_DATASET_NAME = 'dataset_name' - _DATA_KEY_ITEM_NAME = 'item_name' - _DATA_KEY_QUERY_RESULTS = 'query_results' - _DATA_KEY_IS_AWAITING = 'is_awaiting' - response_to_type = DatasetManagementMessage + response_to_type: ClassVar[Type[AbstractInitRequest]] = DatasetManagementMessage - def __init__(self, action: Optional[ManagementAction] = None, is_awaiting: bool = False, - data_id: Optional[str] = None, dataset_name: Optional[str] = None, data: Optional[dict] = None, - **kwargs): - if data is None: - data = {} + data: DatasetManagementResponseBody + + def __init__( + self, + action: Optional[ManagementAction] = None, + is_awaiting: bool = False, + data_id: Optional[str] = None, + dataset_name: Optional[str] = None, + data: Optional[Union[dict, DatasetManagementResponseBody]] = None, + **kwargs + ): + data = data if isinstance(data, DatasetManagementResponseBody) else DatasetManagementResponseBody(**data or {}) # Make sure 'action' param and action string within 'data' param aren't both present and conflicting if action is not None: - if action.name != data.get(self._DATA_KEY_ACTION, action.name): + if action != data.action: msg = '{} initialized with {} action param, but {} action in initial data.' - raise ValueError(msg.format(self.__class__.__name__, action.name, data.get(self._DATA_KEY_ACTION))) - data[self._DATA_KEY_ACTION] = action.name - # Additionally, if not using an explicit 'action', make sure it's a valid action string in 'data', or bail - else: - data_action_str = data.get(self._DATA_KEY_ACTION, '') - # Compare the string to the 'name' string of the action value obtain by passing the string to get_for_name() - if data_action_str.strip().upper() != ManagementAction.get_for_name(data_action_str).name.upper(): - msg = "No valid action param or within 'data' when initializing {} instance (received only '{}')" - raise ValueError(msg.format(self.__class__.__name__, data_action_str)) - - data[self._DATA_KEY_IS_AWAITING] = is_awaiting + raise ValueError(msg.format(self.__class__.__name__, action.name, data.action.name if data.action else data.action)) + data.action = action + + data.is_awaiting = is_awaiting if data_id is not None: - data[self._DATA_KEY_DATA_ID] = data_id + data.data_id = data_id if dataset_name is not None: - data[self._DATA_KEY_DATASET_NAME] = dataset_name + data.dataset_name = dataset_name super().__init__(data=data, **kwargs) @property @@ -314,16 +317,9 @@ def action(self) -> ManagementAction: ManagementAction The action requested by the ::class:`DatasetManagementMessage` for which this instance is the response. """ - if self._DATA_KEY_ACTION not in self.data: - return ManagementAction.UNKNOWN - elif isinstance(self.data[self._DATA_KEY_ACTION], str): - return ManagementAction.get_for_name(self.data[self._DATA_KEY_ACTION]) - elif isinstance(self.data[self._DATA_KEY_ACTION], ManagementAction): - val = self.data[self._DATA_KEY_ACTION] - self.data[self._DATA_KEY_ACTION] = val.name - return val - else: + if self.data.action is None: return ManagementAction.UNKNOWN + return self.data.action @property def data_id(self) -> Optional[str]: @@ -335,7 +331,7 @@ def data_id(self) -> Optional[str]: Optional[str] When available, the 'data_id' of the related dataset. """ - return self.data[self._DATA_KEY_DATA_ID] if self._DATA_KEY_DATA_ID in self.data else None + return self.data.data_id @property def dataset_name(self) -> Optional[str]: @@ -347,7 +343,7 @@ def dataset_name(self) -> Optional[str]: Optional[str] When available, the name of the relevant dataset; otherwise ``None``. """ - return self.data[self._DATA_KEY_DATASET_NAME] if self._DATA_KEY_DATASET_NAME in self.data else None + return self.data.dataset_name @property def item_name(self) -> Optional[str]: @@ -359,11 +355,11 @@ def item_name(self) -> Optional[str]: Optional[str] The name of the relevant dataset item/object/file, or ``None``. """ - return self.data.get(self._DATA_KEY_ITEM_NAME) + return self.data.item_name @property def query_results(self) -> Optional[dict]: - return self.data.get(self._DATA_KEY_QUERY_RESULTS) + return self.data.query_results @property def is_awaiting(self) -> bool: @@ -379,7 +375,7 @@ def is_awaiting(self) -> bool: bool Whether the response indicates the response sender is awaiting something additional. """ - return self.data[self._DATA_KEY_IS_AWAITING] + return self.data.is_awaiting class MaaSDatasetManagementMessage(DatasetManagementMessage, ExternalRequest): From 170a66faca3c6dbcce8cb56b28520d8b0db729d3 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 20 Jan 2023 18:46:53 -0500 Subject: [PATCH 120/205] add type hints to MaaSDatasetManagementResponse class vars --- .../dmod/communication/dataset_management_message.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/communication/dmod/communication/dataset_management_message.py b/python/lib/communication/dmod/communication/dataset_management_message.py index 9b5949c2d..bc396dc8e 100644 --- a/python/lib/communication/dmod/communication/dataset_management_message.py +++ b/python/lib/communication/dmod/communication/dataset_management_message.py @@ -463,7 +463,7 @@ class MaaSDatasetManagementResponse(ExternalRequestResponse, DatasetManagementRe Analog of ::class:`DatasetManagementResponse`, but for the ::class:`MaaSDatasetManagementMessage` message type. """ - response_to_type = MaaSDatasetManagementMessage + response_to_type: ClassVar[Type[AbstractInitRequest]] = MaaSDatasetManagementMessage @classmethod def factory_create(cls, dataset_mgmt_response: DatasetManagementResponse) -> 'MaaSDatasetManagementResponse': From dfa37763e880181667b16164cfb2b00f80688c58 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 20 Jan 2023 19:06:17 -0500 Subject: [PATCH 121/205] refactor PartitionResponse --- .../dmod/communication/partition_request.py | 72 ++++++++++++++----- 1 file changed, 54 insertions(+), 18 deletions(-) diff --git a/python/lib/communication/dmod/communication/partition_request.py b/python/lib/communication/dmod/communication/partition_request.py index a2d0c8eae..5437cf4f2 100644 --- a/python/lib/communication/dmod/communication/partition_request.py +++ b/python/lib/communication/dmod/communication/partition_request.py @@ -1,7 +1,7 @@ from uuid import uuid4 -from numbers import Number -from pydantic import Field -from typing import ClassVar, Dict, Optional, Union +from pydantic import Field, Extra +from typing import ClassVar, Dict, Optional, Type, Union +from dmod.core.serializable import Serializable from .message import AbstractInitRequest, MessageEventType, Response from .maas_request import ExternalRequest @@ -116,30 +116,43 @@ def dict( return serial +class PartitionResponseBody(Serializable): + data_id: Optional[str] + dataset_name: Optional[str] + + class Config: + extra = Extra.allow + + def __contains__(self, key: str) -> bool: + return key in self.__dict__ + + def __getitem__(self, key: str): + return self.__dict__[key] + class PartitionResponse(Response): """ A response to a ::class:`PartitionRequest`. A successful response will contain the serialized partition representation within the ::attribute:`data` property. """ - _DATA_KEY_DATASET_DATA_ID = 'data_id' - _DATA_KEY_DATASET_NAME = 'dataset_name' - response_to_type = PartitionRequest + data: PartitionResponseBody + + response_to_type: ClassVar[Type[AbstractInitRequest]] = PartitionRequest @classmethod def factory_create(cls, dataset_name: Optional[str], dataset_data_id: Optional[str], reason: str, message: str = '', data: Optional[dict] = None): - data_dict = {cls._DATA_KEY_DATASET_DATA_ID: dataset_data_id, cls._DATA_KEY_DATASET_NAME: dataset_name} + data_dict = {"data_id": dataset_data_id, "dataset_name": dataset_name} if data is not None: data_dict.update(data) return cls(success=(dataset_data_id is not None), reason=reason, message=message, data=data_dict) - def __init__(self, success: bool, reason: str, message: str = '', data: Optional[dict] = None): - if data is None: - data = {} + def __init__(self, success: bool, reason: str, message: str = '', data: Optional[Union[dict, PartitionResponseBody]] = None): + data = data if isinstance(data, PartitionResponseBody) else PartitionResponseBody(**data or {}) + if not success: - data[self._DATA_KEY_DATASET_DATA_ID] = None - data[self._DATA_KEY_DATASET_NAME] = None + data.data_id = None + data.dataset_name = None super().__init__(success=success, reason=reason, message=message, data=data) @property @@ -152,7 +165,7 @@ def dataset_data_id(self) -> Optional[str]: Optional[str] The 'data_id' of the dataset where the partition config is saved when requests are successful. """ - return self.data[self._DATA_KEY_DATASET_DATA_ID] + return self.data.data_id @property def dataset_name(self) -> Optional[str]: @@ -164,11 +177,34 @@ def dataset_name(self) -> Optional[str]: Optional[str] The name of the dataset where the partitioning config is saved when requests are successful. """ - return self.data[self._DATA_KEY_DATASET_NAME] + return self.data.dataset_name + + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> Dict[str, Union[str, int]]: + class_name_in_exclude = exclude is not None and "class_name" in exclude + + serial = super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + if not class_name_in_exclude: + serial["class_name"] = self.__class__.__name__ - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serial = super(PartitionResponse, self).to_dict() - serial['class_name'] = self.__class__.__name__ return serial @@ -181,4 +217,4 @@ class Config: class PartitionExternalResponse(PartitionResponse): - response_to_type = PartitionExternalRequest + response_to_type: ClassVar[Type[AbstractInitRequest]] = PartitionExternalRequest From 5a44b3424f8499b16b57b8c3dac4ef677aca3f2e Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 16:57:08 -0500 Subject: [PATCH 122/205] fix faulty logic in SchedulerRequestMessage dict method --- .../lib/communication/dmod/communication/scheduler_request.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/communication/dmod/communication/scheduler_request.py b/python/lib/communication/dmod/communication/scheduler_request.py index 250377766..a01442513 100644 --- a/python/lib/communication/dmod/communication/scheduler_request.py +++ b/python/lib/communication/dmod/communication/scheduler_request.py @@ -132,7 +132,7 @@ def dict( exclude_none: bool = False ) -> Dict[str, Union[str, int]]: # Only including memory value in serial form if it was explicitly set in the first place - if not self.memory_unset: + if self.memory_unset: exclude = {"memory"} if exclude is None else {"memory", *exclude} return super().dict( From 604c2b3d0ff3e8154c2d5a376897735832ac9cf5 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 17:00:52 -0500 Subject: [PATCH 123/205] fix typo in DatasetQuery field name --- .../dmod/communication/dataset_management_message.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/communication/dmod/communication/dataset_management_message.py b/python/lib/communication/dmod/communication/dataset_management_message.py index bc396dc8e..9b14f2485 100644 --- a/python/lib/communication/dmod/communication/dataset_management_message.py +++ b/python/lib/communication/dmod/communication/dataset_management_message.py @@ -41,7 +41,7 @@ def get_for_name(cls, name_str: str) -> 'QueryType': class DatasetQuery(Serializable): - query_file: QueryType + query_type: QueryType def __hash__(self): return hash(self.query_type) From 0c41a645c1a4d023928701f75275e57b4be8a782 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 17:01:53 -0500 Subject: [PATCH 124/205] update dataset query tests --- python/lib/communication/dmod/test/test_dataset_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/communication/dmod/test/test_dataset_query.py b/python/lib/communication/dmod/test/test_dataset_query.py index f313e97fe..8334e953e 100644 --- a/python/lib/communication/dmod/test/test_dataset_query.py +++ b/python/lib/communication/dmod/test/test_dataset_query.py @@ -10,7 +10,7 @@ def setUp(self) -> None: self.examples = [] self.ex_query_types.append(QueryType.LIST_FILES) - self.ex_json_data.append({DatasetQuery._KEY_QUERY_TYPE: 'LIST_FILES'}) + self.ex_json_data.append({"query_type": 'LIST_FILES'}) self.examples.append(DatasetQuery(query_type=QueryType.LIST_FILES)) def test_factory_init_from_deserialized_json_0_a(self): From 19962667d165f1e46f97d74f5c1771ab619d5a2c Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 17:04:05 -0500 Subject: [PATCH 125/205] factory init model exec requests handles nwm and ngen cases. this should be revisited in the future --- .../dmod/communication/maas_request/model_exec_request.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/lib/communication/dmod/communication/maas_request/model_exec_request.py b/python/lib/communication/dmod/communication/maas_request/model_exec_request.py index af12a97d0..b98c3722a 100644 --- a/python/lib/communication/dmod/communication/maas_request/model_exec_request.py +++ b/python/lib/communication/dmod/communication/maas_request/model_exec_request.py @@ -58,7 +58,10 @@ def factory_init_correct_subtype_from_deserialized_json( A deserialized ::class:`ModelExecRequest` of the appropriate subtype. """ try: - model_name = json_obj["model"]["name"] + model = json_obj["model"] + + # TODO: remove logic once `nwm` ModelExecRequest changes where it store the model name. + model_name = model["name"] if "name" in model else "nwm" models = get_available_models() return models[model_name].factory_init_from_deserialized_json(json_obj) From c8ad6eefb89f0cea43c593bb6aeed50f68a05f6c Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 17:04:47 -0500 Subject: [PATCH 126/205] SchedulerRequestMessage factory inits ModelExecRequest correctly --- .../communication/dmod/communication/scheduler_request.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/lib/communication/dmod/communication/scheduler_request.py b/python/lib/communication/dmod/communication/scheduler_request.py index a01442513..25b8215b5 100644 --- a/python/lib/communication/dmod/communication/scheduler_request.py +++ b/python/lib/communication/dmod/communication/scheduler_request.py @@ -2,7 +2,7 @@ from .maas_request import ModelExecRequest from .message import AbstractInitRequest, MessageEventType, Response from .scheduler_request_response_body import SchedulerRequestResponseBody, UNSUCCESSFUL_JOB -from pydantic import Field, PrivateAttr +from pydantic import Field, PrivateAttr, validator from typing import ClassVar, Dict, Optional, Type, Union class SchedulerRequestMessage(AbstractInitRequest): @@ -18,6 +18,12 @@ class SchedulerRequestMessage(AbstractInitRequest): _memory_unset: bool = PrivateAttr() + @validator("model_request", pre=True) + def _factory_init_model_request(cls, value): + if isinstance(value, ModelExecRequest): + return value + return ModelExecRequest.factory_init_correct_subtype_from_deserialized_json(value) + class Config: fields = { "memory": {"alias": "mem"}, From ec2691e81617c2575048179e93d5801e8a96c82c Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 17:07:02 -0500 Subject: [PATCH 127/205] allow SchedulerRequestResponse instances to be treated as dictionary --- .../dmod/communication/scheduler_request.py | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/python/lib/communication/dmod/communication/scheduler_request.py b/python/lib/communication/dmod/communication/scheduler_request.py index 25b8215b5..37adae36f 100644 --- a/python/lib/communication/dmod/communication/scheduler_request.py +++ b/python/lib/communication/dmod/communication/scheduler_request.py @@ -154,7 +154,8 @@ def dict( class SchedulerRequestResponse(Response): response_to_type: ClassVar[Type[AbstractInitRequest]] = SchedulerRequestMessage - data: SchedulerRequestResponseBody + + data: Union[SchedulerRequestResponseBody, Dict[None, None], None] def __init__(self, job_id: Optional[int] = None, output_data_id: Optional[str] = None, data: dict = None, **kwargs): # TODO: how to handle if kwargs has success=True, but job_id value (as param or in data) implies success=False @@ -172,13 +173,11 @@ def __init__(self, job_id: Optional[int] = None, output_data_id: Optional[str] = if output_data_id is not None: data["output_data_id"] = output_data_id - data_body = SchedulerRequestResponseBody(**data if data is not None else {}) - # Ensure that 'success' is being passed as a kwarg to the superclass constructor if "success" not in kwargs: - kwargs["success"] = data is not None and data_body.job_id > 0 + kwargs["success"] = data is not None and "job_id" in data and data["job_id"] > 0 - super().__init__(data=data_body, **kwargs) + super().__init__(data=data, **kwargs) def __eq__(self, other): return self.__class__ == other.__class__ and self.success == other.success and self.job_id == other.job_id @@ -202,3 +201,20 @@ def output_data_id(self) -> Optional[str]: The 'data_id' of the output dataset for requested job, or ``None`` if not known. """ return self.data.output_data_id + + @classmethod + def factory_init_from_deserialized_json(cls, json_obj: dict) -> "SchedulerRequestResponse": + # TODO: remove in future. necessary for backwards compatibility + if isinstance(json_obj, SchedulerRequestResponse): + return json_obj + + return super().factory_init_from_deserialized_json(json_obj=json_obj) + + # NOTE: legacy support. previously this class was treated as a dictionary + def __contains__(self, element: str) -> bool: + return element in self.__dict__ + + # NOTE: legacy support. previously this class was treated as a dictionary + def __getitem__(self, item: str): + return self.__dict__[item] + From 268068382dbd6d19c2e1cc7b1b0c5f8a05110579 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 17:07:37 -0500 Subject: [PATCH 128/205] allow SchedulerRequestResponseBody to be treated like it is a dictionary --- .../scheduler_request_response_body.py | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/python/lib/communication/dmod/communication/scheduler_request_response_body.py b/python/lib/communication/dmod/communication/scheduler_request_response_body.py index e068f2fa4..d4accd642 100644 --- a/python/lib/communication/dmod/communication/scheduler_request_response_body.py +++ b/python/lib/communication/dmod/communication/scheduler_request_response_body.py @@ -2,7 +2,10 @@ from dmod.core.serializable import Serializable -from typing import Optional +from typing import Optional, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from pydantic.typing import AbstractSetIntStr, MappingIntStrAny, DictStrAny UNSUCCESSFUL_JOB = -1 @@ -14,3 +17,36 @@ class SchedulerRequestResponseBody(Serializable): class Config: # allow extra model fields extra = Extra.allow + + def __eq__(self, other: object): + if isinstance(other, dict): + return self.to_dict() == other + return super().__eq__(other) + + def __contains__(self, key: str) -> bool: + return key in self.__dict__ + + def __getattr__(self, key: str): + return self.__dict__[key] + + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = True, # noop + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> "DictStrAny": + # Note: for backwards compatibility, unset fields are excluded by default + return super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=True, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) From 6c77e74cc5642aa084616fb4f801141deaf493af Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 17:37:18 -0500 Subject: [PATCH 129/205] add missing required fields in decorated interface tests --- .../lib/communication/dmod/test/test_decorated_interface.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/lib/communication/dmod/test/test_decorated_interface.py b/python/lib/communication/dmod/test/test_decorated_interface.py index 48524ac1b..9df30d1c1 100644 --- a/python/lib/communication/dmod/test/test_decorated_interface.py +++ b/python/lib/communication/dmod/test/test_decorated_interface.py @@ -153,7 +153,10 @@ def setUp(self): "model": { "nwm": { "config_data_id": "1", - "data_requirements": [{"domain": {"data_format": "NWM_CONFIG", "continuous": [], + "data_requirements": [{ + "category": "CONFIG", + "is_input": True, + "domain": {"data_format": "NWM_CONFIG", "continuous": [], "discrete": [{"variable": "data_id", "values": ["1"]}]}}] } }, From aa5c579144f7a7ebdcccd5eabe0374ac95e92153 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 17:41:24 -0500 Subject: [PATCH 130/205] add missing required fields to websocket interface tests --- .../lib/communication/dmod/test/test_websocket_interface.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/lib/communication/dmod/test/test_websocket_interface.py b/python/lib/communication/dmod/test/test_websocket_interface.py index 3fb8a973f..5f8adc0e6 100644 --- a/python/lib/communication/dmod/test/test_websocket_interface.py +++ b/python/lib/communication/dmod/test/test_websocket_interface.py @@ -155,7 +155,10 @@ def setUp(self): "allocation_paradigm": "ROUND_ROBIN", "config_data_id": "1", "cpu_count": 2, - "data_requirements": [{"domain": {"data_format": "NWM_CONFIG", "continuous": [], + "data_requirements": [{ + "category": "CONFIG", + "is_input": True, + "domain": {"data_format": "NWM_CONFIG", "continuous": [], "discrete": [{"variable": "data_id", "values": ["1"]}]}}] } }, From c4d1af418f2f7db746fa17aa359c3fa46b774aa6 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 21:31:46 -0500 Subject: [PATCH 131/205] use relative imports in tests. this avoids, potentially unexpected behaviour when asserting isinstance --- .../lib/communication/dmod/test/test_decorated_interface.py | 4 ++-- .../communication/dmod/test/test_ngen_request_response.py | 1 + .../lib/communication/dmod/test/test_nwm_request_response.py | 1 + .../lib/communication/dmod/test/test_websocket_interface.py | 5 +++-- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/lib/communication/dmod/test/test_decorated_interface.py b/python/lib/communication/dmod/test/test_decorated_interface.py index 9df30d1c1..0850c1a97 100644 --- a/python/lib/communication/dmod/test/test_decorated_interface.py +++ b/python/lib/communication/dmod/test/test_decorated_interface.py @@ -7,8 +7,8 @@ import sys import unittest from ..communication.message import MessageEventType -from dmod.communication import ModelExecRequest, SessionInitMessage -from dmod.communication.dataset_management_message import MaaSDatasetManagementMessage +from ..communication import ModelExecRequest, SessionInitMessage +from ..communication.dataset_management_message import MaaSDatasetManagementMessage from ..communication.websocket_interface import NoOpHandler from pathlib import Path from socket import gethostname diff --git a/python/lib/communication/dmod/test/test_ngen_request_response.py b/python/lib/communication/dmod/test/test_ngen_request_response.py index 25d3c24e0..df36ed086 100644 --- a/python/lib/communication/dmod/test/test_ngen_request_response.py +++ b/python/lib/communication/dmod/test/test_ngen_request_response.py @@ -1,6 +1,7 @@ import json import unittest from ..communication.maas_request import NGENRequestResponse +from ..communication.maas_request.model_exec_request_response_body import ModelExecRequestResponseBody from ..communication.message import InitRequestResponseReason from ..communication.scheduler_request import SchedulerRequestResponse diff --git a/python/lib/communication/dmod/test/test_nwm_request_response.py b/python/lib/communication/dmod/test/test_nwm_request_response.py index 848ddc709..76908e570 100644 --- a/python/lib/communication/dmod/test/test_nwm_request_response.py +++ b/python/lib/communication/dmod/test/test_nwm_request_response.py @@ -1,6 +1,7 @@ import json import unittest from ..communication.maas_request import NWMRequestResponse +from ..communication.maas_request.model_exec_request_response_body import ModelExecRequestResponseBody from ..communication.message import InitRequestResponseReason from ..communication.scheduler_request import SchedulerRequestResponse diff --git a/python/lib/communication/dmod/test/test_websocket_interface.py b/python/lib/communication/dmod/test/test_websocket_interface.py index 5f8adc0e6..b3908b6fb 100644 --- a/python/lib/communication/dmod/test/test_websocket_interface.py +++ b/python/lib/communication/dmod/test/test_websocket_interface.py @@ -7,8 +7,9 @@ import sys import unittest from ..communication.message import MessageEventType -from dmod.communication import ModelExecRequest, SessionInitMessage -from dmod.communication.dataset_management_message import MaaSDatasetManagementMessage +from ..communication.maas_request import ModelExecRequest +from ..communication.session import SessionInitMessage +from ..communication.dataset_management_message import MaaSDatasetManagementMessage from ..communication.websocket_interface import NoOpHandler from pathlib import Path from socket import gethostname From 6809ecf887fcf06ab75c748025a9f8363e8b239a Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 21:33:32 -0500 Subject: [PATCH 132/205] change assert eq types in ngen request response tests --- .../lib/communication/dmod/test/test_ngen_request_response.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/lib/communication/dmod/test/test_ngen_request_response.py b/python/lib/communication/dmod/test/test_ngen_request_response.py index df36ed086..96e63dd15 100644 --- a/python/lib/communication/dmod/test/test_ngen_request_response.py +++ b/python/lib/communication/dmod/test/test_ngen_request_response.py @@ -96,7 +96,7 @@ def test_factory_init_from_deserialized_json_2_e(self): the expected dictionary value for ``data``. """ obj = NGENRequestResponse.factory_init_from_deserialized_json(self.response_jsons[2]) - self.assertEqual(obj.data.__class__, dict) + self.assertEqual(obj.data.__class__, ModelExecRequestResponseBody) def test_factory_init_from_deserialized_json_2_f(self): """ @@ -128,7 +128,7 @@ def test_factory_init_from_deserialized_json_2_i(self): the expected dictionary value for ``data``, with the ``scheduler_response`` being of the right type. """ obj = NGENRequestResponse.factory_init_from_deserialized_json(self.response_jsons[2]) - self.assertEqual(obj.data['scheduler_response'].__class__, dict) + self.assertEqual(obj.data['scheduler_response'].__class__, SchedulerRequestResponse) def test_factory_init_from_deserialized_json_2_j(self): """ From c9430060ab1efd3c0b904860497e9241c6144c6e Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 21:35:32 -0500 Subject: [PATCH 133/205] change assert eq types in nwn request response tests --- .../dmod/test/test_nwm_request_response.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/lib/communication/dmod/test/test_nwm_request_response.py b/python/lib/communication/dmod/test/test_nwm_request_response.py index 76908e570..d3a8e224e 100644 --- a/python/lib/communication/dmod/test/test_nwm_request_response.py +++ b/python/lib/communication/dmod/test/test_nwm_request_response.py @@ -93,10 +93,11 @@ def test_factory_init_from_deserialized_json_2_d(self): def test_factory_init_from_deserialized_json_2_e(self): """ Test ``factory_init_from_deserialized_json()`` on raw string example 2 to make sure the deserialized object has - the expected dictionary value for ``data``. + the expected ModelExecRequestResponseBody value for ``data``. For legacy support, this can still be + treated like a dictionary. """ obj = NWMRequestResponse.factory_init_from_deserialized_json(self.response_jsons[2]) - self.assertEqual(obj.data.__class__, dict) + self.assertEqual(obj.data.__class__, ModelExecRequestResponseBody) def test_factory_init_from_deserialized_json_2_f(self): """ @@ -125,10 +126,11 @@ def test_factory_init_from_deserialized_json_2_h(self): def test_factory_init_from_deserialized_json_2_i(self): """ Test ``factory_init_from_deserialized_json()`` on raw string example 2 to make sure the deserialized object has - the expected dictionary value for ``data``, with the ``scheduler_response`` being of the right type. + the expected SchedulerRequestResponse value for ``data``, with the ``scheduler_response`` being of the right type. + For legacy support, ``SchedulerRequestResponse`` can still be treated as a dictionary. """ obj = NWMRequestResponse.factory_init_from_deserialized_json(self.response_jsons[2]) - self.assertEqual(obj.data['scheduler_response'].__class__, dict) + self.assertEqual(obj.data['scheduler_response'].__class__, SchedulerRequestResponse) def test_factory_init_from_deserialized_json_2_j(self): """ From 2ef41ab4e3d7a30dfa47c066b74f2ce09ba9ef98 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 21:37:16 -0500 Subject: [PATCH 134/205] add classvar type hints to FieldedMessage --- .../dmod/communication/registered/registered_message.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/lib/communication/dmod/communication/registered/registered_message.py b/python/lib/communication/dmod/communication/registered/registered_message.py index 12c20233e..c80f58982 100644 --- a/python/lib/communication/dmod/communication/registered/registered_message.py +++ b/python/lib/communication/dmod/communication/registered/registered_message.py @@ -5,8 +5,7 @@ import abc import typing from numbers import Number -from typing import Dict -from typing import Union +from typing import ClassVar, Dict, Union from ..message import AbstractInitRequest from ..message import MessageEventType @@ -296,7 +295,7 @@ class FieldedMessage(AbstractInitRequest): """ A message formed by dictated fields coming from subclasses """ - event_type: MessageEventType = MessageEventType.INFORMATION_UPDATE + event_type: ClassVar[MessageEventType] = MessageEventType.INFORMATION_UPDATE """ The event type for this message; this shouldn't have as much bearing on how to handle this message. Use members and class type instead. From f994fdcf03e62acf6695a5a40703634f73a34b90 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 22:20:56 -0500 Subject: [PATCH 135/205] refactor Session dict override with Config field_serializers --- .../dmod/communication/session.py | 42 ++++--------------- 1 file changed, 9 insertions(+), 33 deletions(-) diff --git a/python/lib/communication/dmod/communication/session.py b/python/lib/communication/dmod/communication/session.py index 746088b57..f0db86dc4 100644 --- a/python/lib/communication/dmod/communication/session.py +++ b/python/lib/communication/dmod/communication/session.py @@ -67,6 +67,15 @@ def validate_date(cls, value): except: return datetime.datetime.now() + class Config: + def _serialize_datetime(self: "Session", value: datetime.datetime) -> str: + return value.strftime(self.get_datetime_str_format()) + + field_serializers = { + "created": _serialize_datetime, + "last_accessed": _serialize_datetime, + } + @classmethod def get_datetime_str_format(cls): return cls._DATETIME_FORMAT @@ -181,39 +190,6 @@ def is_serialized_attribute(self, attribute: str) -> bool: return False return attribute in self.__fields__ - def dict( - self, - *, - include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - by_alias: bool = True, # Note this follows Serializable convention - skip_defaults: Optional[bool] = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False - ) -> Dict[str, Union[str, int]]: - _exclude = {"created", "last_accessed"} - if exclude is not None: - _exclude = {*_exclude, *exclude} - - serial = super().dict( - include=include, - exclude=_exclude, - by_alias=by_alias, - skip_defaults=skip_defaults, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) - - if exclude is None or "created" not in exclude: - serial["created"] = self.created.strftime(self.get_datetime_str_format()) - - if exclude is None or "last_accessed" not in exclude: - serial["last_accessed"] = self.last_accessed.strftime(self.get_datetime_str_format()) - - return serial - # TODO: work more on this later, when authentication becomes more important class FullAuthSession(Session): From e0ca55c4d6c828909a0b4b8b15b82155db63a453 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 22:23:04 -0500 Subject: [PATCH 136/205] refactor FailedSessionInitInfo dict override with Config field_serializers --- .../dmod/communication/session.py | 35 +++---------------- 1 file changed, 5 insertions(+), 30 deletions(-) diff --git a/python/lib/communication/dmod/communication/session.py b/python/lib/communication/dmod/communication/session.py index f0db86dc4..d92e91ee8 100644 --- a/python/lib/communication/dmod/communication/session.py +++ b/python/lib/communication/dmod/communication/session.py @@ -263,36 +263,11 @@ class FailedSessionInitInfo(Serializable): def get_datetime_str_format(cls): return Session.get_datetime_str_format() - def dict( - self, - *, - include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - by_alias: bool = True, # Note this follows Serializable convention - skip_defaults: Optional[bool] = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False - ) -> Dict[str, Union[str, int]]: - FAIL_TIME_KEY = "fail_time" - exclude = exclude or set() - fail_time_in_exclude = FAIL_TIME_KEY in exclude - exclude.add(FAIL_TIME_KEY) - - serial = super().dict( - include=include, - exclude=exclude, - by_alias=by_alias, - skip_defaults=skip_defaults, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) - - if not fail_time_in_exclude: - serial[FAIL_TIME_KEY] = self.fail_time.strftime(self.get_datetime_str_format()) - - return serial + class Config: + def _serialize_datetime(self: "Session", value: datetime.datetime) -> str: + return value.strftime(self.get_datetime_str_format()) + + field_serializers = {"fail_time": _serialize_datetime} # Define this custom type here for hinting From 50382e4471065a81517b5cd191ab7d98f3c2d855 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 22:25:59 -0500 Subject: [PATCH 137/205] refactor DataTransmitUUID dict override with Config field_serializers --- .../communication/data_transmit_message.py | 33 ++----------------- 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/python/lib/communication/dmod/communication/data_transmit_message.py b/python/lib/communication/dmod/communication/data_transmit_message.py index 9d5b2c7da..8cec352f9 100644 --- a/python/lib/communication/dmod/communication/data_transmit_message.py +++ b/python/lib/communication/dmod/communication/data_transmit_message.py @@ -13,36 +13,9 @@ class DataTransmitUUID(Serializable): The expectation is that a larger amount of data will be broken up into multiple messages in a series. """ - def dict( - self, - *, - include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - by_alias: bool = True, # Note this follows Serializable convention - skip_defaults: Optional[bool] = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False - ) -> Dict[str, Union[str, int]]: - SERIES_UUID_KEY = "series_uuid" - exclude = exclude or set() - series_uuid_in_exclude = SERIES_UUID_KEY in exclude - exclude.add(SERIES_UUID_KEY) - - serial = super().dict( - include=include, - exclude=exclude, - by_alias=by_alias, - skip_defaults=skip_defaults, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) - - if not series_uuid_in_exclude: - serial[SERIES_UUID_KEY] = str(self.series_uuid) - - return serial + class Config: + field_serializers = {"series_uuid": lambda s: str(s)} + class DataTransmitMessage(DataTransmitUUID, AbstractInitRequest): """ From 1cbb85f6405c0e638c7c2f48fdf5f9288755bfc6 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 22:26:16 -0500 Subject: [PATCH 138/205] remove unused imports --- .../communication/dmod/communication/data_transmit_message.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/lib/communication/dmod/communication/data_transmit_message.py b/python/lib/communication/dmod/communication/data_transmit_message.py index 8cec352f9..b1a77dd94 100644 --- a/python/lib/communication/dmod/communication/data_transmit_message.py +++ b/python/lib/communication/dmod/communication/data_transmit_message.py @@ -2,8 +2,7 @@ from pydantic import Extra from .message import AbstractInitRequest, MessageEventType, Response from pydantic import Field -from typing import ClassVar, Dict, Optional, Type, Union -from numbers import Number +from typing import ClassVar, Type, Union from uuid import UUID From 3911c4fad74fd5e34086cb678fb3aa9fdf56b5af Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 23 Jan 2023 22:28:48 -0500 Subject: [PATCH 139/205] refactor UpdateMessage dict override with Config field_serializers --- .../dmod/communication/update_message.py | 33 ++----------------- 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/python/lib/communication/dmod/communication/update_message.py b/python/lib/communication/dmod/communication/update_message.py index cc136fab6..e2432009f 100644 --- a/python/lib/communication/dmod/communication/update_message.py +++ b/python/lib/communication/dmod/communication/update_message.py @@ -52,6 +52,9 @@ def _validate_updated_data_has_keys(cls, value: Dict[str, str]): raise ValueError("`updated_data` must have at least one key.") return value + class Config: + field_serializers = {"object_type": lambda self, _: self.object_type_string} + @classmethod def get_digest_key(cls) -> str: return cls.__fields__["digest"].alias @@ -72,36 +75,6 @@ def get_updated_data_key(cls) -> str: def object_type_string(self) -> str: return '{}.{}'.format(self.object_type.__module__, self.object_type.__name__) - def dict( - self, - *, - include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - by_alias: bool = True, # Note this follows Serializable convention - skip_defaults: Optional[bool] = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False - ) -> Dict[str, Union[str, int]]: - OBJECT_TYPE_KEY = "object_type" - exclude = exclude or set() - object_type_in_exclude = OBJECT_TYPE_KEY in exclude - exclude.add(OBJECT_TYPE_KEY) - - serial = super().dict( - include=include, - exclude=exclude, - by_alias=by_alias, - skip_defaults=skip_defaults, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) - - if not object_type_in_exclude: - serial[OBJECT_TYPE_KEY] = self.object_type_string - - return serial class UpdateMessageData(Serializable): digest: Optional[str] From 3f35a607866e9d06135df25476f49111e8b17794 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 26 Jan 2023 13:36:29 -0500 Subject: [PATCH 140/205] bump minimum required version of dmod.core to 0.4.2 --- python/lib/communication/setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/lib/communication/setup.py b/python/lib/communication/setup.py index 30289f00d..45edb9a0d 100644 --- a/python/lib/communication/setup.py +++ b/python/lib/communication/setup.py @@ -21,7 +21,6 @@ url='', license='', include_package_data=True, - #install_requires=['websockets', 'jsonschema'],vi - install_requires=['dmod-core>=0.1.2', 'websockets>=8.1', 'jsonschema', 'redis', 'pydantic'], + install_requires=['dmod-core>=0.4.2', 'websockets>=8.1', 'jsonschema', 'redis', 'pydantic'], packages=find_namespace_packages(include=['dmod.*'], exclude=['dmod.test']) ) From dd6b36a8794fe522a90a5d6c9c25c6887338ccef Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 1 Feb 2023 09:30:49 -0500 Subject: [PATCH 141/205] dekabob (i.e. this-that) SchedulerRequestMessage allocation_paradigm field --- .../communication/dmod/communication/scheduler_request.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/lib/communication/dmod/communication/scheduler_request.py b/python/lib/communication/dmod/communication/scheduler_request.py index 37adae36f..043300e29 100644 --- a/python/lib/communication/dmod/communication/scheduler_request.py +++ b/python/lib/communication/dmod/communication/scheduler_request.py @@ -24,6 +24,12 @@ def _factory_init_model_request(cls, value): return value return ModelExecRequest.factory_init_correct_subtype_from_deserialized_json(value) + @validator("allocation_paradigm_", pre=True) + def _dekabob_input(cls, value: Optional[Union[AllocationParadigm, str]]) -> Optional[Union[AllocationParadigm, str]]: + if isinstance(value, str): + return value.replace("-", "_") + return value + class Config: fields = { "memory": {"alias": "mem"}, From 15378db3e4df7d503b5c41ab7f2e7e7a95eefad9 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 3 Feb 2023 17:17:10 -0500 Subject: [PATCH 142/205] fix property name that does not exist --- .../dmod/communication/maas_request/ngen/ngen_request.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py b/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py index c30b8ef73..965889ab4 100644 --- a/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py +++ b/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py @@ -208,7 +208,7 @@ def bmi_cfg_data_requirement(self) -> DataRequirement: if self._bmi_cfg_data_requirement is None: bmi_config_restrict = [ DiscreteRestriction( - variable="data_id", values=[self._bmi_config_data_id] + variable="data_id", values=[self.bmi_config_data_id] ) ] bmi_config_domain = DataDomain( From 5208dc85e7fffeea86aee9937d2e12af97c5507d Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 3 Feb 2023 17:18:07 -0500 Subject: [PATCH 143/205] fix instance where DataRequirement's init params were not specified as kwargs --- .../dmod/communication/maas_request/ngen/ngen_request.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py b/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py index 965889ab4..d0b0a22ce 100644 --- a/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py +++ b/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py @@ -216,7 +216,7 @@ def bmi_cfg_data_requirement(self) -> DataRequirement: discrete_restrictions=bmi_config_restrict, ) self._bmi_cfg_data_requirement = DataRequirement( - bmi_config_domain, True, DataCategory.CONFIG + domain=bmi_config_domain, is_input=True, category=DataCategory.CONFIG ) return self._bmi_cfg_data_requirement From 2e4d9c87efc412dff4fc7bfea30f32afdce6fb45 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 6 Feb 2023 14:16:04 -0500 Subject: [PATCH 144/205] fix NWMRequest get_model_name classmethod. see note --- .../dmod/communication/maas_request/nwm/nwm_request.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/lib/communication/dmod/communication/maas_request/nwm/nwm_request.py b/python/lib/communication/dmod/communication/maas_request/nwm/nwm_request.py index f417c8e99..1d8acdac4 100644 --- a/python/lib/communication/dmod/communication/maas_request/nwm/nwm_request.py +++ b/python/lib/communication/dmod/communication/maas_request/nwm/nwm_request.py @@ -66,6 +66,12 @@ def __init__( super().__init__(**data) + @classmethod + def get_model_name(cls) -> str: + # NOTE: overridden b.c. nwm request has nested model field. In the future we should be able + # to remove this. + return cls.__fields__["model"].type_.__fields__["nwm"].type_.__fields__["name"].default + @property def data_requirements(self) -> List[DataRequirement]: """ From 8e3cd627ab9caf6f3932aaa5be7bed53110754a8 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 10 Feb 2023 16:17:36 -0500 Subject: [PATCH 145/205] Response subtypes are now SerializableDict subtypes. This is more flexable and better captures the pre-pydantic port behavior --- .../dmod/communication/data_transmit_message.py | 7 +++---- .../communication/dataset_management_message.py | 3 ++- .../lib/communication/dmod/communication/message.py | 7 ++++--- .../dmod/communication/metadata_message.py | 4 ++-- .../dmod/communication/partition_request.py | 4 ++-- .../scheduler_request_response_body.py | 13 ++----------- .../lib/communication/dmod/communication/session.py | 5 +++-- .../dmod/communication/update_message.py | 4 ++-- 8 files changed, 20 insertions(+), 27 deletions(-) diff --git a/python/lib/communication/dmod/communication/data_transmit_message.py b/python/lib/communication/dmod/communication/data_transmit_message.py index b1a77dd94..9f216c796 100644 --- a/python/lib/communication/dmod/communication/data_transmit_message.py +++ b/python/lib/communication/dmod/communication/data_transmit_message.py @@ -1,5 +1,6 @@ from dmod.core.serializable import Serializable from pydantic import Extra +from dmod.core.serializable_dict import SerializableDict from .message import AbstractInitRequest, MessageEventType, Response from pydantic import Field from typing import ClassVar, Type, Union @@ -36,10 +37,8 @@ class DataTransmitMessage(DataTransmitUUID, AbstractInitRequest): is_last: bool = Field(False, description="Whether this is the last data transmission message in this series.") -class DataTransmitResponseBody(DataTransmitUUID): - - class Config: - extra = Extra.allow +class DataTransmitResponseBody(SerializableDict, DataTransmitUUID): + ... class DataTransmitResponse(Response): diff --git a/python/lib/communication/dmod/communication/dataset_management_message.py b/python/lib/communication/dmod/communication/dataset_management_message.py index 9b14f2485..e04352cc5 100644 --- a/python/lib/communication/dmod/communication/dataset_management_message.py +++ b/python/lib/communication/dmod/communication/dataset_management_message.py @@ -1,5 +1,6 @@ from .message import AbstractInitRequest, MessageEventType, Response from dmod.core.serializable import Serializable +from dmod.core.serializable_dict import SerializableDict from .maas_request import ExternalRequest, ExternalRequestResponse from dmod.core.meta_data import DataCategory, DataDomain, DataFormat, DataRequirement from dmod.core.enum import PydanticEnum @@ -266,7 +267,7 @@ def __init__( ) -class DatasetManagementResponseBody(Serializable): +class DatasetManagementResponseBody(SerializableDict): action: Optional[ManagementAction] data_id: Optional[str] dataset_name: Optional[str] diff --git a/python/lib/communication/dmod/communication/message.py b/python/lib/communication/dmod/communication/message.py index f9f08a76b..3d4db4074 100644 --- a/python/lib/communication/dmod/communication/message.py +++ b/python/lib/communication/dmod/communication/message.py @@ -3,6 +3,7 @@ from pydantic import Field from dmod.core.serializable import Serializable, ResultIndicator +from dmod.core.serializable_dict import SerializableDict from dmod.core.enum import PydanticEnum @@ -122,7 +123,7 @@ class Response(ResultIndicator, Message, ABC): response_to_type: ClassVar[Type[AbstractInitRequest]] = AbstractInitRequest """ The type of :class:`AbstractInitRequest` for which this type is the response""" - data: Optional[Serializable] + data: Optional[SerializableDict] @classmethod def get_message_event_type(cls) -> MessageEventType: @@ -172,13 +173,13 @@ class InvalidMessageResponse(Response): success = False reason: Literal["Invalid Request message"] = "Invalid Request message" message: Literal["Request message was not formatted as any known valid type"] = "Request message was not formatted as any known valid type" - data: Optional[Serializable] + data: Optional[SerializableDict] def __init__(self, data: Optional[Serializable]=None, **kwargs): super().__init__(data=data) -class HttpCode(Serializable): +class HttpCode(SerializableDict): http_code: int = Field(ge=100, le=599) class ErrorResponse(Response): diff --git a/python/lib/communication/dmod/communication/metadata_message.py b/python/lib/communication/dmod/communication/metadata_message.py index 3256dede0..d32ee1ecb 100644 --- a/python/lib/communication/dmod/communication/metadata_message.py +++ b/python/lib/communication/dmod/communication/metadata_message.py @@ -1,9 +1,9 @@ from .message import AbstractInitRequest, MessageEventType, Response +from dmod.core.serializable_dict import SerializableDict from numbers import Number from typing import ClassVar, Dict, Optional, Type, Union from pydantic import Field, root_validator -from dmod.core.serializable import Serializable from dmod.core.enum import PydanticEnum @@ -28,7 +28,7 @@ def get_value_for_name(cls, name_str: str) -> Optional['MetadataPurpose']: return None -class MetadataSignal(Serializable): +class MetadataSignal(SerializableDict): purpose: MetadataPurpose metadata_follows: bool diff --git a/python/lib/communication/dmod/communication/partition_request.py b/python/lib/communication/dmod/communication/partition_request.py index 5437cf4f2..973ebbc3b 100644 --- a/python/lib/communication/dmod/communication/partition_request.py +++ b/python/lib/communication/dmod/communication/partition_request.py @@ -1,7 +1,7 @@ from uuid import uuid4 from pydantic import Field, Extra from typing import ClassVar, Dict, Optional, Type, Union -from dmod.core.serializable import Serializable +from dmod.core.serializable_dict import SerializableDict from .message import AbstractInitRequest, MessageEventType, Response from .maas_request import ExternalRequest @@ -116,7 +116,7 @@ def dict( return serial -class PartitionResponseBody(Serializable): +class PartitionResponseBody(SerializableDict): data_id: Optional[str] dataset_name: Optional[str] diff --git a/python/lib/communication/dmod/communication/scheduler_request_response_body.py b/python/lib/communication/dmod/communication/scheduler_request_response_body.py index d4accd642..2eee90a69 100644 --- a/python/lib/communication/dmod/communication/scheduler_request_response_body.py +++ b/python/lib/communication/dmod/communication/scheduler_request_response_body.py @@ -1,6 +1,4 @@ -from pydantic import Extra - -from dmod.core.serializable import Serializable +from dmod.core.serializable_dict import SerializableDict from typing import Optional, Union, TYPE_CHECKING @@ -10,22 +8,15 @@ UNSUCCESSFUL_JOB = -1 -class SchedulerRequestResponseBody(Serializable): +class SchedulerRequestResponseBody(SerializableDict): job_id: int = UNSUCCESSFUL_JOB output_data_id: Optional[str] - class Config: - # allow extra model fields - extra = Extra.allow - def __eq__(self, other: object): if isinstance(other, dict): return self.to_dict() == other return super().__eq__(other) - def __contains__(self, key: str) -> bool: - return key in self.__dict__ - def __getattr__(self, key: str): return self.__dict__[key] diff --git a/python/lib/communication/dmod/communication/session.py b/python/lib/communication/dmod/communication/session.py index d92e91ee8..35b2a381f 100644 --- a/python/lib/communication/dmod/communication/session.py +++ b/python/lib/communication/dmod/communication/session.py @@ -3,6 +3,7 @@ import random from .message import AbstractInitRequest, MessageEventType, Response from dmod.core.serializable import Serializable +from dmod.core.serializable_dict import SerializableDict from dmod.core.enum import PydanticEnum from abc import ABC, abstractmethod from numbers import Number @@ -34,7 +35,7 @@ class SessionInitFailureReason(PydanticEnum): UNKNOWN = -1 -class Session(Serializable): +class Session(SerializableDict): """ A bare-bones representation of a session between some compatible server and client, over which various requests may be made, and potentially other communication may take place. @@ -248,7 +249,7 @@ class SessionInitMessage(AbstractInitRequest): """ :class:`MessageEventType`: the event type for this message implementation """ -class FailedSessionInitInfo(Serializable): +class FailedSessionInitInfo(SerializableDict): """ A :class:`~.serializeable.Serializable` type for representing details on why a :class:`SessionInitMessage` didn't successfully init a session. diff --git a/python/lib/communication/dmod/communication/update_message.py b/python/lib/communication/dmod/communication/update_message.py index e2432009f..3102b7ff7 100644 --- a/python/lib/communication/dmod/communication/update_message.py +++ b/python/lib/communication/dmod/communication/update_message.py @@ -4,7 +4,7 @@ from pydantic import Field, validator import uuid -from dmod.core.serializable import Serializable +from dmod.core.serializable_dict import SerializableDict class UpdateMessage(AbstractInitRequest): @@ -76,7 +76,7 @@ def object_type_string(self) -> str: return '{}.{}'.format(self.object_type.__module__, self.object_type.__name__) -class UpdateMessageData(Serializable): +class UpdateMessageData(SerializableDict): digest: Optional[str] object_found: Optional[bool] From 245c34dcf0c5b802010ddd5010432e2f01dbe369 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 10 Feb 2023 16:18:43 -0500 Subject: [PATCH 146/205] metadata_follows _defaults_ to False. This was pre-pydantic behavior --- python/lib/communication/dmod/communication/metadata_message.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/communication/dmod/communication/metadata_message.py b/python/lib/communication/dmod/communication/metadata_message.py index d32ee1ecb..2adf75b73 100644 --- a/python/lib/communication/dmod/communication/metadata_message.py +++ b/python/lib/communication/dmod/communication/metadata_message.py @@ -30,7 +30,7 @@ def get_value_for_name(cls, name_str: str) -> Optional['MetadataPurpose']: class MetadataSignal(SerializableDict): purpose: MetadataPurpose - metadata_follows: bool + metadata_follows: bool = False class Config: fields = { From 8c56cae708853d4a6973000ba82fe02731fb5052 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 10 Feb 2023 16:20:46 -0500 Subject: [PATCH 147/205] remove, now, unnecessary PartitionResponseBody magic method overrides. SerializableDict implements these methods --- .../dmod/communication/partition_request.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/python/lib/communication/dmod/communication/partition_request.py b/python/lib/communication/dmod/communication/partition_request.py index 973ebbc3b..5a85342ac 100644 --- a/python/lib/communication/dmod/communication/partition_request.py +++ b/python/lib/communication/dmod/communication/partition_request.py @@ -120,15 +120,6 @@ class PartitionResponseBody(SerializableDict): data_id: Optional[str] dataset_name: Optional[str] - class Config: - extra = Extra.allow - - def __contains__(self, key: str) -> bool: - return key in self.__dict__ - - def __getitem__(self, key: str): - return self.__dict__[key] - class PartitionResponse(Response): """ A response to a ::class:`PartitionRequest`. From 4518788bd0a6311f27fc78724f5b770a402a87b2 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 10 Feb 2023 16:21:13 -0500 Subject: [PATCH 148/205] fix, could be none, error in SchedulerRequestResponse --- .../lib/communication/dmod/communication/scheduler_request.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/lib/communication/dmod/communication/scheduler_request.py b/python/lib/communication/dmod/communication/scheduler_request.py index 043300e29..9fb41ba91 100644 --- a/python/lib/communication/dmod/communication/scheduler_request.py +++ b/python/lib/communication/dmod/communication/scheduler_request.py @@ -206,7 +206,9 @@ def output_data_id(self) -> Optional[str]: Optional[str] The 'data_id' of the output dataset for requested job, or ``None`` if not known. """ - return self.data.output_data_id + if self.data is None: + return None + return self.data.get("output_data_id") @classmethod def factory_init_from_deserialized_json(cls, json_obj: dict) -> "SchedulerRequestResponse": From ec3658a275bf40823984c63940cd3e325e834318 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 10 Feb 2023 16:22:22 -0500 Subject: [PATCH 149/205] add session_secret validator. handles None as input. pre-pydantic handled this --- .../communication/dmod/communication/session.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/python/lib/communication/dmod/communication/session.py b/python/lib/communication/dmod/communication/session.py index 35b2a381f..32bbe3ea8 100644 --- a/python/lib/communication/dmod/communication/session.py +++ b/python/lib/communication/dmod/communication/session.py @@ -7,9 +7,12 @@ from dmod.core.enum import PydanticEnum from abc import ABC, abstractmethod from numbers import Number -from typing import ClassVar, Dict, Optional, List, Type, Union +from typing import ClassVar, Dict, Optional, List, Type, TYPE_CHECKING, Union from pydantic import Field, IPvAnyAddress, validator, root_validator +if TYPE_CHECKING: + from pydantic.fields import ModelField + def _generate_secret() -> str: """Generate random sha256 session secret. @@ -57,6 +60,16 @@ class Session(SerializableDict): _session_timeout_delta: ClassVar[datetime.timedelta] = datetime.timedelta(minutes=30.0) + @validator("session_secret", pre=True) + def _populate_session_secret_if_none(cls, value: Optional[str], field: "ModelField") -> str: + # NOTE: pre-pydantic, this field was a computed optional: + # (i.e. `__init__(..., session_secret: str = None)`) but if None, a value was generated. + # this validator handles that case + if value is None: + return field.default_factory() # type: ignore + + return value + @validator("created", "last_accessed", pre=True) def validate_date(cls, value): if isinstance(value, datetime.datetime): From da4d0c5a967fc7fbda0d5a4f441e626203ff76c5 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 10 Feb 2023 16:23:00 -0500 Subject: [PATCH 150/205] fix UpdateMessage object_type validator. _guards_ str types --- .../communication/dmod/communication/update_message.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/lib/communication/dmod/communication/update_message.py b/python/lib/communication/dmod/communication/update_message.py index 3102b7ff7..c639c06fa 100644 --- a/python/lib/communication/dmod/communication/update_message.py +++ b/python/lib/communication/dmod/communication/update_message.py @@ -41,10 +41,12 @@ class type, but note that when messages are serialized, it is converted to the f @validator("object_type", pre=True) def _coerce_object_type(cls, value): - obj_type = locate(value) - if obj_type is None: - raise ValueError("could not resolve `object_type`") - return obj_type + if isinstance(value, str): + obj_type = locate(value) + if obj_type is None: + raise ValueError("could not resolve `object_type`") + return obj_type + return value @validator("updated_data") def _validate_updated_data_has_keys(cls, value: Dict[str, str]): From 02890f78fe32dbb7cd6138895c67a920e3c6e4a9 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 27 Jan 2023 11:38:01 -0500 Subject: [PATCH 151/205] add pydantic as dep to dmod.scheduler --- python/lib/scheduler/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/scheduler/setup.py b/python/lib/scheduler/setup.py index 817e02aa4..7e10fb7cc 100644 --- a/python/lib/scheduler/setup.py +++ b/python/lib/scheduler/setup.py @@ -21,7 +21,7 @@ url='', license='', install_requires=['docker', 'Faker', 'dmod-communication>=0.8.0', 'dmod-modeldata>=0.7.1', 'dmod-redis>=0.1.0', - 'dmod-core>=0.2.0', 'cryptography', 'uri', 'pyyaml'], + 'dmod-core>=0.2.0', 'cryptography', 'uri', 'pyyaml', 'pydantic'], packages=find_namespace_packages(exclude=['dmod.test', 'src']) ) From c752133f28a9bcbf1f100f8d6b86d86cfb51dbd3 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 27 Jan 2023 11:39:29 -0500 Subject: [PATCH 152/205] refactor enum variants to use dmod.core.enum.PydanticEnum --- python/lib/scheduler/dmod/scheduler/job/job.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/lib/scheduler/dmod/scheduler/job/job.py b/python/lib/scheduler/dmod/scheduler/job/job.py index e7cf9b912..b36ea09b9 100644 --- a/python/lib/scheduler/dmod/scheduler/job/job.py +++ b/python/lib/scheduler/dmod/scheduler/job/job.py @@ -6,8 +6,8 @@ from dmod.communication import ExternalRequest, ModelExecRequest, NGENRequest, SchedulerRequestMessage from dmod.core.serializable import Serializable from dmod.core.meta_data import DataRequirement +from dmod.core.enum import PydanticEnum from dmod.modeldata.hydrofabric import PartitionConfig -from enum import Enum from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union from uuid import UUID from uuid import uuid4 as uuid_func @@ -20,7 +20,7 @@ import logging -class JobExecStep(Enum): +class JobExecStep(PydanticEnum): """ A component of a JobStatus, representing the particular step within a "phase" encoded within the current status. @@ -133,7 +133,7 @@ def uid(self) -> int: return self._uid -class JobExecPhase(Enum): +class JobExecPhase(PydanticEnum): """ A component of a JobStatus, representing the high level transition stage at which a status exists. """ From 5cb52d7c40aa91a410ef3c2eedb9680f82ac371f Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 31 Jan 2023 16:40:40 -0500 Subject: [PATCH 153/205] refactor JobStatus --- .../lib/scheduler/dmod/scheduler/job/job.py | 60 ++++++++++--------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/python/lib/scheduler/dmod/scheduler/job/job.py b/python/lib/scheduler/dmod/scheduler/job/job.py index b36ea09b9..65a4e9901 100644 --- a/python/lib/scheduler/dmod/scheduler/job/job.py +++ b/python/lib/scheduler/dmod/scheduler/job/job.py @@ -1,6 +1,9 @@ from abc import ABC, abstractmethod from datetime import datetime from numbers import Number +from pydantic import Field, validator, root_validator +from pydantic.fields import ModelField +from warnings import warn from dmod.core.execution import AllocationParadigm from dmod.communication import ExternalRequest, ModelExecRequest, NGENRequest, SchedulerRequestMessage @@ -8,7 +11,7 @@ from dmod.core.meta_data import DataRequirement from dmod.core.enum import PydanticEnum from dmod.modeldata.hydrofabric import PartitionConfig -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union from uuid import UUID from uuid import uuid4 as uuid_func @@ -220,12 +223,27 @@ class JobStatus(Serializable): """ _NAME_DELIMITER = ':' - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict) -> 'JobStatus': - try: - cls(phase=JobExecPhase.get_for_name(json_obj['phase']), step=JobExecStep.get_for_name(json_obj['step'])) - except: - return None + # NOTE: `None` is valid input, default value for field will be used. + phase: Optional[JobExecPhase] = Field(JobExecPhase.UNKNOWN) + # NOTE: field value will be derived from `phase` field if field is unset or None. + step: Optional[JobExecStep] + + @validator("phase", pre=True) + def _set_default_phase_if_none(cls, value: Optional[JobExecPhase], field: ModelField) -> JobExecPhase: + if value is None: + return field.default + + return value + + @validator("step", always=True) + def _set_default_or_derived_step_if_none(cls, value: Optional[JobExecStep], values: Dict[str, JobExecPhase]) -> JobExecStep: + # implicit assertion that `phase` key has already been processed by it's validator + phase: JobExecPhase = values["phase"] + + if value is None: + return phase.default_start_step + + return value @classmethod def get_for_name(cls, name: str) -> 'JobStatus': @@ -258,26 +276,20 @@ def get_for_name(cls, name: str) -> 'JobStatus': if len(parsed_list) != 2: return JobStatus(JobExecPhase.UNKNOWN, JobExecStep.DEFAULT) - return JobStatus(phase=JobExecPhase.get_for_name(parsed_list[0]), - step=JobExecStep.get_for_name(parsed_list[1])) + phase, step = parsed_list + return JobStatus(phase=phase, step=step) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, JobStatus): return self.job_exec_phase == other.job_exec_phase and self.job_exec_step == other.job_exec_step else: return False - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) - def __init__(self, phase: Optional[JobExecPhase], step: Optional[JobExecStep] = None): - self._phase = JobExecPhase.UNKNOWN if phase is None else phase - if step is not None: - self._step = step - elif self._phase is not None: - self._step = self._phase.default_start_step - else: - self._step = JobExecStep.DEFAULT + def __init__(self, phase: Optional[JobExecPhase], step: Optional[JobExecStep] = None, **data): + super().__init__(phase=phase, step=step, **data) def get_for_new_step(self, step: JobExecStep) -> 'JobStatus': """ @@ -316,22 +328,16 @@ def is_interrupted(self) -> bool: @property def job_exec_phase(self) -> JobExecPhase: - return self._phase + return self.phase @property def job_exec_step(self) -> JobExecStep: - return self._step + return self.step @property def name(self) -> str: return self.job_exec_phase.name + self._NAME_DELIMITER + self.job_exec_step.name - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serial = dict() - serial['phase'] = self.job_exec_phase.name - serial['step'] = self.job_exec_step.name - return serial - class Job(Serializable, ABC): """ From a5412d09aa084e136d67b72ed515671c4ff654a7 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 31 Jan 2023 18:09:27 -0500 Subject: [PATCH 154/205] resource enum's now are based on PydanticEnum --- python/lib/scheduler/dmod/scheduler/resources/resource.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/lib/scheduler/dmod/scheduler/resources/resource.py b/python/lib/scheduler/dmod/scheduler/resources/resource.py index 755f58c07..56f43885c 100644 --- a/python/lib/scheduler/dmod/scheduler/resources/resource.py +++ b/python/lib/scheduler/dmod/scheduler/resources/resource.py @@ -2,14 +2,16 @@ from enum import Enum from typing import Any, Dict, Optional, Tuple, Type, Union +from dmod.core.enum import PydanticEnum -class ResourceAvailability(Enum): + +class ResourceAvailability(PydanticEnum): ACTIVE = 1, INACTIVE = 2, UNKNOWN = -1 -class ResourceState(Enum): +class ResourceState(PydanticEnum): READY = 1 NOT_READY = 2, UNKNOWN = -1 From 69bc26b013b66891d1d74cd345d143224723dd1a Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 31 Jan 2023 20:01:02 -0500 Subject: [PATCH 155/205] refactor AbstractProcessingAssetPool to use pydantic --- .../dmod/scheduler/resources/resource.py | 61 +++++++------------ 1 file changed, 23 insertions(+), 38 deletions(-) diff --git a/python/lib/scheduler/dmod/scheduler/resources/resource.py b/python/lib/scheduler/dmod/scheduler/resources/resource.py index 56f43885c..88987e0df 100644 --- a/python/lib/scheduler/dmod/scheduler/resources/resource.py +++ b/python/lib/scheduler/dmod/scheduler/resources/resource.py @@ -1,8 +1,10 @@ from abc import ABC, abstractmethod -from enum import Enum -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from pydantic import Field, Extra, validator + from dmod.core.enum import PydanticEnum +from dmod.core.serializable import Serializable class ResourceAvailability(PydanticEnum): @@ -17,7 +19,7 @@ class ResourceState(PydanticEnum): UNKNOWN = -1 -class AbstractProcessingAssetPool(ABC): +class AbstractProcessingAssetPool(Serializable, ABC): """ Abstract representation of some collection of assets used for processing jobs/tasks. @@ -28,8 +30,12 @@ class AbstractProcessingAssetPool(ABC): ::method:`factory_init_from_dict` class method, and serialization using the ::method:`to_dict` method. """ + cpu_count: int + memory: int + pool_id: str + unique_id_separator: str = ":" + @classmethod - @abstractmethod def factory_init_from_dict(cls, init_dict: Dict[str, Any], ignore_extra_keys: bool = False) -> 'AbstractProcessingAssetPool': """ @@ -71,45 +77,24 @@ def factory_init_from_dict(cls, init_dict: Dict[str, Any], TypeError If any parameters sourced from the init dictionary are not of a supported type for that param. """ - pass - - def __init__(self, pool_id: str, cpu_count: int, memory: int): - self._pool_id = pool_id - self._cpu_count = cpu_count - self._memory = memory - self.unique_id_separator = ':' + original_extra_level = getattr(cls.Config, "extra", None) - @property - def cpu_count(self) -> int: - return self._cpu_count - - @cpu_count.setter - def cpu_count(self, cpu_count: int): - self._cpu_count = cpu_count - - @property - def memory(self) -> int: - return self._memory + if ignore_extra_keys: + setattr(cls.Config, "extra", Extra.ignore) + else: + setattr(cls.Config, "extra", Extra.forbid) - @memory.setter - def memory(self, memory: int): - self._memory = memory + o = cls.parse_obj(init_dict) - @property - def pool_id(self) -> str: - return self._pool_id + if original_extra_level is None: + delattr(cls.Config, "extra") + else: + setattr(cls.Config, "extra", original_extra_level) - @abstractmethod - def to_dict(self) -> Dict[str, Union[str, int]]: - """ - Convert the object to a serialized dictionary. + return o - Returns - ------- - Dict[str, Union[str, int]] - The object as a serialized dictionary - """ - pass + class Config: + extra = Extra.forbid @property @abstractmethod From 271528717ae4e42a21957bfc2303378d9853ee32 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 31 Jan 2023 20:01:23 -0500 Subject: [PATCH 156/205] refactor SingleHostProcessingAssetPool --- python/lib/scheduler/dmod/scheduler/resources/resource.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/python/lib/scheduler/dmod/scheduler/resources/resource.py b/python/lib/scheduler/dmod/scheduler/resources/resource.py index 88987e0df..919c3a66f 100644 --- a/python/lib/scheduler/dmod/scheduler/resources/resource.py +++ b/python/lib/scheduler/dmod/scheduler/resources/resource.py @@ -119,13 +119,7 @@ class SingleHostProcessingAssetPool(AbstractProcessingAssetPool, ABC): creation. """ - def __init__(self, pool_id: str, hostname: str, cpu_count: int, memory: int): - super().__init__(pool_id=pool_id, cpu_count=cpu_count, memory=memory) - self._hostname = hostname - - @property - def hostname(self) -> str: - return self._hostname + hostname: str class Resource(SingleHostProcessingAssetPool): From 473114f04200295c25204c0c8e43fadb837a7881 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 31 Jan 2023 20:01:44 -0500 Subject: [PATCH 157/205] refactor Resource --- .../dmod/scheduler/resources/resource.py | 268 +++++++----------- 1 file changed, 109 insertions(+), 159 deletions(-) diff --git a/python/lib/scheduler/dmod/scheduler/resources/resource.py b/python/lib/scheduler/dmod/scheduler/resources/resource.py index 919c3a66f..1c5be27b9 100644 --- a/python/lib/scheduler/dmod/scheduler/resources/resource.py +++ b/python/lib/scheduler/dmod/scheduler/resources/resource.py @@ -146,63 +146,56 @@ class Resource(SingleHostProcessingAssetPool): are expected to never change for a resource. """ - @classmethod - def factory_init_from_dict(cls, init_dict: dict, ignore_extra_keys: bool = False) -> 'Resource': - """ - Initialize a new object from the given dictionary, raising a ::class:`ValueError` if there are missing expected - keys or there are extra keys when the method is not set to ignore them. + availability: ResourceAvailability + """ + The availability of the resource. - Note that this method will allow ::class:`ResourceAvailability` and ::class:`ResourceState` values for the - init values of ``availability`` and ``state`` respectively, in addition to strings. It will also convert - numeric types from string values appropriately. + Note that the property setter accepts both string and ::class:`ResourceAvailability` values. For a string, the + argument is converted to a ::class:`ResourceAvailability` value using ::method:`get_resource_enum_value`. - Also, unlike other implementations, ``total cpus`` and ``total memory`` are expected keys, but they are not - required. If they are not present, the defaults (the respective available values) are used by the initializer. + However, if the conversion of a string with ::method:`get_resource_enum_value` returns ``None``, the setter + sets ::attribute:`availability` to the ``UNKNOWN`` enum value, rather than ``None``. This is more applicable + and allows the getter to always return an actual ::class:`ResourceAvailability` instance. + """ - parent: - """ - node_id = None - hostname = None - avail = None - state = None - cpus = None - total_cpus = None - memory = None - total_memory = None - - for param_key in init_dict: - # We don't care about non-string keys directly, but they are implicitly extra ... - if not isinstance(param_key, str): - if not ignore_extra_keys: - raise ValueError("Unexpected non-string resource init key") - else: - continue - lower_case_key = param_key.lower() - if lower_case_key == 'node_id' and node_id is None: - node_id = init_dict[param_key] - elif lower_case_key == 'hostname' and hostname is None: - hostname = init_dict[param_key] - elif lower_case_key == 'availability' and avail is None: - avail = init_dict[param_key] - elif lower_case_key == 'state' and state is None: - state = init_dict[param_key] - elif lower_case_key == 'cpus' and cpus is None: - cpus = int(init_dict[param_key]) - elif lower_case_key == 'memorybytes' and memory is None: - memory = int(init_dict[param_key]) - elif lower_case_key == 'total cpus' and total_cpus is None: - total_cpus = int(init_dict[param_key]) - elif lower_case_key == 'total memory' and total_memory is None: - total_memory = int(init_dict[param_key]) - elif not ignore_extra_keys: - raise ValueError("Unexpected resource init key (or case-insensitive duplicate) {}".format(param_key)) - - # Make sure we have everything required set - if node_id is None or hostname is None or cpus is None or memory is None or avail is None or state is None: - raise ValueError("Insufficient valid values keyed within resource init dictionary") - - return cls(resource_id=node_id, hostname=hostname, availability=avail, state=state, cpu_count=cpus, - memory=memory, total_cpu_count=total_cpus, total_memory=total_memory) + state: ResourceState = Field(description="The readiness state of the resource.") + """ + Note that the property setter accepts both string and ::class:`ResourceState` values. For a string, the + argument is converted to a ::class:`ResourceState` value using ::method:`get_resource_enum_value`. + + However, if the conversion of a string with ::method:`get_resource_enum_value` returns ``None``, the setter sets + ::attribute:`state` to the ``UNKNOWN`` enum value, rather than ``None``. This is more applicable and allows the + getter to always return an actual ::class:`ResourceState` instance. + """ + + total_cpus: Optional[int] = Field(description="The total number of CPUs known to be on this resource.") + + total_memory: Optional[int] = Field(description="The total amount of memory known to be on this resource.") + + class Config: + fields = { + "availability": {"alias": "Availability"}, + "cpu_count": {"alias": "CPUs"}, + "hostname": {"alias": "Hostname"}, + "memory": {"alias": "MemoryBytes"}, + "pool_id": {"alias": "node_id"}, + "state": {"alias": "State"}, + "total_cpus": {"alias": "Total CPUs"}, + "total_memory": {"alias": "Total Memory"}, + "unique_id_separator": {"exclude": True} + } + + @validator("availability", pre=True) + def _validate_availability(cls, value: Optional[Any]) -> Union[Any, ResourceAvailability]: + if value is None: + return ResourceAvailability.UNKNOWN + return value + + @validator("state", pre=True) + def _validate_state(cls, value: Optional[Any]) -> Union[Any, ResourceState]: + if value is None: + return ResourceState.UNKNOWN + return value @classmethod def generate_unique_id(cls, resource_id: str, separator: str): @@ -220,7 +213,7 @@ def generate_unique_id(cls, resource_id: str, separator: str): str The derived unique id. """ - return cls.__name__ + separator + resource_id + return f"{cls.__name__}{separator}{resource_id}" @classmethod def get_cpu_hash_key(cls) -> str: @@ -232,7 +225,7 @@ def get_cpu_hash_key(cls) -> str: str The hash key value for serialized dictionaries/hashes representations. """ - return 'CPUs' + return "CPUs" @classmethod def get_resource_enum_value(cls, enum_type: Union[Type[ResourceAvailability], Type[ResourceState]], @@ -269,7 +262,8 @@ def get_resource_enum_value(cls, enum_type: Union[Type[ResourceAvailability], Ty return val return None - def __eq__(self, other): + + def __eq__(self, other: object): if not isinstance(other, Resource): return super().__eq__(other) else: @@ -278,19 +272,32 @@ def __eq__(self, other): and self.cpu_count == other.cpu_count and self.memory == other.memory \ and self.total_cpu_count == other.total_cpu_count and self.total_memory == other.total_memory - def __init__(self, resource_id: str, hostname: str, availability: Union[str, ResourceAvailability], - state: Union[str, ResourceState], cpu_count: int, memory: int, total_cpu_count: Optional[int], - total_memory: Optional[int]): - super().__init__(pool_id=resource_id, hostname=hostname, cpu_count=cpu_count, memory=memory) - - self._availability = None - self.availability = availability - - self._state = state - self.state = state - - self._total_cpu_count = cpu_count if total_cpu_count is None else total_cpu_count - self._total_memory = memory if total_memory is None else total_memory + def __init__( + self, + resource_id: str = None, + hostname: str = None, + availability: Union[str, ResourceAvailability] = None, + state: Union[str, ResourceState] = None, + cpu_count: int = None, + memory: int = None, + total_cpu_count: Optional[int] = None, + total_memory: Optional[int] = None, + **data + ): + if data: + super().__init__(**data) + return + + super().__init__( + pool_id=resource_id, + hostname=hostname, + cpu_count=cpu_count, + memory=memory, + availability=availability, + state=state, + total_cpu_count=cpu_count if total_cpu_count is None else total_cpu_count, + total_memory=memory if total_memory is None else total_memory, + ) def allocate(self, cpu_count: int, memory: int) -> Tuple[int, int, bool]: """ @@ -333,32 +340,12 @@ def allocate(self, cpu_count: int, memory: int) -> Tuple[int, int, bool]: self.memory = 0 return allocated_cpus, allocated_mem, is_fully_allocated - @property - def availability(self) -> ResourceAvailability: - """ - The availability of the resource. - - Note that the property setter accepts both string and ::class:`ResourceAvailability` values. For a string, the - argument is converted to a ::class:`ResourceAvailability` value using ::method:`get_resource_enum_value`. - - However, if the conversion of a string with ::method:`get_resource_enum_value` returns ``None``, the setter - sets ::attribute:`availability` to the ``UNKNOWN`` enum value, rather than ``None``. This is more applicable - and allows the getter to always return an actual ::class:`ResourceAvailability` instance. - - Returns - ------- - ResourceAvailability - The availability of the resource. - """ - return self._availability - - @availability.setter - def availability(self, availability: Union[str, ResourceAvailability]): + def set_availability(self, availability: Union[str, ResourceAvailability]): if isinstance(availability, ResourceAvailability): enum_val = availability else: enum_val = self.get_resource_enum_value(ResourceAvailability, availability) - self._availability = ResourceAvailability.UNKNOWN if enum_val is None else enum_val + self.__dict__["availability"] = ResourceAvailability.UNKNOWN if enum_val is None else enum_val def is_allocatable(self) -> bool: """ @@ -396,85 +383,48 @@ def release(self, cpu_count: int, memory: int): def resource_id(self) -> str: return self.pool_id - @property - def state(self) -> ResourceState: - """ - The readiness state of the resource. - - Note that the property setter accepts both string and ::class:`ResourceState` values. For a string, the - argument is converted to a ::class:`ResourceState` value using ::method:`get_resource_enum_value`. - - However, if the conversion of a string with ::method:`get_resource_enum_value` returns ``None``, the setter sets - ::attribute:`state` to the ``UNKNOWN`` enum value, rather than ``None``. This is more applicable and allows the - getter to always return an actual ::class:`ResourceState` instance. - - Returns - ------- - ResourceState - The readiness state of the resource. - """ - return self._state - - @state.setter - def state(self, state: Union[str, ResourceState]): + def set_state(self, state: Union[str, ResourceState]): if isinstance(state, ResourceState): enum_val = state else: enum_val = self.get_resource_enum_value(ResourceState, state) - self._state = ResourceState.UNKNOWN if enum_val is None else enum_val - - def to_dict(self) -> Dict[str, Union[str, int]]: - """ - Convert the object to a serialized dictionary. + self.__dict__["state"] = ResourceState.UNKNOWN if enum_val is None else enum_val - Key names are as shown in the example below. Enum values are represented as the lower-case version of the name - for the given value. Values shown for CPU and Memory are the max values. + @property + def unique_id(self) -> str: + return self.generate_unique_id(resource_id=self.resource_id, separator=self.unique_id_separator) - E.g.: - { - 'node_id': "Node-0001", - 'Hostname': "my-host", - 'Availability': "active", - 'State': "ready", - 'CPUs': 18, - 'MemoryBytes': 33548128256, - 'Total CPUs': 18, - 'Total Memory: 33548128256 + def _setter_methods(self) -> Dict[str, Callable]: + """Mapping of attribute name to setter method. This supports backwards functional compatibility.""" + # TODO: remove once migration to setters by down stream users is complete + return { + "state": self.set_state, + "availability": self.set_availability, } - Returns - ------- - Dict[str, Union[str, int]] - The object as a serialized dictionary. + def __setattr__(self, name: str, value: Any): """ - return {'node_id': self.resource_id, 'Hostname': self.hostname, 'Availability': self.availability.name.lower(), - 'State': self.state.name.lower(), self.get_cpu_hash_key(): self.cpu_count, 'MemoryBytes': self.memory, - 'Total CPUs': self.total_cpu_count, 'Total Memory': self.total_memory} + Use property setter method when available. - @property - def total_cpu_count(self) -> int: - """ - The total number of CPUs known to be on this resource. + Note, all setter methods should modify their associated property using the instance `__dict__`. + This ensures that calls to, for example, `set_id` don't raise a warning, while `o.id = "new + id"` do. - Returns - ------- - int - The total number of CPUs known to be on this resource. - """ - return self._total_cpu_count + Example: + ``` + class SomeJob(Job): + id: str - @property - def total_memory(self) -> int: + def set_id(self, value: str): + self.__dict__["id"] = value + ``` """ - The total amount of memory known to be on this resource. + if name not in self._setter_methods(): + return super().__setattr__(name, value) - Returns - ------- - int - The total amount of memory known to be on this resource. - """ - return self._total_memory + setter_fn = self._setter_methods()[name] - @property - def unique_id(self) -> str: - return self.generate_unique_id(resource_id=self.resource_id, separator=self.unique_id_separator) + message = f"Setting by attribute is deprecated. Use `{self.__class__.__name__}.{setter_fn.__name__}` method instead." + warn(message, DeprecationWarning) + + setter_fn(value) From 45f60326b2e06a1095fd1bd755015bfd7788e9c6 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 31 Jan 2023 20:02:23 -0500 Subject: [PATCH 158/205] refactor ResourceAllocation --- .../resources/resource_allocation.py | 129 +++++++----------- 1 file changed, 46 insertions(+), 83 deletions(-) diff --git a/python/lib/scheduler/dmod/scheduler/resources/resource_allocation.py b/python/lib/scheduler/dmod/scheduler/resources/resource_allocation.py index f78289358..379e341c8 100644 --- a/python/lib/scheduler/dmod/scheduler/resources/resource_allocation.py +++ b/python/lib/scheduler/dmod/scheduler/resources/resource_allocation.py @@ -1,5 +1,6 @@ from datetime import datetime -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union +from pydantic import root_validator, validator from .resource import SingleHostProcessingAssetPool @@ -8,54 +9,36 @@ class ResourceAllocation(SingleHostProcessingAssetPool): Implementation of ::class:`SingleHostProcessingAssetPool` representing a sub-collection of processing assets on a resource that have been allocated for a job. """ - - @classmethod - def factory_init_from_dict(cls, alloc_dict: dict, ignore_extra_keys: bool = False) -> 'ResourceAllocation': - """ - parent: - """ - node_id = None - hostname = None - cpus_allocated = None - mem = None - created = None - separator = None - - for param_key in alloc_dict: - # We don't care about non-string keys directly, but they are implicitly extra ... - if not isinstance(param_key, str): - if not ignore_extra_keys: - raise ValueError("Unexpected non-string allocation key") - else: - continue - lower_case_key = param_key.lower() - if lower_case_key == 'node_id' and node_id is None: - node_id = alloc_dict[param_key] - elif lower_case_key == 'hostname' and hostname is None: - hostname = alloc_dict[param_key] - elif lower_case_key == 'cpus_allocated' and cpus_allocated is None: - cpus_allocated = int(alloc_dict[param_key]) - elif lower_case_key == 'mem' and mem is None: - mem = int(alloc_dict[param_key]) - elif lower_case_key == 'created' and created is None: - created = alloc_dict[param_key] - elif lower_case_key == 'separator' and separator is None: - separator = alloc_dict[param_key] - elif not ignore_extra_keys: - raise ValueError("Unexpected allocation key (or case-insensitive duplicate) {}".format(param_key)) - - # Make sure we have everything required set - if node_id is None or hostname is None or cpus_allocated is None or mem is None: - raise ValueError("Insufficient valid values keyed within allocation dictionary") - - deserialized = cls(resource_id=node_id, hostname=hostname, cpus_allocated=cpus_allocated, requested_memory=mem, - created=created) - if isinstance(separator, str): - deserialized.unique_id_separator = separator - - return deserialized - - def __eq__(self, other): + created: datetime + + class Config: + fields = { + "pool_id": {"alias": "node_id"}, + "hostname": {"alias": "Hostname"}, + "cpu_count": {"alias": "cpus_allocated"}, + "memory": {"alias": "mem"}, + "created": {"alias": "Created"}, + "unique_id_separator": {"alias": "separator"}, + } + field_serializers = { + "created": lambda v: v.timestamp() + } + + @validator("created", pre=True) + def _validate_datetime(cls, value) -> datetime: + if value is None: + return datetime.now() + elif isinstance(value, datetime): + return value + elif isinstance(value, float): + return datetime.fromtimestamp(value) + return datetime.fromtimestamp(float(value)) + + @root_validator(pre=True) + def _lowercase_all_keys(cls, values: Dict[str, Any]) -> Dict[str, Any]: + return {k.lower(): v for k, v in values.items()} + + def __eq__(self, other: object) -> bool: if not isinstance(other, ResourceAllocation): return False else: @@ -65,38 +48,22 @@ def __eq__(self, other): and self.memory == other.memory \ and self.created == other.created - def __init__(self, resource_id: str, hostname: str, cpus_allocated: int, requested_memory: int, - created: Optional[Union[str, float, datetime]] = None): - super().__init__(pool_id=resource_id, hostname=hostname, cpu_count=cpus_allocated, memory=requested_memory) - self._set_created(created) - - def _set_created(self, created: Optional[Union[str, float, datetime]] = None): - """ - A "private" method for setting the ::attribute:`created` property, potentially converting to value to set. - - A ``None`` argument is interpreted as ``now``. Other non-datetime args are interpreted as string or numeric - epoch timestamp representations (i.e., values like those from ::method:`datetime.timestamp`). - - Parameters - ---------- - created - The value to set. - """ - if created is None: - self._created = datetime.now() - elif isinstance(created, datetime): - self._created = created - elif isinstance(created, float): - self._created = datetime.fromtimestamp(created) - else: - self._created = datetime.fromtimestamp(float(created)) - - @property - def created(self) -> datetime: - return self._created + def __init__( + self, + resource_id: str = None, + hostname: str = None, + cpus_allocated: int = None, + requested_memory: int = None, + created: Optional[Union[str, float, datetime]] = None, + **data + ): + if data: + super().__init__(cpus_allocated=cpus_allocated, **data) + return + super().__init__(pool_id=resource_id, hostname=hostname, cpu_count=cpus_allocated, memory=requested_memory, created=created) def get_unique_id(self, separator: str) -> str: - return self.__class__.__name__ + separator + self.resource_id + separator + str(self.created.timestamp()) + return f"{self.__class__.__name__}{separator}{self.resource_id}{separator}{str(self.created.timestamp())}" @property def node_id(self) -> str: @@ -128,10 +95,6 @@ def resource_id(self) -> str: """ return self.pool_id - def to_dict(self) -> Dict[str, Union[str, int]]: - return {'node_id': self.node_id, 'Hostname': self.hostname, 'cpus_allocated': self.cpu_count, - 'mem': self.memory, 'Created': self.created.timestamp(), 'separator': self.unique_id_separator} - @property def unique_id(self) -> str: return self.get_unique_id(self.unique_id_separator) From 840c36b91197fea59cb786f2a0082940b892e31b Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 1 Feb 2023 13:40:52 -0500 Subject: [PATCH 159/205] resource fields accepted case insensitively --- .../dmod/scheduler/resources/resource.py | 52 +++++++++++++++++-- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/python/lib/scheduler/dmod/scheduler/resources/resource.py b/python/lib/scheduler/dmod/scheduler/resources/resource.py index 1c5be27b9..b46435a16 100644 --- a/python/lib/scheduler/dmod/scheduler/resources/resource.py +++ b/python/lib/scheduler/dmod/scheduler/resources/resource.py @@ -1,6 +1,9 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Optional, Tuple, Type, Union -from pydantic import Field, Extra, validator +from typing_extensions import Self +from pydantic import Field, Extra, validator, root_validator +from functools import cache +from warnings import warn from dmod.core.enum import PydanticEnum @@ -37,7 +40,7 @@ class AbstractProcessingAssetPool(Serializable, ABC): @classmethod def factory_init_from_dict(cls, init_dict: Dict[str, Any], - ignore_extra_keys: bool = False) -> 'AbstractProcessingAssetPool': + ignore_extra_keys: bool = False) -> Self: """ Initialize a new object from the given dictionary, raising a ::class:`ValueError` if there are missing expected keys or there are extra keys when the method is not set to ignore them. @@ -197,6 +200,34 @@ def _validate_state(cls, value: Optional[Any]) -> Union[Any, ResourceState]: return ResourceState.UNKNOWN return value + @root_validator(pre=True) + def _remap_alias_case_insensitive(cls, values: Dict[str, Any]) -> Dict[str, Any]: + alias_field_map = cls._alias_field_map() + + # NOTE: consider removing this in the future and enforcing case sensitive keys + new_values: Dict[str, Any] = dict() + for k, v in values.items(): + if k.lower() in alias_field_map: + new_values[alias_field_map[k.lower()]] = v + continue + new_values[k] = v + return new_values + + @root_validator() + def _set_total_cpus_and_total_memory_if_unset(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if values.get("total_cpus") is None: + values["total_cpus"] = values["cpu_count"] + + if values.get("total_memory") is None: + values["total_memory"] = values["memory"] + return values + + @classmethod + @cache + def _alias_field_map(cls) -> Dict[str, str]: + """Mapping of lower cased alias names to cased alias names.""" + return {v.alias.lower(): v.alias for v in cls.__fields__.values()} + @classmethod def generate_unique_id(cls, resource_id: str, separator: str): """ @@ -270,7 +301,7 @@ def __eq__(self, other: object): return self.resource_id == other.resource_id and self.hostname == other.hostname \ and self.availability == other.availability and self.state == other.state \ and self.cpu_count == other.cpu_count and self.memory == other.memory \ - and self.total_cpu_count == other.total_cpu_count and self.total_memory == other.total_memory + and self.total_cpus == other.total_cpus and self.total_memory == other.total_memory def __init__( self, @@ -285,6 +316,17 @@ def __init__( **data ): if data: + # NOTE: this can be removed alias field names _are_ case sensitive + potentially_aliased_fields = { + "availability": availability, + "hostname": hostname, + "state": state, + "total_memory": total_memory + } + + for field_name, value in potentially_aliased_fields.items(): + if value is not None: + data[field_name] = value super().__init__(**data) return @@ -295,8 +337,8 @@ def __init__( memory=memory, availability=availability, state=state, - total_cpu_count=cpu_count if total_cpu_count is None else total_cpu_count, - total_memory=memory if total_memory is None else total_memory, + total_cpus=total_cpu_count, + total_memory=total_memory ) def allocate(self, cpu_count: int, memory: int) -> Tuple[int, int, bool]: From 160af12bc8ba558915d3ea1efbf88c16ef835e3c Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 1 Feb 2023 13:58:15 -0500 Subject: [PATCH 160/205] add accidentally remove total_cpu_count prop --- .../scheduler/dmod/scheduler/resources/resource.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/lib/scheduler/dmod/scheduler/resources/resource.py b/python/lib/scheduler/dmod/scheduler/resources/resource.py index b46435a16..9d0456054 100644 --- a/python/lib/scheduler/dmod/scheduler/resources/resource.py +++ b/python/lib/scheduler/dmod/scheduler/resources/resource.py @@ -421,6 +421,19 @@ def release(self, cpu_count: int, memory: int): self.cpu_count = self.cpu_count + cpu_count self.memory = self.memory + memory + @property + def total_cpu_count(self) -> int: + """ + The total number of CPUs known to be on this resource. + Returns + ------- + int + The total number of CPUs known to be on this resource. + """ + # NOTE: total cpus will be set or derived from `cpu_count` + return self.total_cpus # type: ignore + + @property def resource_id(self) -> str: return self.pool_id From 521abd8ad4435953c95a4bf29d03ad0814280733 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 1 Feb 2023 14:06:05 -0500 Subject: [PATCH 161/205] validate that cpu and memory total cannot be larger than cpu and memory --- python/lib/scheduler/dmod/scheduler/resources/resource.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/lib/scheduler/dmod/scheduler/resources/resource.py b/python/lib/scheduler/dmod/scheduler/resources/resource.py index 9d0456054..4cbd45016 100644 --- a/python/lib/scheduler/dmod/scheduler/resources/resource.py +++ b/python/lib/scheduler/dmod/scheduler/resources/resource.py @@ -220,6 +220,14 @@ def _set_total_cpus_and_total_memory_if_unset(cls, values: Dict[str, Any]) -> Di if values.get("total_memory") is None: values["total_memory"] = values["memory"] + + msg_template = "`{}` cannot be larger than `{}`. {} > {}" + if values["cpu_count"] > values["total_cpus"]: + raise ValueError(msg_template.format("cpu_count", "total_cpus", values["cpu_count"], values["total_cpus"])) + + if values["memory"] > values["total_memory"]: + raise ValueError(msg_template.format("memory", "total_memory", values["memory"], values["total_memory"])) + return values @classmethod From 8a7bccc2b8281482806d658da526ae51bf4a3764 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 2 Feb 2023 18:26:50 -0500 Subject: [PATCH 162/205] refactor RsaKeyPair --- .../scheduler/dmod/scheduler/rsa_key_pair.py | 643 +++++++++--------- 1 file changed, 320 insertions(+), 323 deletions(-) diff --git a/python/lib/scheduler/dmod/scheduler/rsa_key_pair.py b/python/lib/scheduler/dmod/scheduler/rsa_key_pair.py index ffc790ece..35c30edf5 100644 --- a/python/lib/scheduler/dmod/scheduler/rsa_key_pair.py +++ b/python/lib/scheduler/dmod/scheduler/rsa_key_pair.py @@ -3,29 +3,235 @@ from cryptography.hazmat.backends import default_backend from dmod.core.serializable import Serializable from pathlib import Path -from typing import Dict, Union +from pydantic import Field, PrivateAttr, validator +from typing import ClassVar, Dict, Optional, Tuple, Union +from typing_extensions import Self import datetime import os from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKeyWithSerialization +class _RsaKeyPair(Serializable): + """ + This is a shim object that enables partial instantiation of a :class:`RsaKeyPair`. This class exposes methods and + properties to interact, generate, and write a key pair. However, it does not expose a way to serialize or + deserialize keys and other associated metadata from a dictionary. For the functionality, see :class:`RsaKeyPair`. + """ + + directory: Path + """ + The directory in which the key pair files have been or will be written, as a :class:`Path`. + + If `None` is provided, `directory` defaults to ``$HOME/.ssh/``. If the default or provided directory does not + exists, it and any intermediate directories will be created. Directory inputs that exist and are not directories + (i.e. a file) will raise a ValueError. + """ + + name: str = Field(min_length=1) + """Basename of private key file.""" + + _priv_key: RSAPrivateKeyWithSerialization = PrivateAttr(None) + _priv_key_pem: bytes = PrivateAttr(None) + _is_deserialized: bool = PrivateAttr(False) + + @validator("directory", pre=True) + def _validate_directory(cls, value: Union[str, Path, None]) -> Union[str, Path]: + if value is None: + return Path.home() / ".ssh" + + if isinstance(value, str): + return value.strip() + + return value + + @validator("directory") + def _post_validate_directory(cls, value: Path) -> Path: + if not value.exists(): + value.mkdir(parents=True) + + elif not value.is_dir(): + raise ValueError(f"Existing non-directory file at path provided for key pair directory. {value!r}") + + return value + + @validator("name") + def _validate_name(cls, value: str) -> str: + return value.strip() + + @property + def private_key_file(self) -> Path: + """ + + Returns + ------- + Path + Path to private key file. Is not guaranteed to exist. + """ + return self.directory / self.name + + @property + def public_key_file(self) -> Path: + """ + Same as private key filepath, but with the suffix ".pub". + + Returns + ------- + Path + Path to public key file. Is not guaranteed to exist. + """ + return self.directory / f"{self.name}.pub" + + @property + def private_key_pem(self) -> bytes: + """ + + Returns + ------- + bytes + Encoded private key in PEM format + """ + if self._priv_key_pem is None: + self._priv_key_pem = self._private_key_bytes_from_private_key(self._private_key) + return self._priv_key_pem # type: ignore + + def delete_key_files(self) -> Tuple[bool, bool]: + """ + Delete the files at the paths specified by :attr:`private_key_file` and :attr:`public_key_file`, as long as + there is an existing, regular (i.e., from :method:`Path.is_file`) file at the individual paths. + + Note that whether a delete is performed for one file is independent of what the state of the other. I.e., if + the private key file does not exist, thus resulting in no attempt to delete it, this will not affect whether + there is a delete operation on the public key file. + + Returns + ------- + tuple + A tuple of boolean values, representing whether the private key file and the public key file respectively + were deleted + """ + deleted_private = False + deleted_public = False + if self.private_key_file.exists() and self.private_key_file.is_file(): + self.private_key_file.unlink() + deleted_private = True + if self.public_key_file.exists() and self.public_key_file.is_file(): + self.public_key_file.unlink() + deleted_public = True + return deleted_private, deleted_public + + def write_key_files(self, write_private: bool = True, write_public: bool = True): + """ + Write private and/or public keys to files at :attr:`private_key_file` and :attr:`public_key_file` respectively, + assuming the respective file does not already exist. + + Parameters + ---------- + write_private : bool + An option, ``True`` by default, for whether the private key should be written to :attr:`private_key_file` + + write_public : bool + An option, ``True`` by default, for whether the public key should be written to :attr:`public_key_file` + """ + # if fail to write private key file, delete any existing pub / priv key files. + try: + if write_private and not self.private_key_file.exists(): + self._write_private_key(self._private_key, raise_on_fail=True) + except Exception as e: + if self.public_key_file.exists(): + _, deleted_public = self.delete_key_files() + if not deleted_public: + raise RuntimeError(f"Failed to write private key file. During failure, failed to remove public key file. '{self.public_key_file}'") from e + raise e + + # NOTE: if cannot write pub key file, priv key file, if it exists, will not be removed. + if write_public and not self.public_key_file.exists(): + self._write_public_key(self._private_key, raise_on_fail=True) + + @property + def _private_key(self) -> RSAPrivateKeyWithSerialization: + """ + Serialized private key. Lazily loads private key from :property:`private_key_file` or dynamically generates one. + + If the private key is loaded from :property:`private_key_file` and :property:`public_key_file` does not exist, a + public key is written to disk at :property:`public_key_file`. + """ + if self._priv_key is None and self.private_key_file.exists(): + priv_key_file = self.private_key_file.read_bytes() + self._priv_key = serialization.load_pem_private_key(priv_key_file, None, default_backend()) + self._is_deserialized = True -class RsaKeyPair(Serializable): + self._write_public_key(self._priv_key, overwrite=False, raise_on_fail=True) + + elif self._priv_key is None: + self._priv_key = rsa.generate_private_key(backend=default_backend(), public_exponent=65537, key_size=3072) + + return self._priv_key # type: ignore + + @staticmethod + def _public_key_bytes_from_private_key(private_key: RSAPrivateKeyWithSerialization) -> bytes: + return private_key.public_key().public_bytes(serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH) + + @staticmethod + def _private_key_bytes_from_private_key(private_key: RSAPrivateKeyWithSerialization) -> bytes: + return private_key.private_bytes(encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption()) + + @staticmethod + def _read_private_key_ctime(location: Path) -> datetime.datetime: + return datetime.datetime.fromtimestamp(os.path.getctime(str(location))) + + @staticmethod + def __try_write(content: str, location: Path, overwrite: bool = False, raise_on_fail: bool = False) -> bool: + if not overwrite and location.exists(): + return False + try: + location.write_text(content) + except Exception as e: + if raise_on_fail: + raise e + return False + return True + + def _write_public_key(self, private_key: RSAPrivateKeyWithSerialization, overwrite: bool = False, raise_on_fail: bool = False) -> bool: + pub_key = self._public_key_bytes_from_private_key(private_key).decode("utf-8") + return self.__try_write(pub_key, self.public_key_file, overwrite=overwrite, raise_on_fail=raise_on_fail) + + def _write_private_key(self, private_key: RSAPrivateKeyWithSerialization, overwrite: bool = False, raise_on_fail: bool = False) -> bool: + priv_key = self._private_key_bytes_from_private_key(private_key).decode("utf-8") + return self.__try_write(priv_key, self.private_key_file, overwrite=overwrite, raise_on_fail=raise_on_fail) + + def _delete_existing_key_files_if_priv_keys_differ(self): + # Remove any existing private/public key files unless the contents match serialized private key value + if self.private_key_file.exists(): + priv_key_file_bytes = self.private_key_file.read_bytes() + + if priv_key_file_bytes != self._private_key_bytes_from_private_key(self._private_key): + self.public_key_file.unlink(missing_ok=True) + self.private_key_file.unlink() + raise RuntimeError("Existing private key from file does not match provided private.") + + elif self.public_key_file.exists(): + # Always remove an existing public key file if there was not a private key file + self.public_key_file.unlink() + + +class RsaKeyPair(_RsaKeyPair, Serializable): """ Representation of an RSA key pair and certain meta properties, in particular a name for the key and a pair of :class:`Path` objects for its private and public key files. Keys may be either dynamically generated or deserialized from existing files. Key file basenames are derived from the :attr:`name` value for the object, which is set from an init param that - defaults to ``id_rsa`` if not provided. The public key file will have the same basename as the private key file, - except with the ``.pub`` extension added. + defaults to ``id_rsa`` if not provided. However, :attr:`name` is a required field when initializing from a + dictionary. The public key file will have the same basename as the private key file, except with the ``.pub`` + extension added. When the private key file already exists, the private key will be deserialized from the file contents. This will happen immediately when the object is created. - When the private key file does not already exists, the actual keys will be generated dynamically, though this is - performed lazily. The :method:`generate_key_pair` method will trigger all necessary lazy instantiations and also - cause the key files to be written. + When the private key file does not already exists, the actual keys will be generated dynamically -- but not written + to a file. Use the :method:`write_key_files` to write key pairs to a file. Note that rich comparisons for ``==`` and ``<`` are expressly defined, with the other implementations being derived from these two. @@ -34,19 +240,48 @@ class RsaKeyPair(Serializable): # The basename of the private key file will always be the key pair's name self.name == self.private_key_file.name - # The returned generation time property value will always be equal to the time stamp of the private key file - self.generation_time == datetime.datetime.fromtimestamp(os.path.getctime(str(self.private_key_file))) + """ + private_key: RSAPrivateKeyWithSerialization + """ + Serialized private key for this key pair object. """ - _SERIAL_DATETIME_STR_FORMAT = '%Y-%m-%d %H:%M:%S.%f' - _SERIAL_KEY_DIRECTORY = 'directory' - _SERIAL_KEY_NAME = 'name' - _SERIAL_KEY_PRIVATE_KEY = 'private_key' - _SERIAL_KEY_GENERATION_TIME = 'generation_time' - _SERIAL_KEYS_REQUIRED = [_SERIAL_KEY_NAME, _SERIAL_KEY_DIRECTORY, _SERIAL_KEY_PRIVATE_KEY, _SERIAL_KEY_GENERATION_TIME] + + generation_time: datetime.datetime + + _pub_key: bytes = PrivateAttr(None) + __private_key_text: str = PrivateAttr(None) + + _SERIAL_DATETIME_STR_FORMAT: ClassVar[str] = '%Y-%m-%d %H:%M:%S.%f' + + @validator("generation_time", pre=True) + def _validate_datetime(cls, value: Union[str, datetime.datetime]) -> datetime.datetime: + if isinstance(value, datetime.datetime): + return value + + return datetime.datetime.strptime(value, cls.get_datetime_str_format()) + + @validator("private_key", pre=True) + def _validate_private_key(cls, value: Union[str, RSAPrivateKeyWithSerialization ]) -> RSAPrivateKeyWithSerialization: + if isinstance(value, RSAPrivateKeyWithSerialization): + return value + + priv_key_bytes = value.encode("utf-8") + return serialization.load_pem_private_key(priv_key_bytes, None, default_backend()) + + class Config: # type: ignore + arbitrary_types_allowed = True + def _serialize_datetime(self: "RsaKeyPair", value: datetime.datetime) -> str: + return value.strftime(self.get_datetime_str_format()) + + field_serializers = { + "generation_time": _serialize_datetime, + "private_key": lambda self, _: self._private_key_text, + "directory": lambda directory: str(directory), + } @classmethod - def factory_init_from_deserialized_json(cls, json_obj: Dict[str, str]): + def factory_init_from_deserialized_json(cls, json_obj: Dict[str, str]) -> Optional[Self]: """ Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. @@ -73,204 +308,104 @@ def factory_init_from_deserialized_json(cls, json_obj: Dict[str, str]): err_msg_start = 'Cannot deserialize {} object'.format(cls.__name__) try: # Sanity check serialized structure - for key in cls._SERIAL_KEYS_REQUIRED: - if key not in json_obj: - raise RuntimeError('{}: missing required serial {} key'.format(err_msg_start, key)) - # Parse the generation time - gen_time_str = json_obj[cls._SERIAL_KEY_GENERATION_TIME] - try: - gen_time_val = datetime.datetime.strptime(gen_time_str, cls.get_datetime_str_format()) - except: - raise RuntimeError('{}: invalid format for generation time ({})'.format(err_msg_start, gen_time_str)) - # Create the instance, passing serialize values for directory and name - try: - new_obj = RsaKeyPair(directory=json_obj[cls._SERIAL_KEY_DIRECTORY], name=json_obj[cls._SERIAL_KEY_NAME]) - except ValueError as ve: - raise RuntimeError('{}: problem with directory - {}'.format(err_msg_start, str(ve))) - # Manually set the generation time attribute - new_obj._generation_time = gen_time_val - # Set the private key value from serialized data - priv_key_str = json_obj[cls._SERIAL_KEY_PRIVATE_KEY] - priv_key_bytes = priv_key_str.encode('utf-8') - new_obj._priv_key = serialization.load_pem_private_key(priv_key_bytes, None, default_backend()) - # Remove any existing private/public key files unless the contents match serialized private key value - if new_obj.private_key_file.exists(): - try: - with new_obj.private_key_file.open('rb') as priv_key_file: - priv_key_file_bytes = priv_key_file.read() - if priv_key_file_bytes != priv_key_bytes: - raise RuntimeError('clear key file') - except: - new_obj.public_key_file.unlink(missing_ok=True) - new_obj.private_key_file.unlink() - elif new_obj.public_key_file.exists(): - # Always remove an existing public key file if there was not a private key file - new_obj.public_key_file.unlink() - # Finally, return the instance - return new_obj - - except RuntimeError as e: + for field in cls.__fields__.values(): + if field.alias not in json_obj: + raise RuntimeError('{}: missing required serial {} key'.format(err_msg_start, field.alias)) + + o = cls(**json_obj) + o._is_deserialized = True + return o + except: # TODO: log error return None - def __eq__(self, other: 'RsaKeyPair') -> bool: + def __eq__(self, other: Self) -> bool: return other is not None \ and self.generation_time == other.generation_time \ - and self._get_private_key_text() == other._get_private_key_text() \ + and self._private_key_text == other._private_key_text \ and self.private_key_file.absolute() == other.private_key_file.absolute() - def __ge__(self, other): + def __ge__(self, other: Self): return not self < other - def __gt__(self, other): + def __gt__(self, other: Self): return not self <= other - def __init__(self, directory: Union[str, Path, None], name: str = 'id_rsa'): + def __init__(self, directory: Union[str, Path, None], name: str = "id_rsa", **data): """ Initialize an instance. - Initializing an instance, setting the ``directory`` and ``name`` properties, and creating the other required - backing attributes used by the object, setting them to ``None`` (except for ::attribute:`_files_written`, which - is set to ``False``. - Parameters ---------- directory : str, Path, None The path (either as a :class:`Path` or string) to the parent directory for the backing key files, or - ``None`` if the default of ``.ssh/`` in the user's home directory should be used. + ``None`` if the default of ``{$HOME}/.ssh/`` should be used. name : str The name to use for the key pair, which will also be the basename of the private key file and the basis of the basename of the public key file (``id_rsa`` by default). """ - self._name = name.strip() - if self._name is None or len(self._name) < 1: - raise ValueError("Invalid key pair name") - - self.directory = directory - - self._public_key_file = None - self._private_key_file = None - - self._priv_key = None - self._priv_key_pem = None - self._pub_key = None - - self._private_key_text = None - self._public_key_text = None - - self._is_deserialized = None - self._generation_time = None - self._files_written = False - # Track whether actually in the process of writing something already, to not double-write during lazy load - self._is_writing_private_file = False - self._is_writing_public_file = False + # If `data` exists, we assume we are deserializing a message with all required fields. + # NOTE: method, `factory_init_from_deserialized_json`, verifies that all fields are passed + # before trying to initialize. + if data: + super().__init__( + directory=directory, + name=name, + **data + ) + # indirectly set `_private_key` property of parent class `_RsaKeyPair`. + # as a result a public key file will not be created during initialization even if a + # private key file exists and its contents match the passed `private_key` field and a + # public key file does not exist. + self._priv_key = self.private_key + self._delete_existing_key_files_if_priv_keys_differ() + + # If `data` does not exists, partially initialize using fields we have, then derive / create + # all required byt unspecified fields. Then, fully initialize. + else: + key_pair = _RsaKeyPair(directory=directory, name=name) + # lazily generate or load private key + private_key = key_pair._private_key + # could raise `RuntimeError` + key_pair._delete_existing_key_files_if_priv_keys_differ() + key_pair.write_key_files() + generation_time = key_pair._read_private_key_ctime(key_pair.private_key_file) + + super().__init__( + directory=directory, + name=name, + private_key=private_key, + generation_time=generation_time, + ) + + # transfer how the key pair was created + self._is_deserialized = key_pair._is_deserialized + # no one should access this directly nor through property, `_private_key`, but just in case. + self._priv_key = self.private_key def __hash__(self) -> int: - hash_str = '{}:{}:{}'.format(self._get_private_key_text(), + hash_str = '{}:{}:{}'.format(self._private_key_text, str(self.private_key_file.absolute()), self.generation_time.strftime(self.get_datetime_str_format())) - return hash_str.__hash__() + return hash(hash_str) - def __le__(self, other: 'RsaKeyPair') -> bool: + def __le__(self, other: Self) -> bool: return self == other or self < other - def __lt__(self, other: 'RsaKeyPair') -> bool: + def __lt__(self, other: Self) -> bool: if self.generation_time != other.generation_time: return self.generation_time < other.generation_time - elif self._get_private_key_text != other._get_private_key_text: - return self._get_private_key_text < other._get_private_key_text + elif self._private_key_text != other._private_key_text: + return self._private_key_text < other._private_key_text else: return self.private_key_file.absolute() < other.private_key_file.absolute() - def _get_private_key_text(self): - if self._private_key_text is None: - self._load_key_text() - return self._private_key_text - - def _load_key_text(self): - if self._private_key_text is None: - self._private_key_text = self.private_key_pem.decode('utf-8') - if self._public_key_text is None: - self._public_key_text = self.public_key.decode('utf-8') - - def _read_private_key_ctime(self, skip_file_exists_check=False): - if skip_file_exists_check or self.private_key_file.exists(): - return datetime.datetime.fromtimestamp(os.path.getctime(str(self.private_key_file))) - else: - return None - - def delete_key_files(self) -> tuple: - """ - Delete the files at the paths specified by :attr:`private_key_file` and :attr:`public_key_file`, as long as - there is an existing, regular (i.e., from :method:`Path.is_file`) file at the individual paths. - - Note that whether a delete is performed for one file is independent of what the state of the other. I.e., if - the private key file does not exist, thus resulting in no attempt to delete it, this will not affect whether - there is a delete operation on the public key file. - - Returns - ------- - tuple - A tuple of boolean values, representing whether the private key file and the public key file respectively - were deleted - """ - deleted_private = False - deleted_public = False - if self.private_key_file.exists() and self.private_key_file.is_file(): - self.private_key_file.unlink() - deleted_private = True - if self.public_key_file.exists() and self.public_key_file.is_file(): - self.public_key_file.unlink() - deleted_public = True - return deleted_private, deleted_public - - @property - def directory(self) -> Path: - """ - The directory in which the key pair files have been or will be written, as a :class:`Path`. - - The property getter will lazily instantiate the backing attribute to ``/.ssh/`` if the attribute is - set to ``None``. This is done using the property setter function, thus triggering its potential side effects. - - The property setter will accept string or ::class:`Path` objects, as well as ``None``. - - The setter may, as a side effect, create the directory represented by the argument in the filesystem. This is - done in cases when a valid argument other than ``None`` is received, and no file or directory currently exists - in the file system at that path. For string arguments, the string is first stripped of whitespace and converted - to a ::class:`Path` object before checking if the directory should be created. All of this logic is executed - before setting the backing attribute, so if an error is raised, then the attribute value will not be modified. - - In particular, if the setter receives an argument representing a path to an existing, non-directory file, then a - the setter will raise ::class:`ValueError`, and the attribute will remain unchanged. - - Returns - ------- - Path - The directory in which the key pair files have been or will be written - """ - if self._directory is None: - self.directory = Path.home().joinpath(".ssh") - return self._directory - - @directory.setter - def directory(self, d: Union[str, Path, None]): - # Make sure we are working with either None or the equivalent Path object for a path as a string - d_path = Path(d.strip()) if isinstance(d, str) else d - if d_path is not None: - if not d_path.exists(): - d_path.mkdir() - elif not d_path.is_dir(): - raise ValueError("Existing non-directory file at path provided for key pair directory") - self._directory = d_path - @property - def generation_time(self): - if self._generation_time is None: - if not self.private_key_file.exists(): - self.write_key_files() - self._generation_time = self._read_private_key_ctime(skip_file_exists_check=True) - return self._generation_time + def _private_key_text(self) -> str: + if self.__private_key_text is None: + self.__private_key_text = self.private_key_pem.decode("utf-8") + return self.__private_key_text # type: ignore @property def is_deserialized(self) -> bool: @@ -278,159 +413,21 @@ def is_deserialized(self) -> bool: Whether this object was deserialized from an already-existing file or serialized object, as opposed to being created and dynamically generating its keys. - pre: self._is_deserialized is not None or self._priv_key is None - - post: self._is_deserialized is not None and self._priv_key is not None - Returns ------- bool Whether this object was created from a pre-existing private key file """ - if self._is_deserialized is None: - # We don't actually need the value directly, but the lazy instantiation will set _is_deserialized as a side- - # effect, since it intrinsically has to determine whether it can/should deserialized the private key - priv_key = self.private_key return self._is_deserialized @property - def name(self): - return self._name - - @property - def private_key(self) -> RSAPrivateKeyWithSerialization: - """ - Get the private key for this key pair object, lazily instantiating if necessary either through deserialization - or by dynamically generating a key. - - Note that, since lazy instantiation requires determining if the value should be deserialized, the attribute - backing the :attr:`is_deserialized` property is set as a side effect when performing that step. - - post: self._is_deserialized is not None - - Returns - ------- - RSAPrivateKeyWithSerialization - The actual RSA private key object - """ - if self._priv_key is None and self.private_key_file.exists(): - with self.private_key_file.open('rb') as priv_key_file: - self._priv_key = serialization.load_pem_private_key(priv_key_file.read(), None, default_backend()) - if not self.public_key_file.exists(): - self.write_key_files(write_private=False) - self._files_written = True - self._is_deserialized = True - elif self._priv_key is None: - self._priv_key = rsa.generate_private_key(backend=default_backend(), public_exponent=65537, key_size=3072) - return self._priv_key - - @property - def private_key_file(self) -> Path: - """ - Get the path to the private key file, lazily instantiating using the :attr:`name` and :method:`directory`. - - Returns - ------- - Path - The path to the private key file - """ - if self._private_key_file is None: - self._private_key_file = None if self.directory is None else self.directory.joinpath(self._name) - return self._private_key_file - - @property - def private_key_pem(self): - if self._priv_key_pem is None: - self._priv_key_pem = self.private_key.private_bytes(encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption()) - return self._priv_key_pem - - @property - def public_key(self): + def public_key(self) -> bytes: if self._pub_key is None: - self._pub_key = self.private_key.public_key().public_bytes(serialization.Encoding.OpenSSH, - serialization.PublicFormat.OpenSSH) + self._pub_key = self._public_key_bytes_from_private_key(self.private_key) return self._pub_key - @property - def public_key_file(self) -> Path: - """ - Get the path to the public key file, lazily instantiating based on the :attr:`name` and :method:`directory`. - - Returns - ------- - Path - The path to the public key file - """ - if self._public_key_file is None: - self._public_key_file = None if self.directory is None else self.directory.joinpath(self._name + '.pub') - return self._public_key_file - - def to_dict(self) -> Dict[str, str]: - """ - Serialize to a dictionary representation of string keys and values. - - The format is as follows: - - { - 'name': 'name_value', - 'directory': 'directory_path_as_string', - 'private_key': 'private_key_text', - 'generation_time': 'generation_time_str' - } - - Returns - ------- - Dict[str, str] - The serialized form of this instance as a dictionary object with string keys and string values. - """ - return { - self._SERIAL_KEY_NAME: self.name, - self._SERIAL_KEY_DIRECTORY: str(self.directory), - self._SERIAL_KEY_PRIVATE_KEY: self._get_private_key_text(), - self._SERIAL_KEY_GENERATION_TIME: self.generation_time.strftime(self.get_datetime_str_format()) - } - - def write_key_files(self, write_private=True, write_public=True): - """ - Write private and/or public keys to files at :attr:`private_key_file` and :attr:`public_key_file` respectively, - assuming the respective file does not already exist. - - Parameters - ---------- - write_private : bool - An option, ``True`` by default, for whether the private key should be written to :attr:`private_key_file` - - write_public : bool - An option, ``True`` by default, for whether the public key should be written to :attr:`public_key_file` - """ - # Keep track of whether we are in the process of writing public/private files. - # Also, adjust parameter values based on whether this is nested inside another call due to lazy loading. - # I.e., both the param and the corresponding instance variable will only be True for the highest applicable - # call/scope in the stack. - if self._is_writing_private_file: - write_private = False - else: - self._is_writing_private_file = write_private - - if self._is_writing_public_file: - write_public = False - else: - self._is_writing_public_file = write_public - - # Next, actually perform the writes, loading things as necessary via property getters - try: - self._load_key_text() - if write_private and not self.private_key_file.exists(): - self.private_key_file.write_text(self._get_private_key_text()) - self._is_deserialized = False - if write_public and not self.public_key_file.exists(): - self.public_key_file.write_text(self._public_key_text) - finally: - # Finally, put back instance values to False appropriately if True and the param is True (indicating this is - # the highest call in the stack and should not be skipped for the public/private key file) - if self._is_writing_private_file and write_private: - self._is_writing_private_file = False - if self._is_writing_public_file and write_public: - self._is_writing_public_file = False + def write_key_files(self, write_private: bool = True, write_public: bool = True): + super().write_key_files(write_private=write_private, write_public=write_public) + if write_private: + # update generation time + self.generation_time = self._read_private_key_ctime(self.private_key_file) From 4071df04c64151219ddca2ce39635f4e657fb3f3 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Thu, 2 Feb 2023 18:27:08 -0500 Subject: [PATCH 163/205] add RsaKeyPair unit tests --- .../scheduler/dmod/test/test_rsa_key_pair.py | 136 +++++++++++++++++- 1 file changed, 130 insertions(+), 6 deletions(-) diff --git a/python/lib/scheduler/dmod/test/test_rsa_key_pair.py b/python/lib/scheduler/dmod/test/test_rsa_key_pair.py index e4d1276ec..768ffaa06 100644 --- a/python/lib/scheduler/dmod/test/test_rsa_key_pair.py +++ b/python/lib/scheduler/dmod/test/test_rsa_key_pair.py @@ -1,16 +1,22 @@ import unittest +from tempfile import TemporaryDirectory from ..scheduler.rsa_key_pair import RsaKeyPair +from typing import Dict class TestRsaKeyPair(unittest.TestCase): def setUp(self) -> None: - self.rsa_key_pairs = dict() + self.rsa_key_pairs: Dict[int, RsaKeyPair] = dict() self.rsa_key_pairs[1] = RsaKeyPair(directory='.', name='id_rsa_1') + self.serial_rsa_key_pairs: Dict[int, dict] = dict() + self.serial_rsa_key_pairs[1] = self.rsa_key_pairs[1].to_dict() + + def tearDown(self) -> None: - self.rsa_key_pairs[1].private_key_file.unlink() - self.rsa_key_pairs[1].public_key_file.unlink() + self.rsa_key_pairs[1].private_key_file.unlink(missing_ok=True) + self.rsa_key_pairs[1].public_key_file.unlink(missing_ok=True) def test_generate_key_pair_1_a(self): """ @@ -37,7 +43,7 @@ def test_generate_key_pair_1_c(self): # This should result in the same file names as key_pair, and so the constructor should resolve that it needs to # load the key, not regenerate it reserialized_key = RsaKeyPair(directory=key_pair.directory, name=key_pair.name) - self.assertTrue(key_pair, reserialized_key) + self.assertEqual(key_pair, reserialized_key) def test_generate_key_pair_1_d(self): """ @@ -49,7 +55,7 @@ def test_generate_key_pair_1_d(self): # This should result in the same file names as key_pair, and so the constructor should resolve that it needs to # load the key, not regenerate it reserialized_key = RsaKeyPair(directory=key_pair.directory, name=key_pair.name) - self.assertTrue(key_pair.private_key, reserialized_key.private_key) + self.assertEqual(key_pair.private_key, reserialized_key.private_key) def test_generate_key_pair_1_e(self): """ @@ -61,4 +67,122 @@ def test_generate_key_pair_1_e(self): # This should result in the same file names as key_pair, and so the constructor should resolve that it needs to # load the key, not regenerate it reserialized_key = RsaKeyPair(directory=key_pair.directory, name=key_pair.name) - self.assertTrue(key_pair.private_key_pem, reserialized_key.private_key_pem) + self.assertEqual(key_pair.private_key_pem, reserialized_key.private_key_pem) + + def test_generate_key_pair_1_from_dict_a(self): + """ + """ + key_pair = self.rsa_key_pairs[1] + key_pair.write_key_files() + # This should result in the same file names as key_pair, and so the constructor should resolve that it needs to + # load the key, not regenerate it + + key_pair_dict = key_pair.to_dict() + key_pair_from_dict = RsaKeyPair.factory_init_from_deserialized_json(key_pair_dict) + self.assertEqual(key_pair_from_dict, key_pair) + + def test_delete_key_files(self): + """ + Verify that the `delete_key_files` method deletes both public and private key _if they + exist_ to start with. + """ + key_pair = self.rsa_key_pairs[1] + key_pair.delete_key_files() + self.assertFalse(key_pair.private_key_file.exists()) + self.assertFalse(key_pair.public_key_file.exists()) + + def test_factory_init_from_deserialized_json_does_not_write_key_files_on_init(self): + """ + verify key files are not created if they do not already exist on factory init. + """ + key_pair = self.rsa_key_pairs[1] + kp_as_dict = key_pair.to_dict() + key_pair.delete_key_files() + + kp_from_factory = RsaKeyPair.factory_init_from_deserialized_json(kp_as_dict) + + assert kp_from_factory is not None + self.assertFalse(kp_from_factory.private_key_file.exists()) + self.assertFalse(kp_from_factory.public_key_file.exists()) + + def test_factory_init_from_deserialized_json_verifies_private_key_matches_successfully(self): + """ + verify key files are not created if they do not already exist on factory init. + """ + key_pair = self.rsa_key_pairs[1] + self.assertTrue(key_pair.private_key_file.exists()) + + kp_as_dict = key_pair.to_dict() + kp_from_factory = RsaKeyPair.factory_init_from_deserialized_json(kp_as_dict) + assert kp_from_factory is not None + # this should have been called in __init__ + kp_from_factory._delete_existing_key_files_if_priv_keys_differ() # type: ignore + self.assertTrue(kp_from_factory.private_key_file.exists()) + + def test_factory_init_from_deserialized_json_does_not_write_pub_key_file_when_priv_exists(self): + """ + verify pub key file is not created by factory init if priv key file already exists. + are not created if they do not already exist on factory init. + """ + key_pair = self.rsa_key_pairs[1] + self.assertTrue(key_pair.private_key_file.exists()) + self.assertTrue(key_pair.public_key_file.exists()) + + key_pair.public_key_file.unlink(missing_ok=True) + self.assertFalse(key_pair.public_key_file.exists()) + + kp_as_dict = key_pair.to_dict() + kp_from_factory = RsaKeyPair.factory_init_from_deserialized_json(kp_as_dict) + assert kp_from_factory is not None + + # main concern being tested + self.assertFalse(kp_from_factory.public_key_file.exists()) + + self.assertTrue(kp_from_factory.private_key_file.exists()) + + def test_factory_init_from_deserialized_json_is_deserialized(self): + """ + verify object `is_deserialized` property is true on factory init with no key files on disk. + """ + key_pair = self.rsa_key_pairs[1] + + kp_as_dict = key_pair.to_dict() + + # remove key files + key_pair.delete_key_files() + self.assertFalse(key_pair.private_key_file.exists()) + self.assertFalse(key_pair.public_key_file.exists()) + + kp_from_factory = RsaKeyPair.factory_init_from_deserialized_json(kp_as_dict) + assert kp_from_factory is not None + + # main concern being tested + self.assertTrue(kp_from_factory.is_deserialized) + + def test_factory_init_from_deserialized_json_is_deserialized_with_key_files_present(self): + """ + verify object `is_deserialized` property is true on factory init key files on disk. + """ + key_pair = self.serial_rsa_key_pairs[1] + + kp_from_factory = RsaKeyPair.factory_init_from_deserialized_json(key_pair) + assert kp_from_factory is not None + + # main concern being tested + self.assertTrue(kp_from_factory.is_deserialized) + + def test_is_deserialized_is_false_when_key_is_generated(self): + """ + verify object `is_deserialized` property is false when key is generated. + """ + with TemporaryDirectory() as dir: + key_pair = RsaKeyPair(directory=dir, name="test_is_deserialized") + self.assertFalse(key_pair.is_deserialized) + + def test_is_deserialized_is_true_when_key_is_present(self): + """ + verify object `is_deserialized` property is false when key is generated. + """ + key_pair = self.rsa_key_pairs[1] + kp = RsaKeyPair(directory=key_pair.directory, name=key_pair.name) + self.assertTrue(kp.is_deserialized) From 6327246d20bd2b0f8e76ace1edbe87e8b32e192b Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 3 Feb 2023 09:08:56 -0500 Subject: [PATCH 164/205] restore unit test assertion condition --- python/lib/scheduler/dmod/test/test_rsa_key_pair.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/scheduler/dmod/test/test_rsa_key_pair.py b/python/lib/scheduler/dmod/test/test_rsa_key_pair.py index 768ffaa06..16aa51da0 100644 --- a/python/lib/scheduler/dmod/test/test_rsa_key_pair.py +++ b/python/lib/scheduler/dmod/test/test_rsa_key_pair.py @@ -55,7 +55,7 @@ def test_generate_key_pair_1_d(self): # This should result in the same file names as key_pair, and so the constructor should resolve that it needs to # load the key, not regenerate it reserialized_key = RsaKeyPair(directory=key_pair.directory, name=key_pair.name) - self.assertEqual(key_pair.private_key, reserialized_key.private_key) + self.assertTrue(key_pair.private_key, reserialized_key.private_key) def test_generate_key_pair_1_e(self): """ From 755b25449dca52ee16224fd7217bb17a0e822b2a Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 3 Feb 2023 09:09:48 -0500 Subject: [PATCH 165/205] validate RsaKeyPair property assignment. Required to retain previous behaviour --- python/lib/scheduler/dmod/scheduler/rsa_key_pair.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/lib/scheduler/dmod/scheduler/rsa_key_pair.py b/python/lib/scheduler/dmod/scheduler/rsa_key_pair.py index 35c30edf5..fd24f1f44 100644 --- a/python/lib/scheduler/dmod/scheduler/rsa_key_pair.py +++ b/python/lib/scheduler/dmod/scheduler/rsa_key_pair.py @@ -271,6 +271,7 @@ def _validate_private_key(cls, value: Union[str, RSAPrivateKeyWithSerialization class Config: # type: ignore arbitrary_types_allowed = True + validate_assignment = True def _serialize_datetime(self: "RsaKeyPair", value: datetime.datetime) -> str: return value.strftime(self.get_datetime_str_format()) From dbce96aaa695fe0eec639b7fee7c977d8f34ef87 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 3 Feb 2023 09:11:11 -0500 Subject: [PATCH 166/205] testing: verify modifying RsaKeyPair directory through a setter behaves as expected --- .../scheduler/dmod/test/test_rsa_key_pair.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/python/lib/scheduler/dmod/test/test_rsa_key_pair.py b/python/lib/scheduler/dmod/test/test_rsa_key_pair.py index 16aa51da0..7bc393f6c 100644 --- a/python/lib/scheduler/dmod/test/test_rsa_key_pair.py +++ b/python/lib/scheduler/dmod/test/test_rsa_key_pair.py @@ -1,4 +1,5 @@ import unittest +from pathlib import Path from tempfile import TemporaryDirectory from ..scheduler.rsa_key_pair import RsaKeyPair from typing import Dict @@ -186,3 +187,31 @@ def test_is_deserialized_is_true_when_key_is_present(self): key_pair = self.rsa_key_pairs[1] kp = RsaKeyPair(directory=key_pair.directory, name=key_pair.name) self.assertTrue(kp.is_deserialized) + + def test_reassign_directory_to_default(self): + """ + verify object `is_deserialized` property is false when key is generated. + """ + key_pair = self.rsa_key_pairs[1] + default_location = Path.home() / ".ssh" + self.assertNotEqual(key_pair.directory, default_location) + + key_pair.directory = None + self.assertEqual(key_pair.directory, default_location) + + def test_reassign_directory_creates_directory_if_not_exist(self): + """ + verify object `is_deserialized` property is false when key is generated. + """ + key_pair = self.rsa_key_pairs[1] + with TemporaryDirectory() as dir: + dir = Path(dir) + new_dir = dir / ".ssh" + + self.assertFalse(new_dir.exists()) + self.assertNotEqual(key_pair.directory, new_dir) + + key_pair.directory = new_dir + + self.assertTrue(new_dir.exists()) + self.assertEqual(key_pair.directory, new_dir) From 005fbe45d11df396cc3786031231223e237b9c93 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 6 Feb 2023 13:31:27 -0500 Subject: [PATCH 167/205] add class var type hint to JobStatus --- python/lib/scheduler/dmod/scheduler/job/job.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/scheduler/dmod/scheduler/job/job.py b/python/lib/scheduler/dmod/scheduler/job/job.py index 65a4e9901..a8c05103b 100644 --- a/python/lib/scheduler/dmod/scheduler/job/job.py +++ b/python/lib/scheduler/dmod/scheduler/job/job.py @@ -221,7 +221,7 @@ class JobStatus(Serializable): """ Representation of a ::class:`Job`'s status as a combination of phase and exec step. """ - _NAME_DELIMITER = ':' + _NAME_DELIMITER: ClassVar[str] = ':' # NOTE: `None` is valid input, default value for field will be used. phase: Optional[JobExecPhase] = Field(JobExecPhase.UNKNOWN) From c5407cdf04f358b8f84e19e9269809d32572bc07 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 6 Feb 2023 13:32:06 -0500 Subject: [PATCH 168/205] JobStatus `name` property now uses fstring --- python/lib/scheduler/dmod/scheduler/job/job.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/lib/scheduler/dmod/scheduler/job/job.py b/python/lib/scheduler/dmod/scheduler/job/job.py index a8c05103b..ae5cdb17d 100644 --- a/python/lib/scheduler/dmod/scheduler/job/job.py +++ b/python/lib/scheduler/dmod/scheduler/job/job.py @@ -336,8 +336,7 @@ def job_exec_step(self) -> JobExecStep: @property def name(self) -> str: - return self.job_exec_phase.name + self._NAME_DELIMITER + self.job_exec_step.name - + return f"{self.job_exec_phase.name}{self._NAME_DELIMITER}{self.job_exec_step.name}" class Job(Serializable, ABC): """ From c57f4be9eed4eaf65fff89e9dfc72facf0084fe9 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 6 Feb 2023 13:32:40 -0500 Subject: [PATCH 169/205] ignore type hints that static checker cannot catch --- python/lib/scheduler/dmod/scheduler/job/job.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/lib/scheduler/dmod/scheduler/job/job.py b/python/lib/scheduler/dmod/scheduler/job/job.py index ae5cdb17d..37ed5b186 100644 --- a/python/lib/scheduler/dmod/scheduler/job/job.py +++ b/python/lib/scheduler/dmod/scheduler/job/job.py @@ -328,11 +328,11 @@ def is_interrupted(self) -> bool: @property def job_exec_phase(self) -> JobExecPhase: - return self.phase + return self.phase # type: ignore @property def job_exec_step(self) -> JobExecStep: - return self.step + return self.step # type: ignore @property def name(self) -> str: From 15c16bd3cede9486cdfdc0cc70de1934830a460f Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 6 Feb 2023 13:33:46 -0500 Subject: [PATCH 170/205] refactor Job to use pydantic --- .../lib/scheduler/dmod/scheduler/job/job.py | 328 +++++++----------- 1 file changed, 128 insertions(+), 200 deletions(-) diff --git a/python/lib/scheduler/dmod/scheduler/job/job.py b/python/lib/scheduler/dmod/scheduler/job/job.py index 37ed5b186..45650bc10 100644 --- a/python/lib/scheduler/dmod/scheduler/job/job.py +++ b/python/lib/scheduler/dmod/scheduler/job/job.py @@ -22,6 +22,9 @@ import logging +# SAFETY: tuple can be used in this context because this sentinel is being used to verify if the data is being +# deserialized from json. Tuple's are not datatypes in json or deserialized json. +JOB_CLASS_SENTINEL = tuple() class JobExecStep(PydanticEnum): """ @@ -348,8 +351,111 @@ class Job(Serializable, ABC): The hash value of a job is calculated as the hash of it's ::attribute:`job_id`. """ + allocation_paradigm: AllocationParadigm + """The ::class:`AllocationParadigm` type value that was used or should be used to make allocations.""" + + allocation_priority: int = 0 + """A score for how this job should be prioritized with respect to allocation.""" + + allocations: Optional[List[ResourceAllocation]] + """The scheduler resource allocations for this job, or ``None`` if it is queued or otherwise not yet allocated.""" + + cpu_count: int = Field(gt=0) + """The number of CPUs for this job.""" + + data_requirements: List[DataRequirement] = Field(default_factory=list) + """List of ::class:`DataRequirement` objects representing all data needed for the job.""" + + job_id: str = Field(default_factory=lambda: str(uuid_func())) + """The unique identifier for this particular job.""" + + last_updated: datetime = Field(default_factory=datetime.now) + """ The last time this objects state was updated.""" + + memory_size: int = Field(gt=0) + """The amount of the memory needed for this job.""" + + # TODO: do we need to account for jobs for anything other than model exec? + model_request: ExternalRequest + """The underlying configuration for the model execution that is being requested.""" + + partition_config: Optional[PartitionConfig] + """This job's partitioning configuration.""" + + rsa_key_pair: Optional[RsaKeyPair] + """The ::class:`'RsaKeyPair'` for this job's shared SSH RSA keys, or ``None`` if not has been set.""" + + status: JobStatus = Field(default_factory=lambda: JobStatus(JobExecPhase.INIT)) + """The ::class:`JobStatus` of this object.""" + + job_class: Type[Self] = JOB_CLASS_SENTINEL + """A type or subtype of ::class:`Self`. This can be provided as a str (e.g. "Job"), but will be coerced into a Type + object. Class names, not including module namespace, are used when coercing from a str into a Type (i.e. "job.Job" + is invalid; "Job" is valid). This field is required when factory deserializing from a dictionary. The field defaults + to the type of Self when programmatically creating an instance. It may be possible to specify a `job_class` during + programmatic initialization, however that capability is subtype dependent. + + Notably, the `job_class` field of subtypes of Job are also covariant in Self. Meaning, the `job_class` field of a + subtype S can only be S or a subtype of S. Sibling and super types of S are not allowed. + """ + @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): + def _subclass_search(cls, t: Union[str, Any]) -> Optional[Type[Self]]: + if isinstance(t, str): + # base case + if t == cls.__name__: + return cls + + current_level: List[Type[Self]] = cls.__subclasses__() + # bfs subclass search + while True: + next_level: List[Type[Self]] = list() + for subclass in current_level: + if t == subclass.__name__: + return subclass + next_level.extend(subclass.__subclasses__()) + + # no more levels to explore + if not next_level: + raise ValueError( + f"`t`: {t!r} must be a str with value name of Type[{cls.__name__}]. This includes subtypes of `{cls.__name__}`" + ) + + current_level = next_level + + return None + + @validator("job_class", pre=True, always=True) + def _validate_job_class(cls: Self, value: Union[str, Type[Self]]) -> Type[Self]: + # default case. Is unreachable when factory init from json. + if value is JOB_CLASS_SENTINEL: + return cls + + subclass = cls._subclass_search(value) + if subclass is not None: + return subclass + + if value == cls: + return value + + if issubclass(value, cls): + return value + + raise ValueError( + f"`job_class` field must be a Type[{cls.__name__}]. This includes subtypes of `{cls.__name__}`" + ) + + class Config: + fields = { + "partition_config": {"alias": "partitioning"} + } + field_serializers = { + "job_class": lambda cls: cls.__name__, + "last_updated": lambda self, value: value.strftime(self.get_datetime_str_format()) + } + + @classmethod + def factory_init_from_deserialized_json(cls, json_obj: dict) -> Optional[Self]: """ Factory create a new instance of the correct subtype based on a JSON object dictionary deserialized from received JSON, where this includes a ``job_class`` property containing the name of the appropriate subtype. @@ -363,36 +469,21 @@ def factory_init_from_deserialized_json(cls, json_obj: dict): A new object of the correct subtype instantiated from the deserialize JSON object dictionary, or ``None`` if this cannot be done successfully. """ - job_type_key = 'job_class' - recursive_loop_key = 'base_type_invoked_twice' + try: + if "job_class" not in json_obj: + raise KeyError("missing `job_class` field") - if job_type_key not in json_obj: - return None + subclass = cls._subclass_search(json_obj["job_class"]) + + if subclass is None: + raise ValueError("`job_class` field must be provided as a type `str`") - # Avoid accidental recursive infinite loop by adding an indicator key and bailing if we already see it - if recursive_loop_key in json_obj: + json_obj["job_class"] = subclass + return subclass(**json_obj) + except: return None - else: - json_obj[recursive_loop_key] = True - - # Traverse class type tree and get all subtypes of Job - subclasses = [] - subclasses.extend(cls.__subclasses__()) - traversed_subclasses = set() - while len(subclasses) > len(traversed_subclasses): - for s in subclasses: - if s not in traversed_subclasses: - subclasses.extend(s.__subclasses__()) - traversed_subclasses.add(s) - - for subclass in subclasses: - subclass_name = subclass.__name__ - if subclass_name == json_obj[job_type_key]: - json_obj.pop(job_type_key) - return subclass.factory_init_from_deserialized_json(json_obj) - return None - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if other is None: return False elif isinstance(other, Job): @@ -405,39 +496,12 @@ def __eq__(self, other): # infinite loop (perhaps via some shared interface where that's appropriate) return False - def __hash__(self): + def __hash__(self) -> int: return hash(self.job_id) - def __lt__(self, other): + def __lt__(self, other: "Job") -> bool: return self.allocation_priority < other.allocation_priority - @property - @abstractmethod - def allocation_paradigm(self) -> AllocationParadigm: - """ - The ::class:`AllocationParadigm` type value that was used or should be used to make allocations. - - Returns - ------- - AllocationParadigm - The ::class:`AllocationParadigm` type value that was used or should be used to make allocations. - """ - pass - - @property - @abstractmethod - def allocation_priority(self) -> int: - """ - Get a score for how this job should be prioritized with respect to allocation, with high scores being more - likely to received allocation. - - Returns - ------- - int - A score for how this job should be prioritized with respect to allocation. - """ - pass - @property @abstractmethod def allocation_service_names(self) -> Optional[Tuple[str]]: @@ -456,53 +520,28 @@ def allocation_service_names(self) -> Optional[Tuple[str]]: """ pass - @property @abstractmethod - def allocations(self) -> Optional[Tuple[ResourceAllocation]]: - """ - The resource allocations that have been allocated for this job. - - Returns - ------- - Optional[List[ResourceAllocation]] - The scheduler resource allocations for this job, or ``None`` if it is queued or otherwise not yet allocated. - """ + def set_allocations(self, allocations: List[ResourceAllocation]): pass - @allocations.setter @abstractmethod - def allocations(self, allocations: List[ResourceAllocation]): + def set_data_requirements(self, data_requirements: List[DataRequirement]): pass - @property @abstractmethod - def cpu_count(self) -> int: - """ - The number of CPUs for this job. - - Returns - ------- - int - The number of CPUs for this job. - """ + def set_partition_config(self, part_config: PartitionConfig): pass - @property @abstractmethod - def data_requirements(self) -> List[DataRequirement]: - """ - List of ::class:`DataRequirement` objects representing all data needed for the job. + def set_status(self, status: JobStatus): + pass - Returns - ------- - List[DataRequirement] - List of ::class:`DataRequirement` objects representing all data needed for the job. - """ + @abstractmethod + def set_status_phase(self, phase: JobExecPhase): pass - @data_requirements.setter @abstractmethod - def data_requirements(self, data_requirements: List[DataRequirement]): + def set_status_step(self, step: JobExecStep): pass @property @@ -518,89 +557,6 @@ def is_partitionable(self) -> bool: """ pass - @property - @abstractmethod - def job_id(self): - """ - The unique identifier for this particular job. - - Returns - ------- - The unique identifier for this particular job. - """ - pass - - @property - @abstractmethod - def last_updated(self) -> datetime: - """ - The last time this objects state was updated. - - Returns - ------- - datetime - The last time this objects state was updated. - """ - pass - - @property - @abstractmethod - def memory_size(self) -> int: - """ - The amount of the memory needed for this job. - - Returns - ------- - int - The amount of the memory needed for this job. - """ - pass - - # TODO: do we need to account for jobs for anything other than model exec? - @property - @abstractmethod - def model_request(self) -> ExternalRequest: - """ - Get the underlying configuration for the model execution that is being requested. - - Returns - ------- - ExternalRequest - The underlying configuration for the model execution that is being requested. - """ - pass - - @property - @abstractmethod - def partition_config(self) -> Optional[PartitionConfig]: - """ - Get this job's partitioning configuration. - - Returns - ------- - PartitionConfig - This job's partitioning configuration. - """ - pass - - @partition_config.setter - @abstractmethod - def partition_config(self, part_config: PartitionConfig): - pass - - @property - @abstractmethod - def rsa_key_pair(self) -> Optional['RsaKeyPair']: - """ - The ::class:`'RsaKeyPair'` for this job's shared SSH RSA keys. - - Returns - ------- - Optional['RsaKeyPair'] - The ::class:`'RsaKeyPair'` for this job's shared SSH RSA keys, or ``None`` if not has been set. - """ - pass - @property @abstractmethod def should_release_resources(self) -> bool: @@ -614,24 +570,6 @@ def should_release_resources(self) -> bool: """ pass - @property - @abstractmethod - def status(self) -> JobStatus: - """ - The ::class:`JobStatus` of this object. - - Returns - ------- - JobStatus - The ::class:`JobStatus` of this object. - """ - pass - - @status.setter - @abstractmethod - def status(self, status: JobStatus): - pass - @property def status_phase(self) -> JobExecPhase: """ @@ -644,11 +582,6 @@ def status_phase(self) -> JobExecPhase: """ return self.status.job_exec_phase - @status_phase.setter - @abstractmethod - def status_phase(self, phase: JobExecPhase): - pass - @property def status_step(self) -> JobExecStep: """ @@ -661,11 +594,6 @@ def status_step(self) -> JobExecStep: """ return self.status.job_exec_step - @status_step.setter - @abstractmethod - def status_step(self, step: JobExecStep): - pass - @property @abstractmethod def worker_data_requirements(self) -> List[List[DataRequirement]]: From 615a6aa2426f5ecfdf3dfdf531cd6490f637d304 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 6 Feb 2023 13:35:32 -0500 Subject: [PATCH 171/205] deprecate Job property setters. set through setter method instead. Pre-pydantic versions of Job class used property setters. ``` @property def prop(self) -> str: ... @prop.setter def prop(self, value: str): ... ``` Some properties with property setters mutatated state beyond just the associated property (e.g. reset `last_updated`). Pydantic will not allow you to create `@prop.setter`'s. Also, from a usage perspective, it not always clear if you can set a property without looking at documentation. To get around this, setter methods have been added to the Job class making it clear that you _can_ modify a properties value after instantiation. This also allows updating other instance internal state (i.e. `last_updated`). Likewise, to help us find and update portions of code that still set Job properties using prop setters, a DeprecationWarning is now thrown. To accomplish raising DeprecationWarnings, Job's `__setattr__` method was overridden. In short, there is a dict mapping of field names to associated setter methods that is called instead of `__setattr__`. As a side effect of this, setter methods must set field values through the instance `__dict__`. This avoids setter methods from raising a DeprecationWarning. --- .../lib/scheduler/dmod/scheduler/job/job.py | 86 +++++++++---------- 1 file changed, 40 insertions(+), 46 deletions(-) diff --git a/python/lib/scheduler/dmod/scheduler/job/job.py b/python/lib/scheduler/dmod/scheduler/job/job.py index 45650bc10..a41ab6e29 100644 --- a/python/lib/scheduler/dmod/scheduler/job/job.py +++ b/python/lib/scheduler/dmod/scheduler/job/job.py @@ -607,52 +607,46 @@ def worker_data_requirements(self) -> List[List[DataRequirement]]: """ pass - -class JobImpl(Job): - """ - Basic implementation of ::class:`Job` - - Job ids are simply the string cast of generated UUID values, stored within the ::attribute:`job_uuid` property. - """ - - @classmethod - def _parse_serialized_allocation_paradigm(cls, json_obj: dict, key: str): - paradigm = AllocationParadigm.get_from_name(name=json_obj[key], strict=True) if key in json_obj else None - if not isinstance(paradigm, AllocationParadigm): - if paradigm is None: - type_name = 'None' - else: - type_name = paradigm.__class__.__name__ - raise RuntimeError(cls._get_invalid_type_message().format(key, str.__name__, type_name)) - return paradigm - - @classmethod - def _parse_serialized_allocations(cls, json_obj: dict, key: Optional[str] = None): - if key is None: - key = 'allocations' - - if key not in json_obj: - return None - - serial_alloc_list = json_obj[key] - if not isinstance(serial_alloc_list, list): - raise RuntimeError("Invalid format for allocations list value '{}'".format(str(serial_alloc_list))) - allocations = [] - for serial_alloc in serial_alloc_list: - if not isinstance(serial_alloc, dict): - raise RuntimeError("Invalid format for allocation value '{}'".format(str(serial_alloc_list))) - allocation = ResourceAllocation.factory_init_from_dict(serial_alloc) - if not isinstance(allocation, ResourceAllocation): - raise RuntimeError( - "Unable to deserialize `{}` to resource allocation while deserializing {}".format( - str(allocation), cls.__name__)) - allocations.append(allocation) - return allocations - - @classmethod - def _parse_serialized_data_requirements(cls, json_obj: dict, key: Optional[str] = None): - if key is None: - key = 'data_requirements' + @cache + def _setter_methods(self) -> Dict[str, Callable]: + """Mapping of attribute name to setter method. This supports backwards functional compatibility.""" + # TODO: remove once migration to setters by down stream users is complete + return { + "allocations": self.set_allocations, + "data_requirements": self.set_data_requirements, + "partition_config": self.set_partition_config, + "status": self.set_status, + # derived properties + "status_phase": self.set_status_phase, + "status_step": self.set_status_step, + } + + def __setattr__(self, name: str, value: Any): + """ + Use property setter method when available. + + Note, all setter methods should modify their associated property using the instance `__dict__`. + This ensures that calls to, for example, `set_id` don't raise a warning, while `o.id = "new + id"` do. + + Example: + ``` + class SomeJob(Job): + id: str + + def set_id(self, value: str): + self.__dict__["id"] = value + ``` + """ + if name not in self._setter_methods(): + return super().__setattr__(name, value) + + setter_fn = self._setter_methods()[name] + + message = f"Setting by attribute is deprecated. Use `{self.__class__.__name__}.{setter_fn.__name__}` method instead." + warn(message, DeprecationWarning) + + setter_fn(value) if key not in json_obj: return None From 263b9ca0848519bae8482a0d9a8af4908bcbb0d7 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 6 Feb 2023 13:51:11 -0500 Subject: [PATCH 172/205] import validator and cache --- python/lib/scheduler/dmod/scheduler/job/job.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/lib/scheduler/dmod/scheduler/job/job.py b/python/lib/scheduler/dmod/scheduler/job/job.py index a41ab6e29..56a9b7771 100644 --- a/python/lib/scheduler/dmod/scheduler/job/job.py +++ b/python/lib/scheduler/dmod/scheduler/job/job.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from datetime import datetime -from numbers import Number -from pydantic import Field, validator, root_validator +from functools import cache +from pydantic import Field, PrivateAttr, validator, root_validator from pydantic.fields import ModelField from warnings import warn From b90ee6c23ab53f6fe4694d2bdd7fc7e4629feb79 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 6 Feb 2023 13:52:13 -0500 Subject: [PATCH 173/205] refactor JobImp to use pydantic --- .../lib/scheduler/dmod/scheduler/job/job.py | 585 ++++++------------ 1 file changed, 202 insertions(+), 383 deletions(-) diff --git a/python/lib/scheduler/dmod/scheduler/job/job.py b/python/lib/scheduler/dmod/scheduler/job/job.py index 56a9b7771..799b97b7a 100644 --- a/python/lib/scheduler/dmod/scheduler/job/job.py +++ b/python/lib/scheduler/dmod/scheduler/job/job.py @@ -11,17 +11,19 @@ from dmod.core.meta_data import DataRequirement from dmod.core.enum import PydanticEnum from dmod.modeldata.hydrofabric import PartitionConfig -from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING, Union +from typing_extensions import Self from uuid import UUID from uuid import uuid4 as uuid_func from ..resources import ResourceAllocation - -if TYPE_CHECKING: - from .. import RsaKeyPair +from .. import RsaKeyPair import logging +if TYPE_CHECKING: + from pydantic.typing import AbstractSetIntStr, MappingIntStrAny, DictStrAny + # SAFETY: tuple can be used in this context because this sentinel is being used to verify if the data is being # deserialized from json. Tuple's are not datatypes in json or deserialized json. JOB_CLASS_SENTINEL = tuple() @@ -648,158 +650,87 @@ def set_id(self, value: str): setter_fn(value) - if key not in json_obj: - return None +class JobImpl(Job): + """ + Basic implementation of ::class:`Job` - serial_list = json_obj[key] - if not isinstance(serial_list, list): - raise RuntimeError("Invalid format for data requirements list value '{}'".format(str(serial_list))) - data_req_list = [] - for serial_data_req in serial_list: - if not isinstance(serial_data_req, dict): - raise RuntimeError("Invalid format for data requirements value '{}'".format(str(serial_list))) - data_req = DataRequirement.factory_init_from_deserialized_json(serial_data_req) - if not isinstance(data_req, DataRequirement): - msg = "Unable to deserialize `{}` to nested data requirements while deserializing {}" - raise RuntimeError(msg.format(serial_data_req, cls.__name__)) - data_req_list.append(data_req) - return data_req_list + Job ids are simply the string cast of generated UUID values, stored within the ::attribute:`job_uuid` property. + """ - @classmethod - def _parse_serialized_job_status(cls, json_obj: dict, key: Optional[str] = None): - # Set this to the default value if it is initially None - if key is None: - key = 'status' - status_str = cls.parse_simple_serialized(json_obj=json_obj, key=key, expected_type=str, required_present=False) - if status_str is None: - return None - return JobStatus.get_for_name(name=status_str) + # NOTE: more specific ExternalRequest subtype than super class + model_request: ModelExecRequest - @classmethod - def _parse_serialized_last_updated(cls, json_obj: dict, key: Optional[str] = None): - date_str_converter = lambda date_str: datetime.strptime(date_str, cls.get_datetime_str_format()) - if key is None: - key = 'last_updated' - if key in json_obj: - return cls.parse_simple_serialized(json_obj=json_obj, key=key, expected_type=datetime, - converter=date_str_converter, required_present=False) - else: - return None + _worker_data_requirements: Optional[List[List[DataRequirement]]] = PrivateAttr(None) + _allocation_service_names: Optional[Tuple[str]] = PrivateAttr(None) - @classmethod - def _parse_serialized_partition_config(cls, json_obj: dict, key: Optional[str] = None): - if key is None: - key = 'partitioning' - if key in json_obj: - return PartitionConfig.factory_init_from_deserialized_json(json_obj[key]) - else: - return None + @validator("allocation_paradigm", pre=True) + def _parse_allocation_paradigm(cls, value: Union[AllocationParadigm, str]) -> Union[str, AllocationParadigm]: + if isinstance(value, AllocationParadigm): + return value - @classmethod - def _parse_serialized_rsa_key_pair(cls, json_obj: dict, key: Optional[str] = None, warn_if_missing: bool = False): - # Doing this here for now to avoid import errors - # TODO: find a better way for this - from .. import RsaKeyPair - - # Set this to the default value if it is initially None - if key is None: - # TODO: set somewhere globally - key = 'rsa_key_pair' - if key not in json_obj: - if warn_if_missing: - # TODO: log this better. NJF changed print to logging.warning, anything else needed? - msg = 'Warning: expected serialized RSA key at {} when deserializing {} object' - logging.warning(msg.format(key, cls.__name__)) - return None - if key not in json_obj or json_obj[key] is None: - return None - rsa_key_pair = RsaKeyPair.factory_init_from_deserialized_json(json_obj=json_obj[key]) - if rsa_key_pair is None: - raise RuntimeError('Could not deserialized child RsaKeyPair when deserializing ' + cls.__name__) - else: - return rsa_key_pair + # NOTE: potentially remove in future. There are cases in codebase where kabob case is being used. + return value.replace("-", "_") - # TODO: unit test - # TODO: consider moving this up to Job or even Serializable + @validator("status", pre=True) + def _parse_status(cls, value: Optional[Union[str, JobStatus]], field: ModelField) -> JobStatus: + if value is None: + if field.default_factory is None: + raise RuntimeError("unreachable") + return field.default_factory() - @classmethod - def deserialize_core_attributes(cls, json_obj: dict): - """ - Deserialize the core attributes of the basic ::class:`JobImpl` implementation from the provided dictionary and - return as a tuple. + if isinstance(value, JobStatus): + return value - Parameters - ---------- - json_obj + value = str(value) + return JobStatus.get_for_name(name=value) - Returns - ------- - The tuple with parse values of (cpus, memory, paradigm, priority, job_id, rsa_key_pair, status, allocations, - updated, partitioning) from the provided dictionary. - """ - int_converter = lambda x: int(x) - cpus = cls.parse_simple_serialized(json_obj=json_obj, key='cpu_count', expected_type=int, - converter=int_converter) - memory = cls.parse_simple_serialized(json_obj=json_obj, key='memory_size', expected_type=int, - converter=int_converter) - paradigm = cls._parse_serialized_allocation_paradigm(json_obj=json_obj, key='allocation_paradigm') - priority = cls.parse_simple_serialized(json_obj=json_obj, key='allocation_priority', expected_type=int, - converter=int_converter) - job_id = cls.parse_serialized_job_id(serialized_value=None, json_obj=json_obj, key='job_id') - rsa_key_pair = cls._parse_serialized_rsa_key_pair(json_obj=json_obj) - status = cls._parse_serialized_job_status(json_obj=json_obj) - allocations = cls._parse_serialized_allocations(json_obj=json_obj) - updated = cls._parse_serialized_last_updated(json_obj=json_obj) - partitioning = cls._parse_serialized_partition_config(json_obj=json_obj, key='partitioning') - return cpus, memory, paradigm, priority, job_id, rsa_key_pair, status, allocations, updated, partitioning + @validator("last_updated", pre=True) + def _parse_serialized_last_updated(cls, value: Union[str, datetime]) -> datetime: + if isinstance(value, datetime): + return value - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. + try: + value = str(value) + return datetime.strptime(value, cls.get_datetime_str_format()) + except: + return datetime.now() - Parameters - ---------- - json_obj + @validator("data_requirements", pre=True) + def _populate_default_data_requirements(cls, value: Optional[List[DataRequirement]]) -> List[DataRequirement]: + if value is None: + return list() + return value - Returns - ------- - A new object of this type instantiated from the deserialize JSON object dictionary - """ + @validator("model_request", pre=True) + def _deserialize_model_request(cls, value: Union[Dict[str, Any], ModelExecRequest]) -> ModelExecRequest: + if isinstance(value, ModelExecRequest): + return value - try: - cpus, memory, paradigm, priority, job_id, rsa_key_pair, status, allocations, updated, partitioning = \ - cls.deserialize_core_attributes(json_obj) + return ModelExecRequest.factory_init_correct_subtype_from_deserialized_json(value) - if 'model_request' in json_obj: - model_request = ModelExecRequest.factory_init_correct_subtype_from_deserialized_json(json_obj['model_request']) - else: - # TODO: add serialize/deserialize support for other situations/requests (also change 'model_request' property name) - msg = "Type {} can only support deserializing JSON containing a {} under the 'model_request' key" - raise RuntimeError(msg.format(cls.__name__, ModelExecRequest.__name__)) - - obj = cls(cpu_count=cpus, memory_size=memory, model_request=model_request, allocation_paradigm=paradigm, - alloc_priority=priority) - - if job_id is not None: - obj.job_id = job_id - if rsa_key_pair is not None: - obj.rsa_key_pair = rsa_key_pair - if status is not None: - obj.status = status - if updated is not None: - obj._last_updated = updated - if allocations is not None: - obj.allocations = allocations - obj.data_requirements = cls._parse_serialized_data_requirements(json_obj) - if partitioning is not None: - obj.partition_config = partitioning - - return obj - - except RuntimeError as e: - logging.error(e) - return None + @validator("job_id", pre=True) + def _validate_job_id(cls, value: Optional[Union[UUID, str]], field: ModelField) -> str: + if value is None: + if field.default_factory is None: + raise RuntimeError("unreachable") + return field.default_factory() + + if isinstance(value, UUID): + return str(value) + + return str(UUID(value)) + + @root_validator(pre=True) + def _parse_job_id(cls, values: Dict[str, Any]) -> Dict[str, Any]: + job_id = values.get("job_id") + if job_id is not None: + return values + + values["job_id"] = cls.parse_serialized_job_id(job_id, **values) + return values + + # TODO: unit test + # TODO: consider moving this up to Job or even Serializable @classmethod def parse_serialized_job_id(cls, serialized_value: Optional[str], **kwargs): @@ -850,46 +781,38 @@ def parse_serialized_job_id(cls, serialized_value: Optional[str], **kwargs): RuntimeError Raised if the parameter does not parse to a UUID. """ + if serialized_value is not None: + return serialized_value + key_key = 'key' - json_obj_key = 'json_obj' # First, try to obtain a serialized value, if one was not already set - if serialized_value is None and kwargs is not None and json_obj_key in kwargs and key_key in kwargs: - if isinstance(kwargs[json_obj_key], dict) and kwargs[key_key] in kwargs[json_obj_key]: - try: - serialized_value = cls.parse_simple_serialized(json_obj=kwargs[json_obj_key], key=kwargs[key_key], - expected_type=str, converter=lambda x: str(x), - required_present=False) - except: - # TODO: consider logging this - return None - # Bail here if we don't have a serialized_value to work with - if serialized_value is None: - return None - try: - return UUID(str(serialized_value)) - except ValueError as e: - msg = "Failed parsing parameter value `{}` to UUID object: {}".format(str(serialized_value), str(e)) - raise RuntimeError(msg) + if kwargs is not None and key_key in kwargs: + if kwargs[key_key] in kwargs: + return kwargs[kwargs[key_key]] + + return None + def __init__(self, cpu_count: int, memory_size: int, model_request: ExternalRequest, - allocation_paradigm: Union[str, AllocationParadigm], alloc_priority: int = 0): - self._cpu_count = cpu_count - self._memory_size = memory_size - self._model_request = model_request - if isinstance(allocation_paradigm, AllocationParadigm): - self._allocation_paradigm = allocation_paradigm - else: - self._allocation_paradigm = AllocationParadigm.get_from_name(name=allocation_paradigm) - self._allocation_priority = alloc_priority - self._job_uuid = uuid_func() - self._rsa_key_pair = None - self._status = JobStatus(JobExecPhase.INIT) - self._allocations = None - self._data_requirements = None - self._worker_data_requirements = None - self._allocation_service_names = None - self._partition_config = None + allocation_paradigm: Union[str, AllocationParadigm], alloc_priority: int = 0, **data): + if data: + super().__init__( + allocation_paradigm=allocation_paradigm, + cpu_count=cpu_count, + memory_size=memory_size, + model_request=model_request, + **data, + ) + return + + super().__init__( + allocation_paradigm=allocation_paradigm, + allocation_priority=alloc_priority, + cpu_count=cpu_count, + memory_size=memory_size, + model_request=model_request, + ) self._reset_last_updated() def _process_per_worker_data_requirements(self) -> List[List[DataRequirement]]: @@ -901,11 +824,13 @@ def _process_per_worker_data_requirements(self) -> List[List[DataRequirement]]: List[List[DataRequirement]] List (indexed analogously to worker allocations) of lists of per-worker data requirements. """ + if self.allocations is None: + return [] # TODO: implement this properly/more efficiently - return [list(self.data_requirements) for a in self.allocations] + return [list(self.data_requirements) for _ in self.allocations] def _reset_last_updated(self): - self._last_updated = datetime.now() + self.last_updated = datetime.now() def add_allocation(self, allocation: ResourceAllocation): """ @@ -917,44 +842,16 @@ def add_allocation(self, allocation: ResourceAllocation): allocation : ResourceAllocation A resource allocation object to add. """ - if self._allocations is None: - self._allocations = list() - self._allocations.append(allocation) + if self.allocations is None: + self.set_allocations(list()) + self.allocations.append(allocation) # type: ignore self._allocation_service_names = None self._reset_last_updated() - @property - def allocation_paradigm(self) -> AllocationParadigm: - """ - The ::class:`AllocationParadigm` type value that was used or should be used to make allocations. - - For this type, the value is set as a private attribute during initialization, based on the value of the - ::attribute:`SchedulerRequestMessage.allocation_paradigm` string property present within the provided - ::class:`SchedulerRequestMessage` init param. - - Returns - ------- - AllocationParadigm - The ::class:`AllocationParadigm` type value that was used or should be used to make allocations. - """ - return self._allocation_paradigm - - @property - def allocation_priority(self) -> int: - """ - A score for how this job should be prioritized with respect to allocation, with high scores being more likely to - received allocation. - - Returns - ------- - int - A score for how this job should be prioritized with respect to allocation. - """ - return self._allocation_priority - - @allocation_priority.setter - def allocation_priority(self, priority: int): - self._allocation_priority = priority + def set_allocation_priority(self, priority: int): + # NOTE: set using dict to avoid deprecation warning thrown by `__setattr__`. See `Job.__setattr__` + # docstring for more detail. + self.__dict__["allocation_priority"] = priority self._reset_last_updated() @property @@ -977,7 +874,7 @@ def allocation_service_names(self) -> Optional[Tuple[str]]: allocations. """ if self._allocation_service_names is None and self.allocations is not None and len(self.allocations) > 0: - service_names = [] + service_names: List[str] = [] # TODO: read this from request metadata base_name = "{}-worker".format(self.model_request.get_model_name()) num_allocations = len(self.allocations) @@ -986,42 +883,24 @@ def allocation_service_names(self) -> Optional[Tuple[str]]: self._allocation_service_names = tuple(service_names) return self._allocation_service_names - @property - def allocations(self) -> Optional[Tuple[ResourceAllocation]]: - return None if self._allocations is None else tuple(self._allocations) - - @allocations.setter - def allocations(self, allocations: Union[List[ResourceAllocation], Tuple[ResourceAllocation]]): + def set_allocations(self, allocations: Union[List[ResourceAllocation], Tuple[ResourceAllocation]]): if isinstance(allocations, tuple): - self._allocations = list(allocations) + # NOTE: set using dict to avoid deprecation warning thrown by `__setattr__`. See `Job.__setattr__` + # docstring for more detail. + self.__dict__["allocations"] = list(allocations) else: - self._allocations = allocations + # NOTE: set using dict to avoid deprecation warning thrown by `__setattr__`. See `Job.__setattr__` + # docstring for more detail. + self.__dict__["allocations"] = allocations self._allocation_service_names = None self._reset_last_updated() - @property - def cpu_count(self) -> int: - return self._cpu_count - - @property - def data_requirements(self) -> List[DataRequirement]: - """ - List of ::class:`DataRequirement` objects representing all data needed for the job. - - Returns - ------- - List[DataRequirement] - List of ::class:`DataRequirement` objects representing all data needed for the job. - """ - if self._data_requirements is None: - self._data_requirements = [] - return self._data_requirements - - @data_requirements.setter - def data_requirements(self, data_requirements: List[DataRequirement]): + def set_data_requirements(self, data_requirements: List[DataRequirement]): # Make sure to reset worker data requirements if this is changed self._worker_data_requirements = None - self._data_requirements = data_requirements + # NOTE: set using dict to avoid deprecation warning thrown by `__setattr__`. See `Job.__setattr__` + # docstring for more detail. + self.__dict__["data_requirements"] = data_requirements self._reset_last_updated() @property @@ -1038,68 +917,25 @@ def is_partitionable(self) -> bool: """ return self.model_request is not None and isinstance(self.model_request, NGENRequest) - @property - def job_id(self) -> Optional[str]: - """ - The unique job id for this job in the manager, if one has been set for it, or ``None``. - - The getter for the property returns the ::attribute:`UUID.bytes` field of the ::attribute:`job_uuid` property, - if it is set, or ``None`` if it is not set. - - The setter for the property will actually set the ::attribute:`job_uuid` attribute, via a call to the setter for - the ::attribute:`job_uuid` property. ::attribute:`job_id`'s setter can accept either a ::class:`UUID` or a - string, with the latter case being used to initialize a ::class:`UUID` object. - - Returns - ------- - Optional[str] - The unique job id for this job in the manager, if one has been set for it, or ``None``. - """ - return str(self._job_uuid) if isinstance(self._job_uuid, UUID) else None - - @job_id.setter - def job_id(self, job_id: Union[str, UUID]): + def set_job_id(self, job_id: Union[str, UUID]): job_uuid = job_id if isinstance(job_id, UUID) else UUID(str(job_id)) - if job_uuid != self._job_uuid: - self._job_uuid = job_uuid + job_uuid = str(job_uuid) + if job_uuid != self.job_id: + # NOTE: set using dict to avoid deprecation warning thrown by `__setattr__`. See `Job.__setattr__` + # docstring for more detail. + self.__dict__["job_id"] = job_uuid self._reset_last_updated() - @property - def memory_size(self) -> int: - return self._memory_size - - @property - def last_updated(self) -> datetime: - return self._last_updated - - @property - def model_request(self) -> ExternalRequest: - """ - Get the underlying configuration for the model execution that is being requested. - - Returns - ------- - ExternalRequest - The underlying configuration for the model execution that is being requested. - """ - return self._model_request - - @property - def partition_config(self) -> Optional[PartitionConfig]: - return self._partition_config - - @partition_config.setter - def partition_config(self, part_config: PartitionConfig): - self._partition_config = part_config - - @property - def rsa_key_pair(self) -> Optional['RsaKeyPair']: - return self._rsa_key_pair - - @rsa_key_pair.setter - def rsa_key_pair(self, key_pair: 'RsaKeyPair'): - if key_pair != self._rsa_key_pair: - self._rsa_key_pair = key_pair + def set_partition_config(self, part_config: PartitionConfig): + # NOTE: set using dict to avoid deprecation warning thrown by `__setattr__`. See `Job.__setattr__` + # docstring for more detail. + self.__dict__["partition_config"] = part_config + + def set_rsa_key_pair(self, key_pair: 'RsaKeyPair'): + if key_pair != self.rsa_key_pair: + # NOTE: set using dict to avoid deprecation warning thrown by `__setattr__`. See `Job.__setattr__` + # docstring for more detail. + self.__dict__["rsa_key_pair"] = key_pair self._reset_last_updated() @property @@ -1117,34 +953,21 @@ def should_release_resources(self) -> bool: # TODO: confirm that allocations should be maintained for stopped output jobs while in eval or calibration phase return self.status_step == JobExecStep.FAILED or self.status_phase == JobExecPhase.CLOSED - @property - def status(self) -> JobStatus: - return self._status - - @status.setter - def status(self, new_status: JobStatus): - if new_status != self._status: - self._status = new_status + def set_status(self, status: JobStatus): + if status != self.status: + # NOTE: set using dict to avoid deprecation warning thrown by `__setattr__`. See `Job.__setattr__` + # docstring for more detail. + self.__dict__["status"] = status self._reset_last_updated() - @property - def status_phase(self) -> JobExecPhase: - return super().status_phase - - @status_phase.setter - def status_phase(self, phase: JobExecPhase): - self.status = JobStatus(phase=phase, step=phase.default_start_step) - - @property - def status_step(self) -> JobExecStep: - return super().status_step + def set_status_phase(self, phase: JobExecPhase): + self.set_status(JobStatus(phase=phase, step=phase.default_start_step)) - @status_step.setter - def status_step(self, new_step: JobExecStep): - self.status = JobStatus(phase=self.status.job_exec_phase, step=new_step) + def set_status_step(self, step: JobExecStep): + self.set_status(JobStatus(phase=self.status.job_exec_phase, step=step)) @property - def worker_data_requirements(self) -> List[List[DataRequirement]]: + def worker_data_requirements(self) -> Optional[List[List[DataRequirement]]]: """ List of lists of per-worker data requirements, indexed analogously to worker allocations. @@ -1157,64 +980,60 @@ def worker_data_requirements(self) -> List[List[DataRequirement]]: self._worker_data_requirements = self._process_per_worker_data_requirements() return self._worker_data_requirements - def to_dict(self) -> dict: - """ - Get the representation of this instance as a dictionary or dictionary-like object (e.g., a JSON object). - - { - "job_class" : "", - "cpu_count" : 4, - "memory_size" : 1000, - "model_request" : {}, - "allocation_paradigm" : "SINGLE_NODE", - "allocation_priority" : 0, - "job_id" : "12345678-1234-5678-1234-567812345678", - "rsa_key_pair" : {}, - "status" : INIT:DEFAULT, - "last_updated" : "2020-07-10 12:05:45", - "allocations" : [...], - 'data_requirements" : [...], - "partitioning" : { "partitions": [ ... ] } - } - - Returns - ------- - dict - the representation of this instance as a dictionary or dictionary-like object (e.g., a JSON object) - """ - serial = dict() - - serial['job_class'] = self.__class__.__name__ - serial['cpu_count'] = self.cpu_count - serial['memory_size'] = self.memory_size - - # TODO: support other scenarios along with deserializing (maybe even eliminate RequestedJob subtype) - if isinstance(self.model_request, ModelExecRequest): - request_key = 'model_request' - else: - msg = "Type {} can only support serializing to JSON when fulfilled request is a {}" - raise RuntimeError(msg.format(self.__class__.__name__, ModelExecRequest.__name__)) - serial[request_key] = self.model_request.to_dict() - - if self.allocation_paradigm: - serial['allocation_paradigm'] = self.allocation_paradigm.name - serial['allocation_priority'] = self.allocation_priority - if self.job_id is not None: - serial['job_id'] = str(self.job_id) - if self.rsa_key_pair is not None: - serial['rsa_key_pair'] = self.rsa_key_pair.to_dict() - serial['status'] = self.status.name - serial['last_updated'] = self._last_updated.strftime(self.get_datetime_str_format()) - serial['data_requirements'] = [] - for dr in self.data_requirements: - serial['data_requirements'].append(dr.to_dict()) - if self.allocations is not None and len(self.allocations) > 0: - serial['allocations'] = [] - for allocation in self.allocations: - serial['allocations'].append(allocation.to_dict()) - if self.partition_config is not None: - serial['partitioning'] = self.partition_config.to_dict() + @cache + def _setter_methods(self) -> Dict[str, Callable]: + return { + **super()._setter_methods(), + "allocation_priority": self.set_allocation_priority, + "job_id": self.set_job_id, + "rsa_key_pair": self.set_rsa_key_pair, + } + + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note, this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = True, + ) -> "DictStrAny": + def add(*fields: str, collection: Union[Set[str], Dict[str, bool]]) -> Union[Set[str], Dict[str, bool]]: + if isinstance(collection, set): + collection_copy = {*collection} + for field in fields: + collection_copy.add(field) + return collection_copy + + elif isinstance(exclude, dict): + collection_copy = {**collection} + for field in fields: + collection_copy[field] = True + return collection_copy + + return collection + + exclude = exclude or set() + + # conditionally exclude `allocations` and `partitioning` if allocations is None or is empty + if self.allocations is None or not len(self.allocations): + exclude = add("allocations", "partitioning", collection=exclude) + + serial = super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + # serialize status as "{PHASE}:{STEP}" + if "status" not in exclude: + serial["status"] = self.status.name return serial From 204345990a91a38562bb029f4fac701beb9d7e09 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 6 Feb 2023 13:52:41 -0500 Subject: [PATCH 174/205] refactor RequestedJob to use pydantic --- .../lib/scheduler/dmod/scheduler/job/job.py | 151 +++++------------- 1 file changed, 38 insertions(+), 113 deletions(-) diff --git a/python/lib/scheduler/dmod/scheduler/job/job.py b/python/lib/scheduler/dmod/scheduler/job/job.py index 799b97b7a..f5c4a12ba 100644 --- a/python/lib/scheduler/dmod/scheduler/job/job.py +++ b/python/lib/scheduler/dmod/scheduler/job/job.py @@ -1043,119 +1043,44 @@ class RequestedJob(JobImpl): in the form of a ::class:`SchedulerRequestMessage` object. """ - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. - - Parameters - ---------- - json_obj + originating_request: SchedulerRequestMessage + """The original request that resulted in the creation of this job.""" - Returns - ------- - A new object of this type instantiated from the deserialize JSON object dictionary - """ - - originating_request_key = 'originating_request' - - try: - cpus, memory, paradigm, priority, job_id, rsa_key_pair, status, allocations, updated, partitioning = \ - cls.deserialize_core_attributes(json_obj) - - if originating_request_key not in json_obj: - msg = 'Key for originating request ({}) not present when deserialize {} object' - raise RuntimeError(msg.format(originating_request_key, cls.__name__)) - request = SchedulerRequestMessage.factory_init_from_deserialized_json(json_obj[originating_request_key]) - if request is None: - msg = 'Invalid serialized scheduler request when deserialize {} object' - raise RuntimeError(msg.format(cls.__name__)) - except Exception as e: - logging.error(e) - return None - - # Create the object initially from the request - new_obj = cls(job_request=request) - - # Then update its properties based on the deserialized values, as those are considered most correct - - # Use property setter for job id to handle string or UUID - new_obj.job_id = job_id - - new_obj._cpu_count = cpus - new_obj._memory_size = memory - new_obj._allocation_paradigm = paradigm - new_obj._allocation_priority = priority - new_obj._rsa_key_pair = rsa_key_pair - new_obj._status = status - new_obj._allocations = allocations - new_obj.data_requirements = cls._parse_serialized_data_requirements(json_obj) - new_obj._partition_config = partitioning - - # Do last_updated last, as any usage of setters above might cause the value to be maladjusted - new_obj._last_updated = updated - - return new_obj - - def __init__(self, job_request: SchedulerRequestMessage): - self._originating_request = job_request - super().__init__(cpu_count=job_request.cpus, memory_size=job_request.memory, - model_request=job_request.model_request, - allocation_paradigm=job_request.allocation_paradigm) - self.data_requirements = self.model_request.data_requirements - - @property - def model_request(self) -> ExternalRequest: - """ - Get the underlying configuration for the model execution that is being requested. - - Returns - ------- - ExternalRequest - The underlying configuration for the model execution that is being requested. - """ - return self.originating_request.model_request - - @property - def originating_request(self) -> SchedulerRequestMessage: - """ - The original request that resulted in the creation of this job. - - Returns - ------- - SchedulerRequestMessage - The original request that resulted in the creation of this job. - """ - return self._originating_request + class Config: # type: ignore + fields = { + # exclude `model_request` during serialization + "model_request": {"exclude": True} + } - def to_dict(self) -> dict: - """ - Get the representation of this instance as a dictionary or dictionary-like object (e.g., a JSON object). - - { - "job_class" : "", - "cpu_count" : 4, - "memory_size" : 1000, - "allocation_paradigm" : "SINGLE_NODE", - "allocation_priority" : 0, - "job_id" : "12345678-1234-5678-1234-567812345678", - "rsa_key_pair" : {}, - "status" : INIT:DEFAULT, - "last_updated" : "2020-07-10 12:05:45", - "allocations" : [...], - 'data_requirements" : [...], - "partitioning" : { "partitions": [ ... ] }, - "originating_request" : {} - } + def __init__(self, job_request: SchedulerRequestMessage = None, **data): + if data: + # NOTE: in previous version of code, `model_request` was always a derived field. + # this allows `model_request` be separately specified + if "model_request" in data: + super().__init__(**data) + return + + originating_request = data.get("originating_request") + if originating_request is None: + # this should fail, let pydantic handle that. + super().__init__(**data) + return + + if isinstance(originating_request, SchedulerRequestMessage): + # inject + data["model_request"] = originating_request.model_request + + data["model_request"] = originating_request.get("model_request") + super().__init__(**data) + return - Returns - ------- - dict - the representation of this instance as a dictionary or dictionary-like object (e.g., a JSON object) - """ - dictionary = super().to_dict() - # To avoid this being messy, rely on the superclass's implementation and the returned dict, but remove the - # 'model_request' key/value, since this is contained within the originating serialized scheduler request - dictionary.pop('model_request') - dictionary['originating_request'] = self.originating_request.to_dict() - return dictionary + # NOTE: consider refactoring this into `from_job_request` class method. + super().__init__( + cpu_count=job_request.cpus, + memory_size=job_request.memory, + model_request=job_request.model_request, + allocation_paradigm=job_request.allocation_paradigm, + originating_request=job_request, + ) + # NOTE: this implicitly resets `last_updated` field + self.set_data_requirements(job_request.model_request.data_requirements) From ecf7ef4b6e86a046876c324cc502192f7809c699 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 6 Feb 2023 14:04:35 -0500 Subject: [PATCH 175/205] fix rsa test bug. was not cleaning up --- python/lib/scheduler/dmod/test/test_rsa_key_pair.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/lib/scheduler/dmod/test/test_rsa_key_pair.py b/python/lib/scheduler/dmod/test/test_rsa_key_pair.py index 7bc393f6c..6615ed9b5 100644 --- a/python/lib/scheduler/dmod/test/test_rsa_key_pair.py +++ b/python/lib/scheduler/dmod/test/test_rsa_key_pair.py @@ -196,9 +196,16 @@ def test_reassign_directory_to_default(self): default_location = Path.home() / ".ssh" self.assertNotEqual(key_pair.directory, default_location) + o_pub_key = key_pair.public_key_file + o_priv_key = key_pair.private_key_file + key_pair.directory = None self.assertEqual(key_pair.directory, default_location) + # remove original public key and private key + o_priv_key.unlink(missing_ok=True) + o_pub_key.unlink(missing_ok=True) + def test_reassign_directory_creates_directory_if_not_exist(self): """ verify object `is_deserialized` property is false when key is generated. From 57782844fa3e3b97804dc4584d9f99d9cf0ba39b Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 6 Feb 2023 14:18:16 -0500 Subject: [PATCH 176/205] add type hints to tests --- python/lib/scheduler/dmod/test/test_JobImpl.py | 8 +++++--- python/lib/scheduler/dmod/test/test_job.py | 10 +++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/python/lib/scheduler/dmod/test/test_JobImpl.py b/python/lib/scheduler/dmod/test/test_JobImpl.py index 256304189..8a68e165e 100644 --- a/python/lib/scheduler/dmod/test/test_JobImpl.py +++ b/python/lib/scheduler/dmod/test/test_JobImpl.py @@ -4,6 +4,8 @@ from dmod.communication import NWMRequest from uuid import UUID +from typing import List + class TestJobImpl(unittest.TestCase): @@ -11,14 +13,14 @@ def setUp(self) -> None: self._nwm_model_request = NWMRequest.factory_init_from_deserialized_json( {"model": {"nwm": {"version": 2.0, "output": "streamflow", "domain": "blah", "parameters": {}}}, "session-secret": "f21f27ac3d443c0948aab924bddefc64891c455a756ca77a4d86ec2f697cd13c"}) - self._example_jobs = [] + self._example_jobs: List[JobImpl]= [] self._example_jobs.append(JobImpl(4, 1000, model_request=self._nwm_model_request, allocation_paradigm='single-node')) - self._uuid_str_vals = [] + self._uuid_str_vals: List[str] = [] self._uuid_str_vals.append('12345678-1234-5678-1234-567812345678') - self._resource_allocations = [] + self._resource_allocations: List[ResourceAllocation] = [] self._resource_allocations.append(ResourceAllocation('node001', 'node001', 4, 1000)) def tearDown(self) -> None: diff --git a/python/lib/scheduler/dmod/test/test_job.py b/python/lib/scheduler/dmod/test/test_job.py index 9a9f8e5cc..c7b1c533f 100644 --- a/python/lib/scheduler/dmod/test/test_job.py +++ b/python/lib/scheduler/dmod/test/test_job.py @@ -2,14 +2,18 @@ from ..scheduler.job.job import Job, JobImpl, RequestedJob from dmod.core.meta_data import TimeRange from dmod.communication import NWMRequest, NGENRequest, SchedulerRequestMessage +from typing import Any, List, TYPE_CHECKING + +if TYPE_CHECKING: + from dmod.communication import ModelExecRequest class TestJob(unittest.TestCase): def setUp(self) -> None: - self._example_jobs = [] - self._model_requests = [] - self._model_requests_json = [] + self._example_jobs: List[RequestedJob] = [] + self._model_requests: List["ModelExecRequest"]= [] + self._model_requests_json: Dict[str, Any] = [] # Example 0 - simple JobImpl instance based on NWMRequest for model_request value self._model_requests_json.append({ From 2fa262125a464f2d7aac70a2f93ddd2adb0d0ef2 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 6 Feb 2023 14:18:41 -0500 Subject: [PATCH 177/205] fix missing field in test setup --- python/lib/scheduler/dmod/test/test_JobImpl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/scheduler/dmod/test/test_JobImpl.py b/python/lib/scheduler/dmod/test/test_JobImpl.py index 8a68e165e..6f147468a 100644 --- a/python/lib/scheduler/dmod/test/test_JobImpl.py +++ b/python/lib/scheduler/dmod/test/test_JobImpl.py @@ -11,7 +11,7 @@ class TestJobImpl(unittest.TestCase): def setUp(self) -> None: self._nwm_model_request = NWMRequest.factory_init_from_deserialized_json( - {"model": {"nwm": {"version": 2.0, "output": "streamflow", "domain": "blah", "parameters": {}}}, + {"model": {"nwm": {"version": 2.0, "output": "streamflow", "domain": "blah", "parameters": {}, "config_data_id": "42"}}, "session-secret": "f21f27ac3d443c0948aab924bddefc64891c455a756ca77a4d86ec2f697cd13c"}) self._example_jobs: List[JobImpl]= [] self._example_jobs.append(JobImpl(4, 1000, model_request=self._nwm_model_request, From c19718153507eb075d310c32d18eae2fef3567d0 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 6 Feb 2023 15:55:06 -0500 Subject: [PATCH 178/205] cannot cache _setter_methods. if this becomes bottleneck, we can cache a classmethod and walk the mro --- python/lib/scheduler/dmod/scheduler/job/job.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/lib/scheduler/dmod/scheduler/job/job.py b/python/lib/scheduler/dmod/scheduler/job/job.py index f5c4a12ba..b90cf0237 100644 --- a/python/lib/scheduler/dmod/scheduler/job/job.py +++ b/python/lib/scheduler/dmod/scheduler/job/job.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from datetime import datetime -from functools import cache from pydantic import Field, PrivateAttr, validator, root_validator from pydantic.fields import ModelField from warnings import warn @@ -609,7 +608,6 @@ def worker_data_requirements(self) -> List[List[DataRequirement]]: """ pass - @cache def _setter_methods(self) -> Dict[str, Callable]: """Mapping of attribute name to setter method. This supports backwards functional compatibility.""" # TODO: remove once migration to setters by down stream users is complete @@ -980,7 +978,6 @@ def worker_data_requirements(self) -> Optional[List[List[DataRequirement]]]: self._worker_data_requirements = self._process_per_worker_data_requirements() return self._worker_data_requirements - @cache def _setter_methods(self) -> Dict[str, Callable]: return { **super()._setter_methods(), From 57267cdb9b620a6e2dce932ccebe53845d178485 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 6 Feb 2023 15:59:46 -0500 Subject: [PATCH 179/205] change Job's allocations field to optional tuple from optional list --- python/lib/scheduler/dmod/scheduler/job/job.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/lib/scheduler/dmod/scheduler/job/job.py b/python/lib/scheduler/dmod/scheduler/job/job.py index b90cf0237..2b21eeea0 100644 --- a/python/lib/scheduler/dmod/scheduler/job/job.py +++ b/python/lib/scheduler/dmod/scheduler/job/job.py @@ -358,7 +358,7 @@ class Job(Serializable, ABC): allocation_priority: int = 0 """A score for how this job should be prioritized with respect to allocation.""" - allocations: Optional[List[ResourceAllocation]] + allocations: Optional[Tuple[ResourceAllocation, ...]] """The scheduler resource allocations for this job, or ``None`` if it is queued or otherwise not yet allocated.""" cpu_count: int = Field(gt=0) @@ -841,8 +841,8 @@ def add_allocation(self, allocation: ResourceAllocation): A resource allocation object to add. """ if self.allocations is None: - self.set_allocations(list()) - self.allocations.append(allocation) # type: ignore + self.set_allocations(tuple()) + self.set_allocations((*self.allocations, allocation)) # type: ignore self._allocation_service_names = None self._reset_last_updated() @@ -882,10 +882,10 @@ def allocation_service_names(self) -> Optional[Tuple[str]]: return self._allocation_service_names def set_allocations(self, allocations: Union[List[ResourceAllocation], Tuple[ResourceAllocation]]): - if isinstance(allocations, tuple): + if isinstance(allocations, list): # NOTE: set using dict to avoid deprecation warning thrown by `__setattr__`. See `Job.__setattr__` # docstring for more detail. - self.__dict__["allocations"] = list(allocations) + self.__dict__["allocations"] = tuple(allocations) else: # NOTE: set using dict to avoid deprecation warning thrown by `__setattr__`. See `Job.__setattr__` # docstring for more detail. From ec72debbf49e71f2c02cab1eb0707a06e1b63401 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Mon, 6 Feb 2023 16:38:02 -0500 Subject: [PATCH 180/205] Resource unit tests --- .../lib/scheduler/dmod/test/test_resource.py | 124 ++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 python/lib/scheduler/dmod/test/test_resource.py diff --git a/python/lib/scheduler/dmod/test/test_resource.py b/python/lib/scheduler/dmod/test/test_resource.py new file mode 100644 index 000000000..148bcceca --- /dev/null +++ b/python/lib/scheduler/dmod/test/test_resource.py @@ -0,0 +1,124 @@ +import unittest +from pydantic import ValidationError +from .scheduler_test_utils import _mock_resources +from ..scheduler.resources import Resource, ResourceAvailability + + +class TestResource(unittest.TestCase): + def setUp(self) -> None: + self._resource = Resource( + resource_id="1", + hostname="somehost", + availability="active", + state="ready", + cpu_count=4, + memory=(2 ** 30) * 8, + total_cpu_count=8, + total_memory=(2 ** 30) * 16, + ) + + def test_factory_init_from_dict_coerces_fields_correctly(self): + for i, input in enumerate(_mock_resources): + with self.subTest(i=i): + o = Resource.factory_init_from_dict(input) + assert o.resource_id == input["node_id"] + assert o.pool_id == input["node_id"] + assert o.hostname == input["Hostname"] + assert ( + o.availability.name.casefold() == input["Availability"].casefold() + ) + assert o.state.name.casefold() == input["State"].casefold() + assert o.memory == input["MemoryBytes"] + assert o.cpu_count == input["CPUs"] + assert o.total_cpus == input["CPUs"] + assert o.total_memory == input["MemoryBytes"] + + def test_factory_init_from_dict_works_case_insensitively(self): + input = { + "NODE_ID": "Node-0003", + "hostname": "hostname3", + "AVAILABILITY": "active", + "state": "ready", + "CPUS": 42, + "memorybytes": 200000000000, + } + o = Resource.factory_init_from_dict(input) + assert o.resource_id == input["NODE_ID"] + assert o.pool_id == input["NODE_ID"] + assert o.hostname == input["hostname"] + assert o.availability.name.casefold() == input["AVAILABILITY"].casefold() + assert o.state.name.casefold() == input["state"].casefold() + assert o.memory == input["memorybytes"] + assert o.cpu_count == input["CPUS"] + assert o.total_cpus == input["CPUS"] + assert o.total_memory == input["memorybytes"] + + def test_set_availability(self): + resource = self._resource + availability = ResourceAvailability.UNKNOWN + resource.set_availability(availability) + assert resource.availability == ResourceAvailability.UNKNOWN + + availability = ResourceAvailability.ACTIVE + resource.set_availability(availability) + assert resource.availability == ResourceAvailability.ACTIVE + + availability = ResourceAvailability.INACTIVE + resource.set_availability(availability) + assert resource.availability == ResourceAvailability.INACTIVE + + resource.set_availability("unknown") + assert resource.availability == ResourceAvailability.UNKNOWN + + resource.set_availability("active") + assert resource.availability == ResourceAvailability.ACTIVE + + resource.set_availability("inactive") + assert resource.availability == ResourceAvailability.INACTIVE + + # remove in future + with self.assertWarns(DeprecationWarning): + availability = ResourceAvailability.UNKNOWN + resource.availability = availability + assert resource.availability == ResourceAvailability.UNKNOWN + + with self.assertWarns(DeprecationWarning): + availability = ResourceAvailability.ACTIVE + resource.availability = availability + assert resource.availability == ResourceAvailability.ACTIVE + + with self.assertWarns(DeprecationWarning): + availability = ResourceAvailability.INACTIVE + resource.availability = availability + assert resource.availability == ResourceAvailability.INACTIVE + + def test_eq(self): + resource = self._resource + assert resource == resource + assert resource == Resource.factory_init_from_dict(resource.to_dict()) + + def test_init_with_more_cpu_than_total_cpu(self): + with self.assertRaises(ValidationError): + Resource( + cpu_count=8, + total_cpu_count=4, + resource_id="1", + hostname="somehost", + availability="active", + state="ready", + memory=8, + total_memory=8, + ) + + def test_init_with_more_memory_than_total_memory(self): + with self.assertRaises(ValidationError): + Resource( + memory=8, + total_memory=4, + resource_id="1", + hostname="somehost", + availability="active", + state="ready", + cpu_count=8, + total_cpu_count=8, + ) From 6d5c4ca9153e95497240481de2b57a4d9ff7e2d1 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 7 Feb 2023 15:57:30 -0500 Subject: [PATCH 181/205] add JobImpl setter tests --- .../lib/scheduler/dmod/test/test_JobImpl.py | 176 +++++++++++++++++- 1 file changed, 174 insertions(+), 2 deletions(-) diff --git a/python/lib/scheduler/dmod/test/test_JobImpl.py b/python/lib/scheduler/dmod/test/test_JobImpl.py index 6f147468a..44a2f3739 100644 --- a/python/lib/scheduler/dmod/test/test_JobImpl.py +++ b/python/lib/scheduler/dmod/test/test_JobImpl.py @@ -1,5 +1,5 @@ import unittest -from ..scheduler.job.job import JobImpl +from ..scheduler.job.job import JobImpl, JobStatus, JobExecPhase, JobExecStep from ..scheduler.resources.resource_allocation import ResourceAllocation from dmod.communication import NWMRequest from uuid import UUID @@ -191,5 +191,177 @@ def test_allocations_1_f(self): self.assertLess(initial_last_updated, job.last_updated) # TODO: add tests for rest of setters that should update last_updated property + def test_set_allocation_priority(self): + """ + Update allocation priority. + This should implicitly change the instance's `last_updated` field to the current time. + """ + example_index_job = 0 + job = self._example_jobs[example_index_job] + outdated_last_updated = job.last_updated + prior_allocation_priority = job.allocation_priority + + job.set_allocation_priority(prior_allocation_priority + 1) + self.assertEqual(job.allocation_priority, prior_allocation_priority + 1) + self.assertGreater(job.last_updated, outdated_last_updated) + + def test_add_allocation(self): + """ + Test that a resource allocation is added and that the instance's `last_updated` field is implicitly updated. + """ + example_index_job = 0 + job = self._example_jobs[example_index_job] + resource_allocation = self._resource_allocations[example_index_job] + + # we should not have any allocations up to this point + self.assertIsNone(job.allocations) + outdated_last_updated = job.last_updated + + job.add_allocation(resource_allocation) + + self.assertIsNotNone(job.allocations) + self.assertIsInstance(job.allocations, tuple) + self.assertEqual(len(job.allocations), 1) # type: ignore + + self.assertEqual(job.allocations[0], resource_allocation) # type: ignore + + self.assertGreater(job.last_updated, outdated_last_updated) + + def test_set_allocations(self): + """ + Test setting resource allocations and that the instance's `last_updated` field is implicitly updated. + """ + example_index_job = 0 + job = self._example_jobs[example_index_job] + resource_allocation = self._resource_allocations[example_index_job] + + # we should not have any allocations up to this point + self.assertIsNone(job.allocations) + outdated_last_updated = job.last_updated + + job.set_allocations((resource_allocation, )) + + self.assertIsNotNone(job.allocations) + self.assertIsInstance(job.allocations, tuple) + self.assertEqual(len(job.allocations), 1) # type: ignore + + self.assertEqual(job.allocations[0], resource_allocation) # type: ignore + + # assert `last_updated` was updated and is greater than previous value + self.assertGreater(job.last_updated, outdated_last_updated) + + def test_set_data_requirements(self): + # importing here, not needed elsewhere + from dmod.core.meta_data import DataRequirement, DataCategory, DataDomain, DataFormat, DiscreteRestriction, StandardDatasetIndex + example_index_job = 0 + job = self._example_jobs[example_index_job] + + outdated_last_updated = job.last_updated + + domain = DataDomain( + data_format=DataFormat.NWM_CONFIG, + discrete=[DiscreteRestriction(variable=StandardDatasetIndex.DATA_ID, values=["42"])] + ) + data_reqs = [DataRequirement(category=DataCategory.CONFIG, domain=domain, is_input=True)] + + # data requirements should be an empty list at this point + self.assertFalse(job.data_requirements) + job.set_data_requirements(data_reqs) + + self.assertTrue(job.data_requirements) + self.assertIsInstance(job.data_requirements, list) + self.assertEqual(len(job.data_requirements), 1) # type: ignore + + # assert `last_updated` was updated and is greater than previous value + self.assertGreater(job.last_updated, outdated_last_updated) + + def test_set_job_id(self): + from uuid import UUID + example_index_job = 0 + job = self._example_jobs[example_index_job] + + fake_job_ids = ["00000000-0000-0000-0000-000000000000", UUID("11111111-1111-1111-1111-111111111111")] + + # test setting with `str` and `UUID` + for i, job_id in enumerate(fake_job_ids): + with self.subTest(i=i): + old_last_updated = job.last_updated + old_job_id = job.job_id + + self.assertIsInstance(old_job_id, str) + + job.set_job_id(job_id) + self.assertEqual(str(job_id), job.job_id) + + # assert `last_updated` was updated and is greater than previous value + self.assertGreater(job.last_updated, old_last_updated) + + def test_set_partition_config(self): + from dmod.modeldata.hydrofabric import Partition, PartitionConfig + + example_index_job = 0 + job = self._example_jobs[example_index_job] + + partition_config = PartitionConfig(partitions=[Partition(partition_id=42, catchment_ids=["42"], nexus_ids=["42"])]) + + # we should not have any partition configs up to this point + self.assertIsNone(job.partition_config) + job.set_partition_config(partition_config) + self.assertEqual(job.partition_config, partition_config) + + def test_set_rsa_key_pair(self): + from ..scheduler.rsa_key_pair import RsaKeyPair + from tempfile import TemporaryDirectory + example_index_job = 0 + job = self._example_jobs[example_index_job] + outdated_last_updated = job.last_updated + + self.assertIsNone(job.rsa_key_pair) + + with TemporaryDirectory() as dir: + key_pair = RsaKeyPair(directory=dir) + job.set_rsa_key_pair(key_pair) + self.assertEqual(job.rsa_key_pair, key_pair) + + # assert `last_updated` was updated and is greater than previous value + self.assertGreater(job.last_updated, outdated_last_updated) + + def test_set_status(self): + example_index_job = 0 + job = self._example_jobs[example_index_job] + outdated_last_updated = job.last_updated + + status = JobStatus(phase=None) + self.assertNotEqual(status, job.status) + job.set_status(status) + + self.assertEqual(status, job.status) + self.assertGreater(job.last_updated, outdated_last_updated) + + def test_set_status_phase(self): + example_index_job = 0 + job = self._example_jobs[example_index_job] + outdated_last_updated = job.last_updated + + new_status_phase = JobExecPhase.MODEL_EXEC + self.assertNotEqual(job.status_phase, new_status_phase) + + job.set_status_phase(new_status_phase) + self.assertEqual(job.status_phase, new_status_phase) + + # assert `last_updated` was implicitly updated and is greater than previous value + self.assertGreater(job.last_updated, outdated_last_updated) + + def test_set_status_step(self): + example_index_job = 0 + job = self._example_jobs[example_index_job] + outdated_last_updated = job.last_updated + + new_status_step = JobExecStep.AWAITING_ALLOCATION + self.assertNotEqual(job.status_phase, new_status_step) + + job.set_status_step(new_status_step) + self.assertEqual(job.status_step, new_status_step) - # TODO: add tests for status_phase and status_step + # assert `last_updated` was implicitly updated and is greater than previous value + self.assertGreater(job.last_updated, outdated_last_updated) From 2f811b52c4f8ba80d4bbe6564bc7633f5b847dec Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 25 Jan 2023 22:05:32 -0500 Subject: [PATCH 182/205] format modeldata's setup.py --- python/lib/modeldata/setup.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/python/lib/modeldata/setup.py b/python/lib/modeldata/setup.py index 7c389a0e7..6005ddf0b 100644 --- a/python/lib/modeldata/setup.py +++ b/python/lib/modeldata/setup.py @@ -4,23 +4,31 @@ ROOT = Path(__file__).resolve().parent try: - with open(ROOT / 'README.md', 'r') as readme: + with open(ROOT / "README.md", "r") as readme: long_description = readme.read() except: - long_description = '' + long_description = "" -exec(open(ROOT / 'dmod/modeldata/_version.py').read()) +exec(open(ROOT / "dmod/modeldata/_version.py").read()) setup( - name='dmod-modeldata', + name="dmod-modeldata", version=__version__, - description='', + description="", long_description=long_description, - author='', - author_email='', - url='', - license='', - install_requires=['numpy>=1.20.1', 'pandas', 'geopandas', 'dmod-communication>=0.4.2', 'dmod-core>=0.3.0', 'minio', - 'aiohttp<=3.7.4', 'hypy@git+https://github.com/NOAA-OWP/hypy@master#egg=hypy&subdirectory=python'], - packages=find_namespace_packages(exclude=['dmod.test', 'schemas', 'ssl', 'src']) + author="", + author_email="", + url="", + license="", + install_requires=[ + "numpy>=1.20.1", + "pandas", + "geopandas", + "dmod-communication>=0.4.2", + "dmod-core>=0.3.0", + "minio", + "aiohttp<=3.7.4", + "hypy@git+https://github.com/NOAA-OWP/hypy@master#egg=hypy&subdirectory=python", + ], + packages=find_namespace_packages(exclude=["dmod.test", "schemas", "ssl", "src"]), ) From ce92d014159e3aeb80f543a2467a9e45ce68cfba Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 25 Jan 2023 22:06:32 -0500 Subject: [PATCH 183/205] add missing dep, gitpython to modeldata's setup.py --- python/lib/modeldata/setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/lib/modeldata/setup.py b/python/lib/modeldata/setup.py index 6005ddf0b..3ae0016d0 100644 --- a/python/lib/modeldata/setup.py +++ b/python/lib/modeldata/setup.py @@ -29,6 +29,7 @@ "minio", "aiohttp<=3.7.4", "hypy@git+https://github.com/NOAA-OWP/hypy@master#egg=hypy&subdirectory=python", + "gitpython", ], packages=find_namespace_packages(exclude=["dmod.test", "schemas", "ssl", "src"]), ) From 1669805c33bef5d4bb48d24208f8e27e9c79f8f5 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 25 Jan 2023 22:06:52 -0500 Subject: [PATCH 184/205] add pydantic dep --- python/lib/modeldata/setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/lib/modeldata/setup.py b/python/lib/modeldata/setup.py index 3ae0016d0..176c58c05 100644 --- a/python/lib/modeldata/setup.py +++ b/python/lib/modeldata/setup.py @@ -30,6 +30,7 @@ "aiohttp<=3.7.4", "hypy@git+https://github.com/NOAA-OWP/hypy@master#egg=hypy&subdirectory=python", "gitpython", + "pydantic", ], packages=find_namespace_packages(exclude=["dmod.test", "schemas", "ssl", "src"]), ) From bbb2cb5bc80ebb68c62825d2e35e5b634dcd0538 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 25 Jan 2023 22:14:50 -0500 Subject: [PATCH 185/205] refactor SubsetDefinition --- .../modeldata/subset/subset_definition.py | 46 +++++++------------ 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/python/lib/modeldata/dmod/modeldata/subset/subset_definition.py b/python/lib/modeldata/dmod/modeldata/subset/subset_definition.py index d9b14f25b..e044f153f 100644 --- a/python/lib/modeldata/dmod/modeldata/subset/subset_definition.py +++ b/python/lib/modeldata/dmod/modeldata/subset/subset_definition.py @@ -1,5 +1,5 @@ -from numbers import Number -from typing import Collection, Tuple, Dict, Union +from typing import Collection, Tuple +from pydantic import validator from dmod.core.serializable import Serializable @@ -13,34 +13,29 @@ class SubsetDefinition(Serializable): to be immutable. """ - __slots__ = ["_catchment_ids", "_nexus_ids"] + catchment_ids: Tuple[str] + nexus_ids: Tuple[str] - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - try: - return cls(**json_obj) - except Exception as e: - return None + @validator("catchment_ids", "nexus_ids") + def _sort_and_dedupe_fields(cls, value: Tuple[str]) -> Tuple[str]: + return tuple(sorted(set(value))) def __init__(self, catchment_ids: Collection[str], nexus_ids: Collection[str]): - self._catchment_ids = tuple(sorted(set(catchment_ids))) - self._nexus_ids = tuple(sorted(set(nexus_ids))) + super().__init__(catchment_ids=catchment_ids, nexus_ids=nexus_ids) - def __eq__(self, other): - return isinstance(other, SubsetDefinition) \ - and self.catchment_ids == other.catchment_ids \ - and self.nexus_ids == other.nexus_ids + def __eq__(self, other: object): + return ( + isinstance(other, SubsetDefinition) + and self.catchment_ids == other.catchment_ids + and self.nexus_ids == other.nexus_ids + ) def __hash__(self): - joined_cats = ','.join(self.catchment_ids) - joined_nexs = ','.join(self.nexus_ids) - joined_all = ','.join((joined_cats, joined_nexs)) + joined_cats = ",".join(self.catchment_ids) + joined_nexs = ",".join(self.nexus_ids) + joined_all = ",".join((joined_cats, joined_nexs)) return hash(joined_all) - @property - def catchment_ids(self) -> Tuple[str]: - return self._catchment_ids - @property def id(self): """ @@ -53,10 +48,3 @@ def id(self): The unique id of this instance. """ return self.__hash__() - - @property - def nexus_ids(self) -> Tuple[str]: - return self._nexus_ids - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - return {'catchment_ids': list(self.catchment_ids), 'nexus_ids': list(self.nexus_ids)} From 09190419087097d4a89ea1f0af72979885b694c0 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 25 Jan 2023 22:32:03 -0500 Subject: [PATCH 186/205] refactor Partition --- .../dmod/modeldata/hydrofabric/partition.py | 165 +++++------------- 1 file changed, 43 insertions(+), 122 deletions(-) diff --git a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py index 54ba9ccd9..8e18f663b 100644 --- a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py +++ b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py @@ -1,5 +1,6 @@ from numbers import Number -from typing import Collection, Dict, FrozenSet, List, Union +from typing import Collection, Dict, FrozenSet, List, Tuple, Union +from pydantic import Field from dmod.core.serializable import Serializable @@ -13,56 +14,60 @@ class Partition(Serializable): in the context of the related hydrofabric. """ - __slots__ = ["_catchment_ids", "_hash_val", "_nexus_ids", "_partition_id", "_remote_downstream_nexus_ids", - "_remote_upstream_nexus_ids"] - - _KEY_CATCHMENT_IDS = 'cat-ids' - _KEY_PARTITION_ID = 'id' - # Note that these need to be included in the JSON, but initially aren't actually used at the JSON level - _KEY_NEXUS_IDS = 'nex-ids' - _KEY_REMOTE_UPSTREAM_NEXUS_IDS = 'remote-up' - _KEY_REMOTE_DOWNSTREAM_NEXUS_IDS = 'remote-down' - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - try: - # TODO: later these may be required, but for now, keep optional - if cls._KEY_REMOTE_UPSTREAM_NEXUS_IDS in json_obj: - remote_up = json_obj[cls._KEY_REMOTE_UPSTREAM_NEXUS_IDS] - else: - remote_up = [] - if cls._KEY_REMOTE_DOWNSTREAM_NEXUS_IDS in json_obj: - remote_down = json_obj[cls._KEY_REMOTE_UPSTREAM_NEXUS_IDS] - else: - remote_down = [] - return Partition(catchment_ids=json_obj[cls._KEY_CATCHMENT_IDS], nexus_ids=json_obj[cls._KEY_NEXUS_IDS], - remote_up_nexuses=remote_up, remote_down_nexuses=remote_down, - partition_id=int(json_obj[cls._KEY_PARTITION_ID])) - except: - return None + partition_id: int + catchment_ids: FrozenSet[str] + nexus_ids: FrozenSet[str] + """ + Note that, at the time this is committed, partition ids should always be integers. This is so they can easily + correspond to MPI ranks. However, because of how the expected + """ + remote_upstream_nexus_ids: FrozenSet[str] = Field(default_factory=frozenset) + remote_downstream_nexus_ids: FrozenSet[str] = Field(default_factory=frozenset) + + class Config: + fields = { + "catchment_ids": {"alias": "cat-ids"}, + "partition_id": {"alias": "id"}, + "nexus_ids": {"alias": "nex-ids"}, + "remote_up_nexuses": {"alias": "remote-up"}, + "remote_down_nexuses": {"alias": "remote-down"}, + } def __init__(self, partition_id: int, catchment_ids: Collection[str], nexus_ids: Collection[str], - remote_up_nexuses: Collection[str] = tuple(), remote_down_nexuses: Collection[str] = tuple()): - self._partition_id = partition_id - self._catchment_ids = frozenset(catchment_ids) - self._nexus_ids = frozenset(nexus_ids) - self._remote_upstream_nexus_ids = frozenset(remote_up_nexuses) - self._remote_downstream_nexus_ids = frozenset(remote_down_nexuses) + remote_up_nexuses: Collection[str] = None, remote_down_nexuses: Collection[str] = None, **data): self._hash_val = None - def __eq__(self, other): + if remote_up_nexuses is None or remote_down_nexuses is None: + super().__init__( + partition_id=partition_id, + catchment_ids=catchment_ids, + nexus_ids=nexus_ids, + **data + ) + return + + super().__init__( + partition_id=partition_id, + catchment_ids=catchment_ids, + nexus_ids=nexus_ids, + remote_upstream_nexus_ids=remote_up_nexuses, + remote_downstream_nexus_ids=remote_down_nexuses + ) + + + def __eq__(self, other: object): if not isinstance(other, self.__class__) or other.partition_id != self.partition_id: return False else: return other.__hash__() == self.__hash__() - def __lt__(self, other): + def __lt__(self, other: "Partition"): # Go first by id, so this is clearly true - if self._partition_id < other._partition_id: + if self.partition_id < other.partition_id: return True # Again, going by id first, having greater id is also clear - elif self._partition_id > other._partition_id: + elif self.partition_id > other.partition_id: return False # Also can't be (strictly) less-than AND equal-to elif self == other: @@ -79,90 +84,6 @@ def __hash__(self): self._hash_val = hash(','.join(cat_id_list)) return self._hash_val - @property - def catchment_ids(self) -> FrozenSet[str]: - """ - Get the frozen set of ids for all catchments in this partition. - - Returns - ------- - Set[str] - The frozen set of string ids for all catchments in this partition. - """ - return self._catchment_ids - - @property - def nexus_ids(self) -> FrozenSet[str]: - """ - Get the frozen set of ids for all nexuses in this partition. - - Returns - ------- - Set[str] - The frozen set of string ids for all nexuses in this partition. - """ - return self._nexus_ids - - @property - def partition_id(self) -> int: - """ - Get the id of this partition. - - Note that, at the time this is committed, partition ids should always be integers. This is so they can easily - correspond to MPI ranks. However, because of how the expected - - Returns - ------- - str - The id of this partition, as a string. - """ - return self._partition_id - - @property - def remote_downstream_nexus_ids(self) -> FrozenSet[str]: - """ - Get the frozen set of ids for all remote downstream nexuses in this partition. - - Returns - ------- - Set[str] - The frozen set of string ids for all remote downstream nexuses in this partition. - """ - return self._remote_downstream_nexus_ids - - @property - def remote_upstream_nexus_ids(self) -> FrozenSet[str]: - """ - Get the frozen set of ids for all remote upstream nexuses in this partition. - - Returns - ------- - Set[str] - The frozen set of string ids for all remote upstream nexuses in this partition. - """ - return self._remote_upstream_nexus_ids - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - """ - Get the instance represented as a dict (i.e., a JSON-like object). - - Note that, as described in the main docstring for the class, there are extra keys in the dict/JSON currently - that don't correspond to any attributes of the instance. This is for consistency with other tools. - - Returns - ------- - dict - The instance as a dict - """ - return { - self._KEY_PARTITION_ID: str(self.partition_id), - self._KEY_CATCHMENT_IDS: list(self.catchment_ids), - self._KEY_NEXUS_IDS: list(self.nexus_ids), - self._KEY_REMOTE_UPSTREAM_NEXUS_IDS: list(self.remote_upstream_nexus_ids), - self._KEY_REMOTE_DOWNSTREAM_NEXUS_IDS: list(self.remote_downstream_nexus_ids) - } - - class PartitionConfig(Serializable): """ A type to easily encapsulate the JSON object that is output from the NextGen partitioner. From a4be802b352890cd3c808f618156486c7b360106 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 25 Jan 2023 22:38:31 -0500 Subject: [PATCH 187/205] refactor PartitionConfig --- .../dmod/modeldata/hydrofabric/partition.py | 42 ++++++------------- 1 file changed, 12 insertions(+), 30 deletions(-) diff --git a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py index 8e18f663b..dde8cb660 100644 --- a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py +++ b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py @@ -1,6 +1,5 @@ -from numbers import Number from typing import Collection, Dict, FrozenSet, List, Tuple, Union -from pydantic import Field +from pydantic import Field, validator from dmod.core.serializable import Serializable @@ -89,23 +88,20 @@ class PartitionConfig(Serializable): A type to easily encapsulate the JSON object that is output from the NextGen partitioner. """ - _KEY_PARTITIONS = 'partitions' + partitions: FrozenSet[Partition] - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - try: - return PartitionConfig([Partition.factory_init_from_deserialized_json(serial_p) for serial_p in json_obj[cls._KEY_PARTITIONS]]) - except: - return None + @validator("partitions") + def _sort_partitions(cls, value: FrozenSet[Partition]) -> FrozenSet[Partition]: + return frozenset(sorted(value)) @classmethod def get_serial_property_key_partitions(cls) -> str: - return cls._KEY_PARTITIONS + return "partitions" - def __init__(self, partitions: Collection[Partition]): - self._partitions = frozenset(partitions) + def __init__(self, partitions: Collection[Partition], **data): + super().__init__(partitions=partitions, **data) - def __eq__(self, other): + def __eq__(self, other: object): if not isinstance(other, PartitionConfig): return False other_partitions_dict = dict() @@ -118,7 +114,7 @@ def __eq__(self, other): return False return True - def __hash__(self): + def __hash__(self) -> int: """ Get the unique hash for this instance. @@ -127,22 +123,8 @@ def __hash__(self): Returns ------- - + int + Hash of instance """ - # return hash(','.join([str(p.__hash__()) for p in sorted(self._partitions)])) - @property - def partitions(self) -> List[Partition]: - """ - Get the (sorted) list of partitions for this config. - - Returns - ------- - List[Partition] - The (sorted) list of partitions for this config. - """ - return sorted(self._partitions) - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - return {self._KEY_PARTITIONS: [p.to_dict() for p in self.partitions]} From 58e55d62b94f8af0c1d6e8498475b790e1597d88 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 25 Jan 2023 22:39:32 -0500 Subject: [PATCH 188/205] add kwargs argument to SubsetDefinition initilizer --- .../lib/modeldata/dmod/modeldata/subset/subset_definition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/lib/modeldata/dmod/modeldata/subset/subset_definition.py b/python/lib/modeldata/dmod/modeldata/subset/subset_definition.py index e044f153f..b9ec5b05a 100644 --- a/python/lib/modeldata/dmod/modeldata/subset/subset_definition.py +++ b/python/lib/modeldata/dmod/modeldata/subset/subset_definition.py @@ -20,8 +20,8 @@ class SubsetDefinition(Serializable): def _sort_and_dedupe_fields(cls, value: Tuple[str]) -> Tuple[str]: return tuple(sorted(set(value))) - def __init__(self, catchment_ids: Collection[str], nexus_ids: Collection[str]): - super().__init__(catchment_ids=catchment_ids, nexus_ids=nexus_ids) + def __init__(self, catchment_ids: Collection[str], nexus_ids: Collection[str], **data): + super().__init__(catchment_ids=catchment_ids, nexus_ids=nexus_ids, **data) def __eq__(self, other: object): return ( From fbcb8ed68583ad33dfe1f0405806c09d27cb1bf8 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 25 Jan 2023 22:46:39 -0500 Subject: [PATCH 189/205] refactor HydrofabricSubset --- .../dmod/modeldata/subset/hydrofabric_subset.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py b/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py index b059b4b72..dc6aa0f65 100644 --- a/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py +++ b/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from hypy import Catchment, Nexus -from typing import Collection, Optional, Sequence, Set, Tuple, Union +from typing import Collection, Optional, Set, Tuple from ..hydrofabric import Hydrofabric from .subset_definition import SubsetDefinition @@ -22,17 +22,19 @@ class HydrofabricSubset(SubsetDefinition, ABC): made in the case of invalid objects. In such cases, the hash is equal to the super class hash output plus ``1``. """ - __slots__ = ["_hydrofabric"] + hydrofabric: Hydrofabric - def __init__(self, catchment_ids: Collection[str], nexus_ids: Collection[str], hydrofabric: Hydrofabric): - super().__init__(catchment_ids=catchment_ids, nexus_ids=nexus_ids) + class Config: + arbitrary_types_allowed = True + + def __init__(self, catchment_ids: Collection[str], nexus_ids: Collection[str], hydrofabric: Hydrofabric, **data): if not self.validate_hydrofabric(hydrofabric): raise RuntimeError("Insufficient or wrongly formatted hydrofabric when trying to create {} object".format( self.__class__.__name__ )) - self._hydrofabric = hydrofabric + super().__init__(catchment_ids=catchment_ids, nexus_ids=nexus_ids, hydrofabric=hydrofabric, **data) - def __eq__(self, other): + def __eq__(self, other: object): if isinstance(other, self.__class__): return self.validate_hydrofabric() == other.validate_hydrofabric() and super().__eq__(other) else: From 9a28d4e88def0cc4d2077759eaf74a32a25ea0aa Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 25 Jan 2023 22:46:53 -0500 Subject: [PATCH 190/205] refactor SimpleHydrofabricSubset --- .../dmod/modeldata/subset/hydrofabric_subset.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py b/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py index dc6aa0f65..3cbb2f6bb 100644 --- a/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py +++ b/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py @@ -129,13 +129,10 @@ def factory_create_from_base_and_hydrofabric(cls, subset_def: SubsetDefinition, return cls(catchment_ids=subset_def.catchment_ids, nexus_ids=subset_def.nexus_ids, hydrofabric=hydrofabric, *args, **kwargs) - __slots__ = ["_catchments", "_nexuses"] - - def __init__(self, catchment_ids: Collection[str], nexus_ids: Collection[str], hydrofabric: Hydrofabric, *args, - **kwargs): + def __init__(self, catchment_ids: Collection[str], nexus_ids: Collection[str], hydrofabric: Hydrofabric, **data): self._catchments: Set[Catchment] = set() self._nexuses: Set[Nexus] = set() - super().__init__(catchment_ids, nexus_ids, hydrofabric) + super().__init__(catchment_ids=catchment_ids, nexus_ids=nexus_ids, hydrofabric=hydrofabric, **data) # Since super __init__ validates, and validate function make sure ids are recognized, these won't ever be None for cid in catchment_ids: self._catchments.add(hydrofabric.get_catchment_by_id(cid)) @@ -186,11 +183,11 @@ def validate_hydrofabric(self, hydrofabric: Optional[Hydrofabric] = None) -> boo otherwise. """ if hydrofabric is None: - hydrofabric = self._hydrofabric - for cid in self._catchment_ids: + hydrofabric = self.hydrofabric + for cid in self.catchment_ids: if not hydrofabric.is_catchment_recognized(cid): return False - for nid in self._nexus_ids: + for nid in self.nexus_ids: if not hydrofabric.is_nexus_recognized(nid): return False return True From 681e6d40df897f5e13953a3d2deb6bd72057a578 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 25 Jan 2023 22:59:02 -0500 Subject: [PATCH 191/205] move HydrofabricSubset initializer --- .../lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py b/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py index 3cbb2f6bb..945e06291 100644 --- a/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py +++ b/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py @@ -28,11 +28,11 @@ class Config: arbitrary_types_allowed = True def __init__(self, catchment_ids: Collection[str], nexus_ids: Collection[str], hydrofabric: Hydrofabric, **data): + super().__init__(catchment_ids=catchment_ids, nexus_ids=nexus_ids, hydrofabric=hydrofabric, **data) if not self.validate_hydrofabric(hydrofabric): raise RuntimeError("Insufficient or wrongly formatted hydrofabric when trying to create {} object".format( self.__class__.__name__ )) - super().__init__(catchment_ids=catchment_ids, nexus_ids=nexus_ids, hydrofabric=hydrofabric, **data) def __eq__(self, other: object): if isinstance(other, self.__class__): From 0ee01408dd55fd7db7615d397e6b3395b98afe42 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 25 Jan 2023 23:00:35 -0500 Subject: [PATCH 192/205] add private attr class level declarations to SimpleHydrofabricSubset --- .../dmod/modeldata/subset/hydrofabric_subset.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py b/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py index 945e06291..2274d866a 100644 --- a/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py +++ b/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from hypy import Catchment, Nexus from typing import Collection, Optional, Set, Tuple +from pydantic import PrivateAttr from ..hydrofabric import Hydrofabric from .subset_definition import SubsetDefinition @@ -102,6 +103,9 @@ class SimpleHydrofabricSubset(HydrofabricSubset): Simple ::class:`HydrofabricSubset` type. """ + _catchments: Set[Catchment] = PrivateAttr(default_factory=set) + _nexuses: Set[Nexus] = PrivateAttr(default_factory=set) + @classmethod def factory_create_from_base_and_hydrofabric(cls, subset_def: SubsetDefinition, hydrofabric: Hydrofabric, *args, **kwargs) \ @@ -130,13 +134,11 @@ def factory_create_from_base_and_hydrofabric(cls, subset_def: SubsetDefinition, *args, **kwargs) def __init__(self, catchment_ids: Collection[str], nexus_ids: Collection[str], hydrofabric: Hydrofabric, **data): - self._catchments: Set[Catchment] = set() - self._nexuses: Set[Nexus] = set() super().__init__(catchment_ids=catchment_ids, nexus_ids=nexus_ids, hydrofabric=hydrofabric, **data) # Since super __init__ validates, and validate function make sure ids are recognized, these won't ever be None - for cid in catchment_ids: + for cid in self.catchment_ids: self._catchments.add(hydrofabric.get_catchment_by_id(cid)) - for nid in nexus_ids: + for nid in self.nexus_ids: self._nexuses.add(hydrofabric.get_nexus_by_id(nid)) @property From 2ec8307c89026a58bf9eac25aa876338997282dd Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Wed, 25 Jan 2023 23:05:54 -0500 Subject: [PATCH 193/205] update Tuple type hints to be variadic --- .../dmod/modeldata/subset/hydrofabric_subset.py | 10 +++++----- .../dmod/modeldata/subset/subset_definition.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py b/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py index 2274d866a..283b99020 100644 --- a/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py +++ b/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py @@ -49,7 +49,7 @@ def __hash__(self): @property @abstractmethod - def catchments(self) -> Tuple[Catchment]: + def catchments(self) -> Tuple[Catchment, ...]: """ Get the associated catchments as ::class:`Catchment` objects. @@ -62,7 +62,7 @@ def catchments(self) -> Tuple[Catchment]: @property @abstractmethod - def nexuses(self) -> Tuple[Nexus]: + def nexuses(self) -> Tuple[Nexus, ...]: """ Get the associated nexuses as ::class:`Nexus` objects. @@ -148,19 +148,19 @@ def catchments(self) -> Tuple[Catchment]: Returns ------- - Tuple[Catchment] + Tuple[Catchment, ...] The associated catchments as ::class:`Catchment` objects. """ return tuple(self._catchments) @property - def nexuses(self) -> Tuple[Nexus]: + def nexuses(self) -> Tuple[Nexus, ...]: """ Get the associated nexuses as ::class:`Nexus` objects. Returns ------- - Tuple[Catchment] + Tuple[Catchment, ...] The associated nexuses as ::class:`Nexus` objects. """ return tuple(self._nexuses) diff --git a/python/lib/modeldata/dmod/modeldata/subset/subset_definition.py b/python/lib/modeldata/dmod/modeldata/subset/subset_definition.py index b9ec5b05a..baa128fea 100644 --- a/python/lib/modeldata/dmod/modeldata/subset/subset_definition.py +++ b/python/lib/modeldata/dmod/modeldata/subset/subset_definition.py @@ -13,11 +13,11 @@ class SubsetDefinition(Serializable): to be immutable. """ - catchment_ids: Tuple[str] - nexus_ids: Tuple[str] + catchment_ids: Tuple[str, ...] + nexus_ids: Tuple[str, ...] @validator("catchment_ids", "nexus_ids") - def _sort_and_dedupe_fields(cls, value: Tuple[str]) -> Tuple[str]: + def _sort_and_dedupe_fields(cls, value: Tuple[str, ...]) -> Tuple[str, ...]: return tuple(sorted(set(value))) def __init__(self, catchment_ids: Collection[str], nexus_ids: Collection[str], **data): From ad241c14f4105c5e484a82d9295b764f53e71444 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 7 Feb 2023 11:02:24 -0500 Subject: [PATCH 194/205] fix reference to non-existent field --- python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py index dde8cb660..96e550206 100644 --- a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py +++ b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py @@ -105,7 +105,7 @@ def __eq__(self, other: object): if not isinstance(other, PartitionConfig): return False other_partitions_dict = dict() - for other_p in other._partitions: + for other_p in other.partitions: other_partitions_dict[other_p.partition_id] = other_p other_pids = set([p2.partition_id for p2 in other.partitions]) From 300ce9b5acb604140301dcd9b84009ef28d276f1 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 7 Feb 2023 11:05:14 -0500 Subject: [PATCH 195/205] add missing PrivateAttr field to Partition --- .../lib/modeldata/dmod/modeldata/hydrofabric/partition.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py index 96e550206..f66d0714b 100644 --- a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py +++ b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py @@ -1,5 +1,5 @@ -from typing import Collection, Dict, FrozenSet, List, Tuple, Union -from pydantic import Field, validator +from typing import Collection, Dict, FrozenSet, List, Optional, Tuple, Union +from pydantic import Field, PrivateAttr, validator from dmod.core.serializable import Serializable @@ -23,6 +23,8 @@ class Partition(Serializable): remote_upstream_nexus_ids: FrozenSet[str] = Field(default_factory=frozenset) remote_downstream_nexus_ids: FrozenSet[str] = Field(default_factory=frozenset) + _hash_val: Optional[int] = PrivateAttr(None) + class Config: fields = { "catchment_ids": {"alias": "cat-ids"}, @@ -35,8 +37,6 @@ class Config: def __init__(self, partition_id: int, catchment_ids: Collection[str], nexus_ids: Collection[str], remote_up_nexuses: Collection[str] = None, remote_down_nexuses: Collection[str] = None, **data): - self._hash_val = None - if remote_up_nexuses is None or remote_down_nexuses is None: super().__init__( partition_id=partition_id, From 58f0341642c7da81370855b3ef5aa411939790ce Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 7 Feb 2023 11:32:20 -0500 Subject: [PATCH 196/205] typo in Partition's fields alias map --- python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py index f66d0714b..0de16f749 100644 --- a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py +++ b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py @@ -30,8 +30,8 @@ class Config: "catchment_ids": {"alias": "cat-ids"}, "partition_id": {"alias": "id"}, "nexus_ids": {"alias": "nex-ids"}, - "remote_up_nexuses": {"alias": "remote-up"}, - "remote_down_nexuses": {"alias": "remote-down"}, + "remote_upstream_nexus_ids": {"alias": "remote-up"}, + "remote_downstream_nexus_ids": {"alias": "remote-down"}, } def __init__(self, partition_id: int, catchment_ids: Collection[str], nexus_ids: Collection[str], From 1362e2a75c49cb03b25d87b6563323bae2a93c4b Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 7 Feb 2023 11:33:34 -0500 Subject: [PATCH 197/205] serialize Partition fields as list and fix it's intializer fn to allow factory init from dict --- .../dmod/modeldata/hydrofabric/partition.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py index 0de16f749..533adc891 100644 --- a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py +++ b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py @@ -34,8 +34,32 @@ class Config: "remote_downstream_nexus_ids": {"alias": "remote-down"}, } - def __init__(self, partition_id: int, catchment_ids: Collection[str], nexus_ids: Collection[str], - remote_up_nexuses: Collection[str] = None, remote_down_nexuses: Collection[str] = None, **data): + def _serialize_frozenset(value: FrozenSet[str]) -> List[str]: + return list(value) + + field_serializers = { + "catchment_ids": _serialize_frozenset, + "nexus_ids": _serialize_frozenset, + "remote_upstream_nexus_ids": _serialize_frozenset, + "remote_downstream_nexus_ids": _serialize_frozenset, + } + + def __init__( + self, + # required, but for backwards compatibility, None + partition_id: int = None, + catchment_ids: Collection[str] = None, + nexus_ids: Collection[str] = None, + # non-required fields + remote_up_nexuses: Collection[str] = None, + remote_down_nexuses: Collection[str] = None, + **data + ): + # if data exists, assume fields specified using their alias; no backwards compatibility. + if data: + super().__init__(**data) + return + if remote_up_nexuses is None or remote_down_nexuses is None: super().__init__( From cc57f916e3ee2ea2e4d2a01277362df41d0bd374 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 7 Feb 2023 12:10:47 -0500 Subject: [PATCH 198/205] add Partition unit tests --- .../lib/modeldata/dmod/test/test_partition.py | 95 +++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 python/lib/modeldata/dmod/test/test_partition.py diff --git a/python/lib/modeldata/dmod/test/test_partition.py b/python/lib/modeldata/dmod/test/test_partition.py new file mode 100644 index 000000000..546c5199f --- /dev/null +++ b/python/lib/modeldata/dmod/test/test_partition.py @@ -0,0 +1,95 @@ +import unittest +from ..modeldata.hydrofabric.partition import Partition, PartitionConfig + +class TestPartition(unittest.TestCase): + + @classmethod + @property + def partition_instance(cls) -> Partition: + return Partition( + nexus_ids=["2"], + catchment_ids=["42"], + partition_id=0, + remote_up_nexuses=["1"], + remote_down_nexuses=["3"] + ) + + @classmethod + @property + def serialized_partition(cls) -> dict: + return { + "cat-ids": ["42"], + "id": 0, + "remote-up": ["1"], + "nex-ids": ["2"], + "remote-down": ["3"] + } + + + def test_programmatically_create_partition(self): + """Test creating an instance programmatically""" + o = self.partition_instance + + self.assertEqual(len(o.catchment_ids), 1) + self.assertEqual(len(o.remote_upstream_nexus_ids), 1) + self.assertEqual(len(o.nexus_ids), 1) + self.assertEqual(len(o.remote_downstream_nexus_ids), 1) + + self.assertIn + self.assertEqual(o.partition_id, 0) + self.assertIn("42", o.catchment_ids) + self.assertIn("1", o.remote_upstream_nexus_ids) + self.assertIn("2", o.nexus_ids) + self.assertIn("3", o.remote_downstream_nexus_ids) + + def test_factory_init_from_deserialized_json(self): + """ + Test creating an instance from a dictionary, then re-serializing equals the original dict. + """ + data = self.serialized_partition + o = Partition.factory_init_from_deserialized_json(data) + self.assertIsNotNone(o) + self.assertDictEqual(data, o.to_dict()) # type: ignore + + def test_eq(self): + """ + Test equality of instances. Tests instances created programmatically and from dict + deserialization. + """ + o1 = self.partition_instance + o2 = self.partition_instance + + o3 = Partition.factory_init_from_deserialized_json(self.serialized_partition) + self.assertEqual(o1, o1) + self.assertEqual(o1, o2) + self.assertEqual(o1, o3) + + def test_hash(self): + """ + Test instances hash to the same value based on their data, not the order of their data. + """ + catchment_ids = ["1", "2", "3"] + rev_catchment_ids = catchment_ids[::-1] + + o1 = Partition( + # these fields are used by __hash__ + catchment_ids=catchment_ids, + partition_id=0, + + nexus_ids=["2"], + remote_up_nexuses=["1"], + remote_down_nexuses=["3"] + ) + + o2 = Partition( + # these fields are used by __hash__ + catchment_ids=rev_catchment_ids, + partition_id=0, + + nexus_ids=["2"], + remote_up_nexuses=["1"], + remote_down_nexuses=["3"] + ) + + self.assertNotEqual(catchment_ids, rev_catchment_ids) + self.assertEqual(hash(o1), hash(o2)) From ed2556efb46b19e277ebb8d0cd542f61c700b369 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 7 Feb 2023 13:11:15 -0500 Subject: [PATCH 199/205] fix invalid reference --- python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py index 533adc891..e7136ab8c 100644 --- a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py +++ b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py @@ -150,5 +150,5 @@ def __hash__(self) -> int: int Hash of instance """ - return hash(','.join([str(p.__hash__()) for p in sorted(self._partitions)])) + return hash(','.join([str(p.__hash__()) for p in sorted(self.partitions)])) From 9815c3a8b575c493ad6a65619f37a2b7854db344 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 7 Feb 2023 15:25:28 -0500 Subject: [PATCH 200/205] add PartionConfig unittests. format with black --- .../lib/modeldata/dmod/test/test_partition.py | 190 +++++++++++++++--- 1 file changed, 160 insertions(+), 30 deletions(-) diff --git a/python/lib/modeldata/dmod/test/test_partition.py b/python/lib/modeldata/dmod/test/test_partition.py index 546c5199f..3d1344ceb 100644 --- a/python/lib/modeldata/dmod/test/test_partition.py +++ b/python/lib/modeldata/dmod/test/test_partition.py @@ -1,30 +1,29 @@ import unittest from ..modeldata.hydrofabric.partition import Partition, PartitionConfig -class TestPartition(unittest.TestCase): +class TestPartition(unittest.TestCase): @classmethod @property def partition_instance(cls) -> Partition: return Partition( - nexus_ids=["2"], - catchment_ids=["42"], - partition_id=0, - remote_up_nexuses=["1"], - remote_down_nexuses=["3"] - ) + nexus_ids=["2"], + catchment_ids=["42"], + partition_id=0, + remote_up_nexuses=["1"], + remote_down_nexuses=["3"], + ) @classmethod @property def serialized_partition(cls) -> dict: return { - "cat-ids": ["42"], - "id": 0, - "remote-up": ["1"], - "nex-ids": ["2"], - "remote-down": ["3"] - } - + "cat-ids": ["42"], + "id": 0, + "remote-up": ["1"], + "nex-ids": ["2"], + "remote-down": ["3"], + } def test_programmatically_create_partition(self): """Test creating an instance programmatically""" @@ -49,7 +48,7 @@ def test_factory_init_from_deserialized_json(self): data = self.serialized_partition o = Partition.factory_init_from_deserialized_json(data) self.assertIsNotNone(o) - self.assertDictEqual(data, o.to_dict()) # type: ignore + self.assertDictEqual(data, o.to_dict()) # type: ignore def test_eq(self): """ @@ -72,24 +71,155 @@ def test_hash(self): rev_catchment_ids = catchment_ids[::-1] o1 = Partition( - # these fields are used by __hash__ - catchment_ids=catchment_ids, - partition_id=0, - - nexus_ids=["2"], - remote_up_nexuses=["1"], - remote_down_nexuses=["3"] - ) + # these fields are used by __hash__ + catchment_ids=catchment_ids, + partition_id=0, + nexus_ids=["2"], + remote_up_nexuses=["1"], + remote_down_nexuses=["3"], + ) o2 = Partition( - # these fields are used by __hash__ - catchment_ids=rev_catchment_ids, - partition_id=0, + # these fields are used by __hash__ + catchment_ids=rev_catchment_ids, + partition_id=0, + nexus_ids=["2"], + remote_up_nexuses=["1"], + remote_down_nexuses=["3"], + ) + + self.assertNotEqual(catchment_ids, rev_catchment_ids) + self.assertEqual(hash(o1), hash(o2)) + + def test_to_dict(self): + """Test serializing to dict""" + o = Partition.factory_init_from_deserialized_json(self.serialized_partition) + + self.assertIsNotNone(o) + self.assertDictEqual(o.to_dict(), self.serialized_partition) # type: ignore + + +class TestPartitionConfig(unittest.TestCase): + @classmethod + @property + def partition_config_instance(cls) -> PartitionConfig: + return PartitionConfig(partitions=[TestPartition.partition_instance]) + + @classmethod + @property + def serialized_partition_config(cls) -> dict: + return {"partitions": [TestPartition.serialized_partition]} + + def test_programmatically_create_partition(self): + """Test creating an instance programmatically""" + o = self.partition_config_instance + + self.assertEqual(len(o.partitions), 1) + + def test_factory_init_from_deserialized_json(self): + """Test creating an instance programmatically""" + data = self.serialized_partition_config + o = PartitionConfig.factory_init_from_deserialized_json(data) - nexus_ids=["2"], - remote_up_nexuses=["1"], - remote_down_nexuses=["3"] + self.assertIsNotNone(o) + self.assertDictEqual(data, o.to_dict()) # type: ignore + + def test_to_dict(self): + o = PartitionConfig.factory_init_from_deserialized_json( + self.serialized_partition_config + ) + self.assertIsNotNone(o) + self.assertDictEqual(self.serialized_partition_config, o.to_dict()) # type: ignore + + def test_hash(self): + """ + Test instances hash to the same value based on their data, not the order of their data. + """ + self.assertEqual( + hash(self.partition_config_instance), hash(self.partition_config_instance) + ) + + # from dictionary + o = PartitionConfig.factory_init_from_deserialized_json( + self.partition_config_instance.to_dict() + ) + self.assertEqual(hash(self.partition_config_instance), hash(o)) + + catchment_ids = ["1", "2", "3"] + + o1 = PartitionConfig( + partitions=[ + Partition( + nexus_ids=["1"], + remote_up_nexuses=["2"], + remote_down_nexuses=["3"], + partition_id=0, + catchment_ids=catchment_ids, + ) + ] + ) + + o2 = PartitionConfig( + partitions=[ + Partition( + nexus_ids=["2222"], + remote_up_nexuses=["1111"], + remote_down_nexuses=["3333"], + partition_id=0, + catchment_ids=catchment_ids, ) + ] + ) - self.assertNotEqual(catchment_ids, rev_catchment_ids) + # same partition and catchment ids + # NOTE: this is the expected behavior self.assertEqual(hash(o1), hash(o2)) + + def test_duplicate_partitions_removed_during_init(self): + catchment_ids = ["1", "2", "3"] + rev_catchment_ids = catchment_ids[::-1] + + o1 = Partition( + nexus_ids=["2"], + remote_up_nexuses=["1"], + remote_down_nexuses=["3"], + partition_id=0, + catchment_ids=catchment_ids, + ) + + o2 = Partition( + nexus_ids=["2"], + remote_up_nexuses=["1"], + remote_down_nexuses=["3"], + partition_id=0, + catchment_ids=rev_catchment_ids, + ) + + duplicate_partition_inst = PartitionConfig(partitions=[o1, o1]) + self.assertEqual(len(duplicate_partition_inst.partitions), 1) + + duplicate_partition_same_data_inst = PartitionConfig(partitions=[o1, o2]) + self.assertEqual(len(duplicate_partition_same_data_inst.partitions), 1) + catchment_ids = ["1", "2", "3"] + + o1 = Partition( + nexus_ids=["2"], + remote_up_nexuses=["1"], + remote_down_nexuses=["3"], + partition_id=0, + catchment_ids=catchment_ids, + ) + + o3 = Partition( + nexus_ids=["2222"], + remote_up_nexuses=["1111"], + remote_down_nexuses=["3333"], + partition_id=0, + catchment_ids=catchment_ids, + ) + + same_catchment_id_and_partition_id = PartitionConfig(partitions=[o1, o3]) + + # NOTE: this is the expected behavior + self.assertEqual(len(same_catchment_id_and_partition_id.partitions), 1) + From f6379ee5810365ecf0f536a2cfbd3f65e5a3dc4d Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 7 Feb 2023 15:26:48 -0500 Subject: [PATCH 201/205] add PartionConfig field_serializer --- .../lib/modeldata/dmod/modeldata/hydrofabric/partition.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py index e7136ab8c..25d2dfb73 100644 --- a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py +++ b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py @@ -118,6 +118,14 @@ class PartitionConfig(Serializable): def _sort_partitions(cls, value: FrozenSet[Partition]) -> FrozenSet[Partition]: return frozenset(sorted(value)) + class Config: + def _serialize_frozenset(value: FrozenSet[Partition]) -> List[Partition]: + return list(value) + + field_serializers = { + "partitions": _serialize_frozenset + } + @classmethod def get_serial_property_key_partitions(cls) -> str: return "partitions" From 0a93138fc96cc139ae1ddc63c2bada7cf23f37c2 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Tue, 7 Feb 2023 15:27:57 -0500 Subject: [PATCH 202/205] override PartitionConfig dict method. reasons why dict is overridden here: pydantic will serialize from inner types outward, serializing each type as a dictionary, list, or primitive and replacing its previous type with the new "serialized" type. Consequently, this means hashable container types like tuples and frozensets that contain values that "serialize" to a non-hashable type (non-primitive, in this case) will raise a `TypeError: unhashable type: 'dict'`. In the case of PartitionConfig, FronzenSet[Partition] "serializes" inner Partition types as dictionaries which are not hashable. To get around this, we will momentarily swap the `partitions` field for a non-hashable container type, serialize using `.dict()`, and swap back in the original `partitions` container. --- .../dmod/modeldata/hydrofabric/partition.py | 49 ++++++++++++++++++- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py index 25d2dfb73..0b9f22255 100644 --- a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py +++ b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py @@ -1,7 +1,10 @@ -from typing import Collection, Dict, FrozenSet, List, Optional, Tuple, Union +from typing import Collection, FrozenSet, List, Optional, TYPE_CHECKING, Union from pydantic import Field, PrivateAttr, validator from dmod.core.serializable import Serializable +if TYPE_CHECKING: + from pydantic.typing import AbstractSetIntStr, DictStrAny, MappingIntStrAny + class Partition(Serializable): """ @@ -158,5 +161,47 @@ def __hash__(self) -> int: int Hash of instance """ - return hash(','.join([str(p.__hash__()) for p in sorted(self.partitions)])) + return hash(",".join([str(p.__hash__()) for p in sorted(self.partitions)])) + + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> "DictStrAny": + # reasons why dict is overridden here: + # pydantic will serialize from inner types outward, serializing each type as a dictionary, + # list, or primitive and replacing its previous type with the new "serialized" type. + # Consequently, this means hashable container types like tuples and frozensets that contain + # values that "serialize" to a non-hashable type (non-primitive, in this case) will raise a + # `TypeError: unhashable type: 'dict'`. In the case of PartitionConfig, + # FronzenSet[Partition] "serializes" inner Partition types as dictionaries which are not + # hashable. To get around this, we will momentarily swap the `partitions` field for a + # non-hashable container type, serialize using `.dict()`, and swap back in the original + # `partitions` container. + + # 1. take a reference to partitions: FrozenSet[Partition] + partitions = self.partitions + + # 2. cast and set partitions to a list, a non-hashable container type + self.partitions = list(partitions) + + # 3. serialize + serial = super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + # 4. replace partitions with its hashable representation + self.partitions = partitions + return serial From ff06874891577e1bb0d68ce9fc42346a37bb94c2 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 10 Feb 2023 15:58:34 -0500 Subject: [PATCH 203/205] replace private reference to Dataset field name --- .../lib/modeldata/dmod/modeldata/data/object_store_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/lib/modeldata/dmod/modeldata/data/object_store_manager.py b/python/lib/modeldata/dmod/modeldata/data/object_store_manager.py index dffb5a8cf..33045f461 100644 --- a/python/lib/modeldata/dmod/modeldata/data/object_store_manager.py +++ b/python/lib/modeldata/dmod/modeldata/data/object_store_manager.py @@ -560,8 +560,8 @@ def reload(self, reload_from: str, serialized_item: Optional[str] = None) -> Dat response_obj.release_conn() # If we can safely infer it, make sure the "type" key is set in cases when it is missing - if len(self.supported_dataset_types) == 1 and Dataset._KEY_TYPE not in response_data: - response_data[Dataset._KEY_TYPE] = list(self.supported_dataset_types)[0].name + if len(self.supported_dataset_types) == 1 and "type" not in response_data: + response_data["type"] = list(self.supported_dataset_types)[0].name dataset = Dataset.factory_init_from_deserialized_json(response_data) dataset.manager = self From 060e07666dbcd3bac3cd8970e2751d95677110d3 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 10 Feb 2023 15:59:31 -0500 Subject: [PATCH 204/205] replace private reference to Dataset field name in test --- .../lib/modeldata/dmod/test/it_object_store_dataset_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lib/modeldata/dmod/test/it_object_store_dataset_manager.py b/python/lib/modeldata/dmod/test/it_object_store_dataset_manager.py index 3af8d9ddb..3fec85778 100644 --- a/python/lib/modeldata/dmod/test/it_object_store_dataset_manager.py +++ b/python/lib/modeldata/dmod/test/it_object_store_dataset_manager.py @@ -253,7 +253,7 @@ def test_get_data_1_b(self): data_dict = json.loads(self.manager.get_data(dataset_name, item_name=serial_file_name).decode()) - self.assertEqual(dataset_name, data_dict[Dataset._KEY_NAME]) + self.assertEqual(dataset_name, data_dict["name"]) def test_list_files_1_a(self): """ From 4e20d27fac48accd1e913d8f40487bc1ccd67fbd Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 10 Feb 2023 16:26:08 -0500 Subject: [PATCH 205/205] no longer use a double wrapped classmethod property. is not supported in 3.8 and is deprecated in 3.11 --- .../lib/modeldata/dmod/test/test_partition.py | 22 +++++-------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/python/lib/modeldata/dmod/test/test_partition.py b/python/lib/modeldata/dmod/test/test_partition.py index 3d1344ceb..be9fa6f7f 100644 --- a/python/lib/modeldata/dmod/test/test_partition.py +++ b/python/lib/modeldata/dmod/test/test_partition.py @@ -3,10 +3,7 @@ class TestPartition(unittest.TestCase): - @classmethod - @property - def partition_instance(cls) -> Partition: - return Partition( + partition_instance = Partition( nexus_ids=["2"], catchment_ids=["42"], partition_id=0, @@ -14,10 +11,7 @@ def partition_instance(cls) -> Partition: remote_down_nexuses=["3"], ) - @classmethod - @property - def serialized_partition(cls) -> dict: - return { + serialized_partition = { "cat-ids": ["42"], "id": 0, "remote-up": ["1"], @@ -100,15 +94,9 @@ def test_to_dict(self): class TestPartitionConfig(unittest.TestCase): - @classmethod - @property - def partition_config_instance(cls) -> PartitionConfig: - return PartitionConfig(partitions=[TestPartition.partition_instance]) - - @classmethod - @property - def serialized_partition_config(cls) -> dict: - return {"partitions": [TestPartition.serialized_partition]} + partition_config_instance = PartitionConfig(partitions=[TestPartition.partition_instance]) + + serialized_partition_config = {"partitions": [TestPartition.serialized_partition]} def test_programmatically_create_partition(self): """Test creating an instance programmatically"""