diff --git a/python/lib/communication/dmod/communication/data_transmit_message.py b/python/lib/communication/dmod/communication/data_transmit_message.py index 6702bb7a0..9f216c796 100644 --- a/python/lib/communication/dmod/communication/data_transmit_message.py +++ b/python/lib/communication/dmod/communication/data_transmit_message.py @@ -1,10 +1,23 @@ +from dmod.core.serializable import Serializable +from pydantic import Extra +from dmod.core.serializable_dict import SerializableDict from .message import AbstractInitRequest, MessageEventType, Response -from typing import Dict, Optional, Union -from numbers import Number +from pydantic import Field +from typing import ClassVar, Type, Union from uuid import UUID -class DataTransmitMessage(AbstractInitRequest): +class DataTransmitUUID(Serializable): + series_uuid: UUID = Field(description="A unique id for the collective series of transmission message this instance is a part of.") + """ + The expectation is that a larger amount of data will be broken up into multiple messages in a series. + """ + + class Config: + field_serializers = {"series_uuid": lambda s: str(s)} + + +class DataTransmitMessage(DataTransmitUUID, AbstractInitRequest): """ Specialized message type for transmitting data. @@ -18,64 +31,14 @@ class DataTransmitMessage(AbstractInitRequest): ::class:`str` object. However, instances can be initialized using either ::class:`str` or ::class:`bytes` data. """ - _KEY_SERIES_UUID = 'series_uuid' + event_type: ClassVar[MessageEventType] = MessageEventType.DATA_TRANSMISSION - event_type: MessageEventType = MessageEventType.DATA_TRANSMISSION + data: str = Field(description="The data carried by this message, in decoded string form.") + is_last: bool = Field(False, description="Whether this is the last data transmission message in this series.") - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict) -> Optional['DataTransmitMessage']: - try: - return cls(data=json_obj['data'], series_uuid=UUID(json_obj[cls._KEY_SERIES_UUID]), - is_last=json_obj['is_last']) - except Exception as e: - return None - def __init__(self, data: Union[str, bytes], series_uuid: UUID, is_last: bool = False, *args, **kwargs): - super(DataTransmitMessage, self).__init__(*args, **kwargs) - self._data: str = data if isinstance(data, str) else data.decode() - self._series_uuid = series_uuid - self._is_last: bool = is_last - - @property - def data(self) -> str: - """ - The data carried by this message, in decoded string form. - - Returns - ------- - str - The data carried by this message, in decoded string form. - """ - return self._data - - @property - def is_last(self) -> bool: - """ - Whether this is the last data transmission message in this series. - - Returns - ------- - bool - Whether this is the last data transmission message in this series. - """ - return self._is_last - - @property - def series_uuid(self) -> UUID: - """ - A unique id for the collective series of transmission message this instance is a part of. - - The expectation is that a larger amount of data will be broken up into multiple messages in a series. - - Returns - ------- - UUID - A unique id for the collective series of transmission message this instance is a part of. - """ - return self._series_uuid - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - return {'data': self.data, self._KEY_SERIES_UUID: str(self.series_uuid), 'is_last': self.is_last} +class DataTransmitResponseBody(SerializableDict, DataTransmitUUID): + ... class DataTransmitResponse(Response): @@ -86,38 +49,24 @@ class DataTransmitResponse(Response): series of which it is a part. """ - response_to_type = DataTransmitMessage - - _KEY_SERIES_UUID = response_to_type._KEY_SERIES_UUID - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict) -> 'DataTransmitResponse': - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. - - Parameters - ---------- - json_obj - - Returns - ------- - response_obj : Response - A new object of this type instantiated from the deserialize JSON object dictionary, or none if the provided - parameter could not be used to instantiated a new object. - """ - try: - return cls(success=json_obj['success'], reason=json_obj['reason'], message=json_obj['message'], - series_uuid=json_obj['data'][cls._KEY_SERIES_UUID], data=json_obj['data']) - except Exception as e: - return None - - def __init__(self, series_uuid: Union[str, UUID], *args, **kwargs): - if 'data' not in kwargs: - kwargs['data'] = dict() - kwargs['data'][self._KEY_SERIES_UUID] = str(series_uuid) - super(DataTransmitResponse, self).__init__(*args, **kwargs) + response_to_type: ClassVar[Type[AbstractInitRequest]] = DataTransmitMessage + + data: DataTransmitResponseBody + + # `series_uuid` required in prior version of code + def __init__(self, series_uuid: Union[str, UUID] = None, **kwargs): + # assume no need for backwards compatibility + if series_uuid is None: + super().__init__(**kwargs) + return + + if "data" not in kwargs: + kwargs["data"] = dict() + + kwargs["data"]["series_uuid"] = series_uuid + super().__init__(**kwargs) @property def series_uuid(self) -> UUID: - return UUID(self.data[self._KEY_SERIES_UUID]) + return self.data.series_uuid diff --git a/python/lib/communication/dmod/communication/dataset_management_message.py b/python/lib/communication/dmod/communication/dataset_management_message.py index d148b6d0e..e04352cc5 100644 --- a/python/lib/communication/dmod/communication/dataset_management_message.py +++ b/python/lib/communication/dmod/communication/dataset_management_message.py @@ -1,13 +1,14 @@ from .message import AbstractInitRequest, MessageEventType, Response from dmod.core.serializable import Serializable +from dmod.core.serializable_dict import SerializableDict from .maas_request import ExternalRequest, ExternalRequestResponse from dmod.core.meta_data import DataCategory, DataDomain, DataFormat, DataRequirement -from numbers import Number -from enum import Enum -from typing import Dict, Optional, Union, List +from dmod.core.enum import PydanticEnum +from pydantic import root_validator, Field +from typing import Any, ClassVar, Dict, Optional, Type, Union, List -class QueryType(Enum): +class QueryType(PydanticEnum): LIST_FILES = 1 GET_CATEGORY = 2 GET_FORMAT = 3 @@ -41,31 +42,13 @@ def get_for_name(cls, name_str: str) -> 'QueryType': class DatasetQuery(Serializable): - _KEY_QUERY_TYPE = 'query_type' - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict) -> Optional['DatasetQuery']: - try: - return cls(query_type=QueryType.get_for_name(json_obj[cls._KEY_QUERY_TYPE])) - except Exception as e: - return None + query_type: QueryType def __hash__(self): return hash(self.query_type) - def __eq__(self, other): - return isinstance(other, DatasetQuery) and self.query_type == other.query_type - - def __init__(self, query_type: QueryType): - self.query_type = query_type - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serial = dict() - serial[self._KEY_QUERY_TYPE] = self.query_type.name - return serial - -class ManagementAction(Enum): +class ManagementAction(PydanticEnum): """ Type enumerating the standard actions that can be requested via ::class:`DatasetManagementMessage`. """ @@ -175,65 +158,43 @@ class DatasetManagementMessage(AbstractInitRequest): Valid actions are enumerated by the ::class:`ManagementAction`. """ - event_type: MessageEventType = MessageEventType.DATASET_MANAGEMENT - - _SERIAL_KEY_ACTION = 'action' - _SERIAL_KEY_CATEGORY = 'category' - _SERIAL_KEY_DATA_DOMAIN = 'data_domain' - _SERIAL_KEY_DATA_LOCATION = 'data_location' - _SERIAL_KEY_DATASET_NAME = 'dataset_name' - _SERIAL_KEY_IS_PENDING_DATA = 'pending_data' - _SERIAL_KEY_QUERY = 'query' - _SERIAL_KEY_IS_READ_ONLY = 'read_only' + event_type: ClassVar[MessageEventType] = MessageEventType.DATASET_MANAGEMENT - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict) -> Optional['DatasetManagementMessage']: - """ - Inflate serialized representation back to a full object, if serial representation is valid. + management_action: ManagementAction = Field(description="The type of ::class:`ManagementAction` this message embodies or requests.") + dataset_name: Optional[str] = Field(description="The name of the involved dataset, if applicable.") + is_read_only_dataset: bool = Field(False, description="Whether the dataset involved is, should be, or must be (depending on action) read-only.") + data_category: Optional[DataCategory] = Field(description="The category of the involved data, if applicable.") + data_domain: Optional[DataDomain] = Field(description="The domain of the involved data, if applicable.") + data_location: Optional[str] = Field(description="Location for acted-upon data.") + is_pending_data: bool = Field(False, description="Whether the sender has data pending transmission after this message.") + """ + Whether the sender has data it wants to transmit after this message. The typical use case is during a + ``CREATE`` action, where this indicates there is already data to add to the newly created dataset. + """ + query: Optional[DatasetQuery] - Parameters - ---------- - json_obj : dict - Serialized representation of a ::class:`DatasetManagementMessage` instance. + @root_validator() + def _post_init_validate_dependent_fields(cls, values): + # Sanity check certain param values depending on the action; e.g., can't CREATE a dataset without a name + action: ManagementAction = values["management_action"] + name, category, domain = values["dataset_name"], values["data_category"], values["data_domain"] + err_msg_template = "Cannot create {} for action {} without {}" + if name is None and action.requires_dataset_name: + raise RuntimeError(err_msg_template.format(cls.__name__, action, "a dataset name")) + if category is None and action.requires_data_category: + raise RuntimeError(err_msg_template.format(cls.__name__, action, "a data category")) + if domain is None and action.requires_data_domain: + raise RuntimeError(err_msg_template.format(cls.__name__, action, "a data domain")) - Returns - ------- - Optional[DatasetManagementMessage] - The inflated ::class:`DatasetManagementMessage`, or ``None`` if the serialized form was invalid. - """ - try: - # Grab the class to deserialize, popping it from the json obj (it was temp injected by a subclass) if there - deserialized_class = json_obj.pop('deserialized_class', cls) - - # Similarly, get/pop any temporarily injected kwargs values to pass to deserialized_class's init function - deserialized_class_kwargs = json_obj.pop('deserialized_class_kwargs', dict()) - - action = ManagementAction.get_for_name(json_obj[cls._SERIAL_KEY_ACTION]) - if json_obj[cls._SERIAL_KEY_ACTION] != action.name: - raise RuntimeError("Unparseable serialized {} value: {}".format(ManagementAction.__name__, - json_obj[cls._SERIAL_KEY_ACTION])) - - dataset_name = json_obj.get(cls._SERIAL_KEY_DATASET_NAME) - category_str = json_obj.get(cls._SERIAL_KEY_CATEGORY) - category = None if category_str is None else DataCategory.get_for_name(category_str) - data_loc = json_obj.get(cls._SERIAL_KEY_DATA_LOCATION) - #page = json_obj[cls._SERIAL_KEY_PAGE] if cls._SERIAL_KEY_PAGE in json_obj else None - if cls._SERIAL_KEY_QUERY in json_obj: - query = DatasetQuery.factory_init_from_deserialized_json(json_obj[cls._SERIAL_KEY_QUERY]) - else: - query = None - if cls._SERIAL_KEY_DATA_DOMAIN in json_obj: - domain = DataDomain.factory_init_from_deserialized_json(json_obj[cls._SERIAL_KEY_DATA_DOMAIN]) - else: - domain = None + return values - return deserialized_class(action=action, dataset_name=dataset_name, category=category, - is_read_only_dataset=json_obj[cls._SERIAL_KEY_IS_READ_ONLY], domain=domain, - data_location=data_loc, - is_pending_data=json_obj.get(cls._SERIAL_KEY_IS_PENDING_DATA), #page=page, - query=query, **deserialized_class_kwargs) - except Exception as e: - return None + class Config: + fields = { + "management_action": {"alias": "action"}, + "data_category": {"alias": "category"}, + "is_read_only_dataset": {"alias": "read_only"}, + "is_pending_data": {"alias": "pending_data"}, + } def __eq__(self, other): try: @@ -259,10 +220,20 @@ def __hash__(self): self.data_category.name, str(hash(self.data_domain)), self.data_location, str(self.is_pending_data), self.query.to_json()])) - def __init__(self, action: ManagementAction, dataset_name: Optional[str] = None, is_read_only_dataset: bool = False, - category: Optional[DataCategory] = None, domain: Optional[DataDomain] = None, - data_location: Optional[str] = None, is_pending_data: bool = False, - query: Optional[DatasetQuery] = None, *args, **kwargs): + def __init__( + self, + *, + # NOTE: default is None for backwards compatibility. could be specified using alias. + action: ManagementAction = None, + dataset_name: Optional[str] = None, + is_read_only_dataset: bool = False, + category: Optional[DataCategory] = None, + domain: Optional[DataDomain] = None, + data_location: Optional[str] = None, + is_pending_data: bool = False, + query: Optional[DatasetQuery] = None, + **data + ): """ Initialize this instance. @@ -283,171 +254,58 @@ def __init__(self, action: ManagementAction, dataset_name: Optional[str] = None, query : Optional[DatasetQuery] Optional ::class:`DatasetQuery` object for query messages. """ - # Sanity check certain param values depending on the action; e.g., can't CREATE a dataset without a name - err_msg_template = "Cannot create {} for action {} without {}" - if dataset_name is None and action.requires_dataset_name: - raise RuntimeError(err_msg_template.format(self.__class__.__name__, action, "a dataset name")) - if category is None and action.requires_data_category: - raise RuntimeError(err_msg_template.format(self.__class__.__name__, action, "a data category")) - if domain is None and action.requires_data_domain: - raise RuntimeError(err_msg_template.format(self.__class__.__name__, action, "a data domain")) - - super(DatasetManagementMessage, self).__init__(*args, **kwargs) - - # TODO: raise exceptions for actions for which the workflow is not yet supported (e.g., REMOVE_DATA) - - self._action = action - self._dataset_name = dataset_name - self._is_read_only_dataset = is_read_only_dataset - self._category = category - self._domain = domain - self._data_location = data_location - self._query = query - self._is_pending_data = is_pending_data - - @property - def data_location(self) -> Optional[str]: - """ - Location for acted-upon data. - - Returns - ------- - Optional[str] - Location for acted-upon data. - """ - return self._data_location - - @property - def is_pending_data(self) -> bool: - """ - Whether the sender has data pending transmission after this message. - - Whether the sender has data it wants to transmit after this message. The typical use case is during a - ``CREATE`` action, where this indicates there is already data to add to the newly created dataset. - - Returns - ------- - bool - Whether the sender has data pending transmission after this message. - """ - return self._is_pending_data - - @property - def data_category(self) -> Optional[DataCategory]: - """ - The category of the involved data, if applicable. - - Returns - ------- - bool - The category of the involved data, if applicable. - """ - return self._category - - @property - def data_domain(self) -> Optional[DataDomain]: - """ - The domain of the involved data, if applicable. - - Returns - ------- - Optional[DataDomain] - The domain of the involved data, if applicable. - """ - return self._domain - - @property - def dataset_name(self) -> Optional[str]: - """ - The name of the involved dataset, if applicable. - - Returns - ------- - Optional - The name of the involved dataset, if applicable. - """ - return self._dataset_name - - @property - def is_read_only_dataset(self) -> bool: - """ - Whether the dataset involved is, should be, or must be (depending on action) read-only. - - Returns - ------- - bool - Whether the dataset involved is, should be, or must be (depending on action) read-only. - """ - return self._is_read_only_dataset - - @property - def management_action(self) -> ManagementAction: - """ - The type of ::class:`ManagementAction` this message embodies or requests. - - Returns - ------- - ManagementAction - The type of ::class:`ManagementAction` this message embodies or requests. - """ - return self._action - - @property - def query(self) -> Optional[DatasetQuery]: - return self._query - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serial = {self._SERIAL_KEY_ACTION: self.management_action.name, - self._SERIAL_KEY_IS_READ_ONLY: self.is_read_only_dataset, - self._SERIAL_KEY_IS_PENDING_DATA: self.is_pending_data} - if self.dataset_name is not None: - serial[self._SERIAL_KEY_DATASET_NAME] = self.dataset_name - if self.data_category is not None: - serial[self._SERIAL_KEY_CATEGORY] = self.data_category.name - if self.data_location is not None: - serial[self._SERIAL_KEY_DATA_LOCATION] = self.data_location - if self.data_domain is not None: - serial[self._SERIAL_KEY_DATA_DOMAIN] = self.data_domain.to_dict() - if self.query is not None: - serial[self._SERIAL_KEY_QUERY] = self.query.to_dict() - return serial + super().__init__( + management_action=action or data.pop("management_action", None), + dataset_name=dataset_name, + is_read_only_dataset=is_read_only_dataset or data.pop("read_only", False), + data_category=category or data.pop("data_category", None), + data_domain=domain or data.pop("data_domain", None), + data_location=data_location, + is_pending_data=is_pending_data or data.pop("pending_data", False), + query=query, + **data + ) + + +class DatasetManagementResponseBody(SerializableDict): + action: Optional[ManagementAction] + data_id: Optional[str] + dataset_name: Optional[str] + item_name: Optional[str] + # TODO: in the future, tighten the type restrictions of this field + query_results: Optional[Dict[str, Any]] + is_awaiting: bool = False class DatasetManagementResponse(Response): - _DATA_KEY_ACTION= 'action' - _DATA_KEY_DATA_ID = 'data_id' - _DATA_KEY_DATASET_NAME = 'dataset_name' - _DATA_KEY_ITEM_NAME = 'item_name' - _DATA_KEY_QUERY_RESULTS = 'query_results' - _DATA_KEY_IS_AWAITING = 'is_awaiting' - response_to_type = DatasetManagementMessage + response_to_type: ClassVar[Type[AbstractInitRequest]] = DatasetManagementMessage - def __init__(self, action: Optional[ManagementAction] = None, is_awaiting: bool = False, - data_id: Optional[str] = None, dataset_name: Optional[str] = None, data: Optional[dict] = None, - **kwargs): - if data is None: - data = {} + data: DatasetManagementResponseBody + + def __init__( + self, + action: Optional[ManagementAction] = None, + is_awaiting: bool = False, + data_id: Optional[str] = None, + dataset_name: Optional[str] = None, + data: Optional[Union[dict, DatasetManagementResponseBody]] = None, + **kwargs + ): + data = data if isinstance(data, DatasetManagementResponseBody) else DatasetManagementResponseBody(**data or {}) # Make sure 'action' param and action string within 'data' param aren't both present and conflicting if action is not None: - if action.name != data.get(self._DATA_KEY_ACTION, action.name): + if action != data.action: msg = '{} initialized with {} action param, but {} action in initial data.' - raise ValueError(msg.format(self.__class__.__name__, action.name, data.get(self._DATA_KEY_ACTION))) - data[self._DATA_KEY_ACTION] = action.name - # Additionally, if not using an explicit 'action', make sure it's a valid action string in 'data', or bail - else: - data_action_str = data.get(self._DATA_KEY_ACTION, '') - # Compare the string to the 'name' string of the action value obtain by passing the string to get_for_name() - if data_action_str.strip().upper() != ManagementAction.get_for_name(data_action_str).name.upper(): - msg = "No valid action param or within 'data' when initializing {} instance (received only '{}')" - raise ValueError(msg.format(self.__class__.__name__, data_action_str)) - - data[self._DATA_KEY_IS_AWAITING] = is_awaiting + raise ValueError(msg.format(self.__class__.__name__, action.name, data.action.name if data.action else data.action)) + data.action = action + + data.is_awaiting = is_awaiting if data_id is not None: - data[self._DATA_KEY_DATA_ID] = data_id + data.data_id = data_id if dataset_name is not None: - data[self._DATA_KEY_DATASET_NAME] = dataset_name + data.dataset_name = dataset_name super().__init__(data=data, **kwargs) @property @@ -460,16 +318,9 @@ def action(self) -> ManagementAction: ManagementAction The action requested by the ::class:`DatasetManagementMessage` for which this instance is the response. """ - if self._DATA_KEY_ACTION not in self.data: - return ManagementAction.UNKNOWN - elif isinstance(self.data[self._DATA_KEY_ACTION], str): - return ManagementAction.get_for_name(self.data[self._DATA_KEY_ACTION]) - elif isinstance(self.data[self._DATA_KEY_ACTION], ManagementAction): - val = self.data[self._DATA_KEY_ACTION] - self.data[self._DATA_KEY_ACTION] = val.name - return val - else: + if self.data.action is None: return ManagementAction.UNKNOWN + return self.data.action @property def data_id(self) -> Optional[str]: @@ -481,7 +332,7 @@ def data_id(self) -> Optional[str]: Optional[str] When available, the 'data_id' of the related dataset. """ - return self.data[self._DATA_KEY_DATA_ID] if self._DATA_KEY_DATA_ID in self.data else None + return self.data.data_id @property def dataset_name(self) -> Optional[str]: @@ -493,7 +344,7 @@ def dataset_name(self) -> Optional[str]: Optional[str] When available, the name of the relevant dataset; otherwise ``None``. """ - return self.data[self._DATA_KEY_DATASET_NAME] if self._DATA_KEY_DATASET_NAME in self.data else None + return self.data.dataset_name @property def item_name(self) -> Optional[str]: @@ -505,11 +356,11 @@ def item_name(self) -> Optional[str]: Optional[str] The name of the relevant dataset item/object/file, or ``None``. """ - return self.data.get(self._DATA_KEY_ITEM_NAME) + return self.data.item_name @property def query_results(self) -> Optional[dict]: - return self.data.get(self._DATA_KEY_QUERY_RESULTS) + return self.data.query_results @property def is_awaiting(self) -> bool: @@ -525,7 +376,7 @@ def is_awaiting(self) -> bool: bool Whether the response indicates the response sender is awaiting something additional. """ - return self.data[self._DATA_KEY_IS_AWAITING] + return self.data.is_awaiting class MaaSDatasetManagementMessage(DatasetManagementMessage, ExternalRequest): @@ -536,9 +387,26 @@ class MaaSDatasetManagementMessage(DatasetManagementMessage, ExternalRequest): the superclass. """ - _SERIAL_KEY_DATA_REQUIREMENTS = 'data_requirements' - _SERIAL_KEY_OUTPUT_FORMATS = 'output_formats' - _SERIAL_KEY_SESSION_SECRET = 'session_secret' + data_requirements: List[DataRequirement] = Field( + default_factory=list, + description="List of all the explicit and implied data requirements for this request.", + ) + """ + By default, this is an empty list, though it is possible to append requirements to the list. + """ + + output_formats: List[DataFormat] = Field( + default_factory=list, + description="List of the formats of each required output dataset for the requested task." + ) + """ + By default, this will be an empty list, though if any request does need to produce output, + formats can be appended to it. + """ + + class Config: + # NOTE: in parent class, `ExternalRequest`, `session_secret` is aliased using `session-secret` + fields = {"session_secret": {"alias": "session_secret"}} @classmethod def factory_create(cls, mgmt_msg: DatasetManagementMessage, session_secret: str) -> 'MaaSDatasetManagementMessage': @@ -562,90 +430,33 @@ def factory_init_correct_response_subtype(cls, json_obj: dict) -> 'MaaSDatasetMa """ return MaaSDatasetManagementResponse.factory_init_from_deserialized_json(json_obj=json_obj) - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict) -> Optional['MaaSDatasetManagementMessage']: - try: - # Inject this if necessary before passing to supertype - if 'deserialized_class' not in json_obj: - json_obj['deserialized_class'] = cls - elif isinstance(json_obj['deserialized_class'], str): - json_obj['deserialized_class'] = globals()[json_obj['deserialized_class']] - # Also inject things that will be used as additional kwargs to the eventual class init - if 'deserialized_class_kwargs' not in json_obj: - json_obj['deserialized_class_kwargs'] = dict() - if 'session_secret' not in json_obj['deserialized_class_kwargs']: - json_obj['deserialized_class_kwargs']['session_secret'] = json_obj[cls._SERIAL_KEY_SESSION_SECRET] - - obj = super().factory_init_from_deserialized_json(json_obj=json_obj) - - # Also add these if there happened to be any present - if cls._SERIAL_KEY_DATA_REQUIREMENTS in json_obj: - obj.data_requirements.extend([DataRequirement.factory_init_from_deserialized_json(json) for json in - json_obj[cls._SERIAL_KEY_DATA_REQUIREMENTS]]) - if cls._SERIAL_KEY_OUTPUT_FORMATS in json_obj: - obj.output_formats.extend( - [DataFormat.get_for_name(f) for f in json_obj[cls._SERIAL_KEY_OUTPUT_FORMATS]]) - - # Finally, return the object - return obj - except Exception as e: - return None - - def __init__(self, session_secret: str, *args, **kwargs): - """ - - Keyword Args - ---------- - session_secret : str - action : ManagementAction - dataset_name : Optional[str] - is_read_only_dataset : bool - category : Optional[DataCategory] - data_location : Optional[str] - is_pending_data : bool - query : Optional[DataQuery] - """ - super(MaaSDatasetManagementMessage, self).__init__(session_secret=session_secret, *args, **kwargs) - self._data_requirements = [] - self._output_formats = [] - - @property - def data_requirements(self) -> List[DataRequirement]: - """ - List of all the explicit and implied data requirements for this request. - - By default, this is an empty list, though it is possible to append requirements to the list. - - Returns - ------- - List[DataRequirement] - List of all the explicit and implied data requirements for this request. - """ - return self._data_requirements - - @property - def output_formats(self) -> List[DataFormat]: - """ - List of the formats of each required output dataset for the requested task. - - By default, this will be an empty list, though if any request does need to produce output, formats can be - appended to it - - Returns - ------- - List[DataFormat] - List of the formats of each required output dataset for the requested. - """ - return self._output_formats - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serial = super(MaaSDatasetManagementMessage, self).to_dict() - serial[self._SERIAL_KEY_SESSION_SECRET] = self.session_secret - if len(self.data_requirements) > 0: - serial[self._SERIAL_KEY_DATA_REQUIREMENTS] = [r.to_dict() for r in self.data_requirements] - if len(self.output_formats) > 0: - serial[self._SERIAL_KEY_OUTPUT_FORMATS] = [f.name for f in self.output_formats] - return serial + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> Dict[str, Union[str, int]]: + exclude = exclude or set() + + if not self.data_requirements: + exclude.add("data_requirements") + if not self.output_formats: + exclude.add("output_formats") + + return super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) class MaaSDatasetManagementResponse(ExternalRequestResponse, DatasetManagementResponse): @@ -653,7 +464,7 @@ class MaaSDatasetManagementResponse(ExternalRequestResponse, DatasetManagementRe Analog of ::class:`DatasetManagementResponse`, but for the ::class:`MaaSDatasetManagementMessage` message type. """ - response_to_type = MaaSDatasetManagementMessage + response_to_type: ClassVar[Type[AbstractInitRequest]] = MaaSDatasetManagementMessage @classmethod def factory_create(cls, dataset_mgmt_response: DatasetManagementResponse) -> 'MaaSDatasetManagementResponse': @@ -670,4 +481,4 @@ def factory_create(cls, dataset_mgmt_response: DatasetManagementResponse) -> 'Ma MaaSDatasetManagementResponse Factory-created analog of this instance type. """ - return cls.factory_init_from_deserialized_json(dataset_mgmt_response.to_dict()) \ No newline at end of file + return cls.factory_init_from_deserialized_json(dataset_mgmt_response.to_dict()) diff --git a/python/lib/communication/dmod/communication/evaluation_request.py b/python/lib/communication/dmod/communication/evaluation_request.py index 4c633d259..78f76b6d4 100644 --- a/python/lib/communication/dmod/communication/evaluation_request.py +++ b/python/lib/communication/dmod/communication/evaluation_request.py @@ -3,9 +3,9 @@ import json from numbers import Number -from typing import Dict -from typing import Union +from pydantic import Field, validator +from dmod.core.serializable import Serializable from .message import Message, MessageEventType, Response SERIALIZABLE_DICT = typing.Dict[str, typing.Union[str, Number, dict, typing.List]] @@ -16,35 +16,32 @@ class EvaluationRequest(Message, abc.ABC): A request to be forwarded to the evaluation service """ - event_type: MessageEventType = MessageEventType.EVALUATION_REQUEST + event_type: typing.ClassVar[MessageEventType] = MessageEventType.EVALUATION_REQUEST """ :class:`MessageEventType`: the event type for this message implementation """ + action: str + @classmethod @abc.abstractmethod def get_action(cls) -> str: ... - @property - def action(self) -> str: - return self.get_action() - class EvaluationConnectionRequest(EvaluationRequest): """ A request used to communicate through a chained websocket connection """ - _action_parameters: typing.Dict[str, typing.Any] + action: typing.Literal["connect"] = "connect" + parameters: typing.Dict[str, typing.Any] = Field(default_factory=dict) - def __init__(self, **kwargs): - self._action_parameters = kwargs or dict() + class Config: + fields = { + "parameters": {"alias": "action_parameters"} + } @classmethod def get_action(cls) -> str: - return "connect" - - @property - def parameters(self) -> typing.Dict[str, typing.Any]: - return self._action_parameters + return cls.__fields__["action"].default @classmethod def factory_init_from_deserialized_json(cls, json_obj: dict) -> typing.Optional[EvaluationRequest]: @@ -64,20 +61,6 @@ def factory_init_from_deserialized_json(cls, json_obj: dict) -> typing.Optional[ return cls(**json_obj) - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - """ - Returns: - A dictionary representation of this request - """ - dictionary_representation = { - "action": self.action - } - - if self._action_parameters: - dictionary_representation['action_parameters'] = self._action_parameters.copy() - - return dictionary_representation - class EvaluationConnectionRequestResponse(Response): pass @@ -86,64 +69,64 @@ class EvaluationConnectionRequestResponse(Response): class SaveEvaluationRequest(EvaluationRequest): pass +class ActionParameters(Serializable): + evaluation_name: str + instructions: str + + @validator("instructions", pre=True) + def _coerce_instructions(cls, value): + if isinstance(value, dict): + return json.dumps(value, indent=4) + return value + class StartEvaluationRequest(EvaluationRequest): - @classmethod - def get_action(cls) -> str: - return "launch" + action: typing.Literal["launch"] = "launch" + parameters: typing.Dict[str, typing.Any] + + class Config: + fields = { + "parameters": {"alias": "action_parameters"} + } - evaluation_name: str = None + # Note: `parameters` is a dictionary representation of `ActionParameters` plus arbitrary keys + # and values + @validator("parameters", pre=True) + def _coerce_action_parameters(cls, value: typing.Union[typing.Dict[str, typing.Any], ActionParameters]): + if isinstance(value, ActionParameters): + return value.to_dict() - instructions: typing.Union[str, dict] = None + parameters = ActionParameters(**value) + return {**value, **parameters.to_dict()} - action_parameters: dict = None + @classmethod + def get_action(cls) -> str: + return cls.__fields__["action"].default @classmethod def factory_init_from_deserialized_json(cls, json_obj: dict) -> typing.Optional[EvaluationRequest]: try: - if "action" in json_obj and json_obj['action'] != cls.get_action(): + if "action" in json_obj and json_obj["action"] != cls.get_action(): return None - if "action_parameters" in json_obj: - parameters = json_obj['action_parameters'] - else: - parameters = json_obj - - missing_instructions = not parameters.get("instructions") \ - or not isinstance(parameters.get("instructions"), (str, dict)) - missing_name = not parameters.get("evaluation_name") - - if missing_instructions or missing_name: - return None - - return cls( - instructions=parameters.get("instructions"), - evaluation_name=parameters.get("evaluation_name"), - **parameters - ) - except Exception as e: + return cls(**json_obj) + except Exception: return None - def to_dict(self) -> SERIALIZABLE_DICT: - return { - "action": self.action, - "action_parameters": self.action_parameters.update( - { - "evaluation_name": self.evaluation_name, - "instructions": self.instructions - } - ) - } - def __init__( self, - instructions: str, - evaluation_name: str, + # NOTE: None for backwards compatibility + instructions: str = None, + evaluation_name: str = None, **kwargs ): - self._instructions = json.dumps(instructions, indent=4) if isinstance(instructions, dict) else instructions - self._evaluation_name = evaluation_name - self._action_parameters = kwargs + # assume no need for backwards compatibility + if instructions is None or evaluation_name is None: + super().__init__(**kwargs) + return + + parameters = ActionParameters(instructions=instructions, evaluation_name=evaluation_name, **kwargs) + super().__init__(parameters=parameters.to_dict()) class FindEvaluationRequest(EvaluationRequest): diff --git a/python/lib/communication/dmod/communication/maas_request/distribution.py b/python/lib/communication/dmod/communication/maas_request/distribution.py index aa1bef5a2..d088302bc 100644 --- a/python/lib/communication/dmod/communication/maas_request/distribution.py +++ b/python/lib/communication/dmod/communication/maas_request/distribution.py @@ -1,8 +1,28 @@ -class Distribution: +from dmod.core.serializable import Serializable + +from typing import Literal + + +class DistributionBounds(Serializable): + minimum: int = 0 + maximum: int = 0 + distribution_type: Literal["normal"] = "normal" + + class Config: + feilds = { + "distribution_type": {"alias": "type"}, + "minimum": {"alias": "min"}, + "maximum": {"alias": "max"}, + } + + +class Distribution(Serializable): """ Represents the definition of a distribution of numbers """ + distribution: DistributionBounds + def __init__( self, minimum: int = 0, maximum: int = 0, distribution_type: str = "normal" ): @@ -11,18 +31,38 @@ def __init__( :param int maximum: The upper bound of the distribution :param str distribution_type: The type of the distribution """ - self.minimum = minimum - self.maximum = maximum - self.distribution_type = distribution_type - - def to_dict(self): - return { - "distribution": { - "min": self.minimum, - "max": self.maximum, - "type": self.distribution_type, - } - } + super().__init__( + distribution=DistributionBounds( + minimum=minimum, maximum=maximum, distribution_type=distribution_type + ) + ) + + @property + def minimum(self) -> int: + """The lower bound for the distribution""" + return self.distribution.minimum + + @minimum.setter + def minimum(self, value: int): + self.distribution.minimum = value + + @property + def maximum(self) -> int: + """The upper bound for the distribution""" + return self.distribution.maximum + + @maximum.setter + def maximum(self, value: int): + self.distribution.maximum = value + + @property + def distribution_type(self) -> str: + """The type of the distribution""" + return self.distribution.distribution_type + + @distribution_type.setter + def distribution_type(self, value: str) -> str: + self.distribution.distribution_type = value def __str__(self): return str(self.to_dict()) diff --git a/python/lib/communication/dmod/communication/maas_request/dmod_job_request.py b/python/lib/communication/dmod/communication/maas_request/dmod_job_request.py index 5288e7941..451e59e56 100644 --- a/python/lib/communication/dmod/communication/maas_request/dmod_job_request.py +++ b/python/lib/communication/dmod/communication/maas_request/dmod_job_request.py @@ -11,9 +11,6 @@ class DmodJobRequest(AbstractInitRequest, ABC): The base class underlying all types of messages requesting execution of some kind of workflow job. """ - def __int__(self, *args, **kwargs): - super(DmodJobRequest, self).__int__(*args, **kwargs) - @property @abstractmethod def data_requirements(self) -> List[DataRequirement]: diff --git a/python/lib/communication/dmod/communication/maas_request/external_request.py b/python/lib/communication/dmod/communication/maas_request/external_request.py index 0b6df842a..b84093e55 100644 --- a/python/lib/communication/dmod/communication/maas_request/external_request.py +++ b/python/lib/communication/dmod/communication/maas_request/external_request.py @@ -6,6 +6,11 @@ class ExternalRequest(AbstractInitRequest, ABC): """ The base class underlying all types of externally-initiated (and, therefore, authenticated) MaaS system requests. """ + # NOTE: in some places this is serialized as `session-secret` + session_secret: str + + class Config: + fields = {"session_secret": {"alias": "session-secret"}} @classmethod @abstractmethod @@ -23,19 +28,7 @@ def factory_init_correct_response_subtype(cls, json_obj: dict): """ pass - def __init__(self, session_secret: str, *args, **kwargs): - """ - Initialize the base attributes and state of this request object. - - Parameters - ---------- - session_secret : str - The session secret for the right session when communicating with the MaaS request handler - """ - super(ExternalRequest, self).__init__(*args, **kwargs) - self.session_secret = session_secret - - def _check_class_compatible_for_equality(self, other) -> bool: + def _check_class_compatible_for_equality(self, other: object) -> bool: """ Check and return whether another object is of some class that is compatible for equality checking with the class of this instance, such that the class difference does not independently imply the other object and this instance diff --git a/python/lib/communication/dmod/communication/maas_request/external_request_response.py b/python/lib/communication/dmod/communication/maas_request/external_request_response.py index 36b320de1..f03d10d2d 100644 --- a/python/lib/communication/dmod/communication/maas_request/external_request_response.py +++ b/python/lib/communication/dmod/communication/maas_request/external_request_response.py @@ -1,13 +1,12 @@ from abc import ABC -from ..message import Response +from ..message import AbstractInitRequest, Response from .external_request import ExternalRequest +from typing import ClassVar, Type + class ExternalRequestResponse(Response, ABC): - response_to_type = ExternalRequest + response_to_type: ClassVar[Type[AbstractInitRequest]] = ExternalRequest """ The type of :class:`AbstractInitRequest` for which this type is the response""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) diff --git a/python/lib/communication/dmod/communication/maas_request/model_exec_request.py b/python/lib/communication/dmod/communication/maas_request/model_exec_request.py index de537091b..b98c3722a 100644 --- a/python/lib/communication/dmod/communication/maas_request/model_exec_request.py +++ b/python/lib/communication/dmod/communication/maas_request/model_exec_request.py @@ -1,14 +1,15 @@ from abc import ABC -from typing import Optional, Union +from typing import ClassVar, Dict, Optional, Union from dmod.core.execution import AllocationParadigm from ..message import MessageEventType from .dmod_job_request import DmodJobRequest from .external_request import ExternalRequest +from .model_exec_request_body import ModelExecRequestBody -def get_available_models() -> dict: +def get_available_models() -> Dict[str, "ModelExecRequest"]: """ :return: The names of all models mapped to their class """ @@ -25,14 +26,16 @@ class ModelExecRequest(ExternalRequest, DmodJobRequest, ABC): An abstract extension of ::class:`DmodJobRequest` for requesting model execution jobs. """ - event_type: MessageEventType = MessageEventType.MODEL_EXEC_REQUEST + event_type: ClassVar[MessageEventType] = MessageEventType.MODEL_EXEC_REQUEST - model_name = None + model_name: ClassVar[str] = None """(:class:`str`) The name of the model to be used""" - _DEFAULT_CPU_COUNT = 1 + _DEFAULT_CPU_COUNT: ClassVar[int] = 1 """ The default number of CPUs to assume are being requested for the job, when not explicitly provided. """ + model: ModelExecRequestBody + @classmethod def factory_init_correct_subtype_from_deserialized_json( cls, json_obj: dict @@ -55,14 +58,13 @@ def factory_init_correct_subtype_from_deserialized_json( A deserialized ::class:`ModelExecRequest` of the appropriate subtype. """ try: - for model in get_available_models(): - if model in json_obj["model"] or ( - "name" in json_obj["model"] and json_obj["model"]["name"] == model - ): - return get_available_models()[ - model - ].factory_init_from_deserialized_json(json_obj) - return None + model = json_obj["model"] + + # TODO: remove logic once `nwm` ModelExecRequest changes where it store the model name. + model_name = model["name"] if "name" in model else "nwm" + models = get_available_models() + + return models[model_name].factory_init_from_deserialized_json(json_obj) except: return None @@ -71,15 +73,16 @@ def get_model_name(cls) -> str: """ :return: The name of this model """ - return cls.model_name + return cls.__fields__["model"].type_.__fields__["name"].default def __init__( self, - config_data_id: str, + # required in prior version of code + config_data_id: str = None, + # optional in prior version of code cpu_count: Optional[int] = None, allocation_paradigm: Optional[Union[str, AllocationParadigm]] = None, - *args, - **kwargs + **data ): """ Initialize model-exec-specific attributes and state of this request object common to all model exec requests. @@ -89,19 +92,20 @@ def __init__( session_secret : str The session secret for the right session when communicating with the request handler. """ - super(ModelExecRequest, self).__init__(*args, **kwargs) - self._config_data_id = config_data_id - self._cpu_count = ( - cpu_count if cpu_count is not None else self._DEFAULT_CPU_COUNT - ) - if allocation_paradigm is None: - self._allocation_paradigm = AllocationParadigm.get_default_selection() - elif isinstance(allocation_paradigm, str): - self._allocation_paradigm = AllocationParadigm.get_from_name( - allocation_paradigm - ) - else: - self._allocation_paradigm = allocation_paradigm + # assume no need for backwards compatibility + if "model" in data: + super().__init__(**data) + return + + data["model"] = {"config_data_id": config_data_id} + + if cpu_count is not None: + data["model"]["cpu_count"] = cpu_count + + if allocation_paradigm is not None: + data["model"]["allocation_paradigm"] = cpu_count + + super().__init__(**data) def __eq__(self, other): if not self._check_class_compatible_for_equality(other): @@ -132,7 +136,7 @@ def allocation_paradigm(self) -> AllocationParadigm: AllocationParadigm The allocation paradigm desired for use with this request. """ - return self._allocation_paradigm + return self.model.allocation_paradigm @property def config_data_id(self) -> str: @@ -144,7 +148,7 @@ def config_data_id(self) -> str: str Value of ``data_id`` identifying the dataset with the primary configuration applicable to this request. """ - return self._config_data_id + return self.model.config_data_id @property def cpu_count(self) -> int: @@ -156,4 +160,4 @@ def cpu_count(self) -> int: int The number of processors requested for this job. """ - return self._cpu_count + return self.model.cpu_count diff --git a/python/lib/communication/dmod/communication/maas_request/model_exec_request_body.py b/python/lib/communication/dmod/communication/maas_request/model_exec_request_body.py new file mode 100644 index 000000000..4d063a448 --- /dev/null +++ b/python/lib/communication/dmod/communication/maas_request/model_exec_request_body.py @@ -0,0 +1,31 @@ +from abc import ABC +from pydantic import Field, validator + +from dmod.core.serializable import Serializable +from dmod.core.execution import AllocationParadigm + +from typing import ClassVar + + +class ModelExecRequestBody(Serializable, ABC): + _DEFAULT_CPU_COUNT: ClassVar[int] = 1 + """ The default number of CPUs to assume are being requested for the job, when not explicitly provided. """ + + # model type discriminator field. enables constructing correct subclass based on `name` field + # value. + # override `name` in subclasses using `typing.Literal` + # e.g. `name: Literal["ngen"] = "ngen"` + name: str = Field("", description="The name of the model to be used") + + config_data_id: str = Field(description="Uniquely identifies the dataset with the primary configuration for this request.") + cpu_count: int = Field(_DEFAULT_CPU_COUNT, gt=0, description="The number of processors requested for this job.") + allocation_paradigm: AllocationParadigm = Field( + default_factory=AllocationParadigm.get_default_selection, + description="The allocation paradigm desired for use when allocating resources for this request." + ) + + @validator("name", pre=True) + def _lower_model_name_(cls, value: str): + # NOTE: this should enable case insensitive subclass construction based on `name`, that is + # if all `name` field's are lowercase. + return str(value).lower() diff --git a/python/lib/communication/dmod/communication/maas_request/model_exec_request_response.py b/python/lib/communication/dmod/communication/maas_request/model_exec_request_response.py index f82593cbf..c8514bb12 100644 --- a/python/lib/communication/dmod/communication/maas_request/model_exec_request_response.py +++ b/python/lib/communication/dmod/communication/maas_request/model_exec_request_response.py @@ -1,34 +1,30 @@ from abc import ABC +from typing import Any, ClassVar, Dict, Optional, Type, Union +from pydantic import validator -from typing import Optional - -from ..message import InitRequestResponseReason +from ..scheduler_request import SchedulerRequestResponse, UNSUCCESSFUL_JOB +from ..message import AbstractInitRequest, InitRequestResponseReason from .external_request_response import ExternalRequestResponse from .model_exec_request import ModelExecRequest +from .model_exec_request_response_body import ModelExecRequestResponseBody class ModelExecRequestResponse(ExternalRequestResponse, ABC): - _data_dict_key_job_id = "job_id" - _data_dict_key_output_data_id = "output_data_id" - _data_dict_key_scheduler_response = "scheduler_response" - response_to_type = ModelExecRequest + response_to_type: ClassVar[Type[AbstractInitRequest]] = ModelExecRequest """ The type of :class:`AbstractInitRequest` for which this type is the response""" - @classmethod - def _convert_scheduler_response_to_data_attribute(cls, scheduler_response=None): - if scheduler_response is None: - return None - elif isinstance(scheduler_response, dict) and len(scheduler_response) == 0: - return {} - elif isinstance(scheduler_response, dict): - return scheduler_response - else: - return { - cls._data_dict_key_job_id: scheduler_response.job_id, - cls._data_dict_key_output_data_id: scheduler_response.output_data_id, - cls._data_dict_key_scheduler_response: scheduler_response.to_dict(), - } + data: Optional[Union[ModelExecRequestResponseBody, Dict[str, Any]]] = None + + @validator("data", pre=True) + def _convert_data_field(cls, value: Optional[Union[SchedulerRequestResponse, ModelExecRequestResponseBody, Dict[str, Any]]]) -> Optional[Union[ModelExecRequestResponseBody, Dict[str, Any]]]: + if value is None: + return value + + elif isinstance(value, SchedulerRequestResponse): + return ModelExecRequestResponseBody.from_scheduler_request_response(value) + + return value @classmethod def get_job_id_key(cls) -> str: @@ -40,7 +36,7 @@ def get_job_id_key(cls) -> str: str Serialization dictionary key for the field containing the ::attribute:`job_id` property. """ - return str(cls._data_dict_key_job_id) + return "job_id" @classmethod def get_output_data_id_key(cls) -> str: @@ -52,7 +48,7 @@ def get_output_data_id_key(cls) -> str: str Serialization dictionary key for the field containing the ::attribute:`output_data_id` property. """ - return str(cls._data_dict_key_output_data_id) + return "output_data_id" @classmethod def get_scheduler_response_key(cls) -> str: @@ -64,52 +60,35 @@ def get_scheduler_response_key(cls) -> str: str Serialization dictionary key for the field containing the 'scheduler_response' value. """ - return str(cls._data_dict_key_scheduler_response) - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. + return "scheduler_response" + + def __init__( + self, + scheduler_response: Optional[ + Union[ + SchedulerRequestResponse, ModelExecRequestResponseBody, Dict[str, Any] + ] + ] = None, + **kwargs + ): + if scheduler_response is None: + super().__init__(**kwargs) + return - Parameters - ---------- - json_obj + # NOTE: if `scheduler_response` is not None, it is given precedence over "data" that might + # be present in `kwargs`. + kwargs["data"] = scheduler_response + super().__init__(**kwargs) - Returns - ------- - response_obj : Response - A new object of this type instantiated from the deserialize JSON object dictionary, or none if the provided - parameter could not be used to instantiated a new object. + @property + def job_id(self) -> int: + if isinstance(self.data, ModelExecRequestResponseBody): + return self.data.job_id - See Also - ------- - _factory_init_data_attribute - """ - try: - return cls( - success=json_obj["success"], - reason=json_obj["reason"], - message=json_obj["message"], - scheduler_response=json_obj["data"], - ) - except Exception as e: - return None - - def __init__(self, scheduler_response=None, *args, **kwargs): - data = self._convert_scheduler_response_to_data_attribute(scheduler_response) - if data is not None: - kwargs["data"] = data - super().__init__(*args, **kwargs) + elif isinstance(self.data, dict) and "job_id" in self.data: + return self.data["job_id"] - @property - def job_id(self): - if ( - not isinstance(self.data, dict) - or self._data_dict_key_job_id not in self.data - ): - return -1 - else: - return self.data[self._data_dict_key_job_id] + return UNSUCCESSFUL_JOB @property def output_data_id(self) -> Optional[str]: @@ -121,13 +100,13 @@ def output_data_id(self) -> Optional[str]: Optional[str] The 'data_id' of the output dataset for requested job, if request was successful; otherwise ``None``. """ - if ( - not isinstance(self.data, dict) - or self._data_dict_key_output_data_id not in self.data - ): - return None - else: - return self.data[self._data_dict_key_output_data_id] + if isinstance(self.data, ModelExecRequestResponseBody): + return self.data.output_data_id + + elif isinstance(self.data, dict) and "output_data_id" in self.data: + return self.data["output_data_id"] + + return None @property def reason_enum(self): diff --git a/python/lib/communication/dmod/communication/maas_request/model_exec_request_response_body.py b/python/lib/communication/dmod/communication/maas_request/model_exec_request_response_body.py new file mode 100644 index 000000000..9f2c82759 --- /dev/null +++ b/python/lib/communication/dmod/communication/maas_request/model_exec_request_response_body.py @@ -0,0 +1,23 @@ +from ..scheduler_request import SchedulerRequestResponse, SchedulerRequestResponseBody + + +class ModelExecRequestResponseBody(SchedulerRequestResponseBody): + scheduler_response: SchedulerRequestResponse + + @classmethod + def from_scheduler_request_response( + cls, scheduler_response: SchedulerRequestResponse + ) -> "ModelExecRequestResponseBody": + return cls( + job_id=scheduler_response.job_id, + output_data_id=scheduler_response.output_data_id, + scheduler_response=scheduler_response.copy(), + ) + + # NOTE: legacy support. previously this class was treated as a dictionary + def __contains__(self, element: str) -> bool: + return element in self.__dict__ + + # NOTE: legacy support. previously this class was treated as a dictionary + def __getitem__(self, item: str): + return self.__dict__[item] diff --git a/python/lib/communication/dmod/communication/maas_request/ngen/ngen_exec_request_body.py b/python/lib/communication/dmod/communication/maas_request/ngen/ngen_exec_request_body.py new file mode 100644 index 000000000..a10d5d75f --- /dev/null +++ b/python/lib/communication/dmod/communication/maas_request/ngen/ngen_exec_request_body.py @@ -0,0 +1,33 @@ +from pydantic import validator + +from dmod.core.meta_data import TimeRange +from ..model_exec_request_body import ModelExecRequestBody + +from typing import List, Literal, Optional + + +class NGENRequestBody(ModelExecRequestBody): + name: Literal["ngen"] = "ngen" + + time_range: TimeRange + hydrofabric_uid: str + hydrofabric_data_id: str + bmi_config_data_id: str + # NOTE: consider pydantic.conlist to constrain this type rather than using validators + catchments: Optional[List[str]] + partition_cfg_data_id: Optional[str] + + @validator("catchments") + def validate_deduplicate_and_sort_catchments( + cls, value: List[str] + ) -> Optional[List[str]]: + if value is None: + return None + + deduped = set(value) + return sorted(list(deduped)) + + class Config: + fields = { + "partition_cfg_data_id": {"alias": "partition_config_data_id"}, + } diff --git a/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py b/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py index 074045370..d0b0a22ce 100644 --- a/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py +++ b/python/lib/communication/dmod/communication/maas_request/ngen/ngen_request.py @@ -1,6 +1,6 @@ -from numbers import Number +from pydantic import PrivateAttr -from typing import Dict, List, Optional, Set, Union +from typing import ClassVar, List, Optional, Set, Type, Union from dmod.core.meta_data import ( DataCategory, @@ -10,9 +10,10 @@ DiscreteRestriction, TimeRange, ) -from ...message import MessageEventType +from ...message import AbstractInitRequest, MessageEventType from ..model_exec_request import ModelExecRequest from ..model_exec_request_response import ModelExecRequestResponse +from .ngen_exec_request_body import NGENRequestBody class NGENRequest(ModelExecRequest): @@ -23,58 +24,13 @@ class NGENRequest(ModelExecRequest): model_name = "ngen" # FIXME case sentitivity """(:class:`str`) The name of the model to be used""" - @classmethod - def factory_init_from_deserialized_json( - cls, json_obj: dict - ) -> Optional["NGENRequest"]: - """ - Deserialize request formated as JSON to an instance. - - See the documentation of this type's ::method:`to_dict` for an example of the format of valid JSON. - - Parameters - ---------- - json_obj : dict - The serialized JSON representation of a request object. + model: NGENRequestBody - Returns - ------- - The deserialized ::class:`NGENRequest`, or ``None`` if the JSON was not valid for deserialization. - - See Also - ------- - ::method:`to_dict` - """ - try: - optional_kwargs_w_defaults = dict() - if "cpu_count" in json_obj["model"]: - optional_kwargs_w_defaults["cpu_count"] = json_obj["model"]["cpu_count"] - if "allocation_paradigm" in json_obj["model"]: - optional_kwargs_w_defaults["allocation_paradigm"] = json_obj["model"][ - "allocation_paradigm" - ] - if "catchments" in json_obj["model"]: - optional_kwargs_w_defaults["catchments"] = json_obj["model"][ - "catchments" - ] - if "partition_config_data_id" in json_obj["model"]: - optional_kwargs_w_defaults["partition_config_data_id"] = json_obj[ - "model" - ]["partition_config_data_id"] - - return cls( - time_range=TimeRange.factory_init_from_deserialized_json( - json_obj["model"]["time_range"] - ), - hydrofabric_uid=json_obj["model"]["hydrofabric_uid"], - hydrofabric_data_id=json_obj["model"]["hydrofabric_data_id"], - config_data_id=json_obj["model"]["config_data_id"], - bmi_cfg_data_id=json_obj["model"]["bmi_config_data_id"], - session_secret=json_obj["session-secret"], - **optional_kwargs_w_defaults - ) - except Exception as e: - return None + _hydrofabric_data_requirement = PrivateAttr(None) + _forcing_data_requirement = PrivateAttr(None) + _realization_cfg_data_requirement = PrivateAttr(None) + _bmi_cfg_data_requirement = PrivateAttr(None) + _partition_cfg_data_requirement = PrivateAttr(None) @classmethod def factory_init_correct_response_subtype( @@ -95,20 +51,23 @@ def factory_init_correct_response_subtype( json_obj=json_obj ) - def __eq__(self, other): - return ( - self.time_range == other.time_range - and self.hydrofabric_data_id == other.hydrofabric_data_id - and self.hydrofabric_uid == other.hydrofabric_uid - and self.config_data_id == other.config_data_id - and self.bmi_config_data_id == other.bmi_config_data_id - and self.session_secret == other.session_secret - and self.cpu_count == other.cpu_count - and self.partition_cfg_data_id == other.partition_cfg_data_id - and self.catchments == other.catchments - ) + def __eq__(self, other: "NGENRequest"): + try: + return ( + self.time_range == other.time_range + and self.hydrofabric_data_id == other.hydrofabric_data_id + and self.hydrofabric_uid == other.hydrofabric_uid + and self.config_data_id == other.config_data_id + and self.bmi_config_data_id == other.bmi_config_data_id + and self.session_secret == other.session_secret + and self.cpu_count == other.cpu_count + and self.partition_cfg_data_id == other.partition_cfg_data_id + and self.catchments == other.catchments + ) + except AttributeError: + return False - def __hash__(self): + def __hash__(self) -> int: hash_str = "{}-{}-{}-{}-{}-{}-{}-{}-{}".format( self.time_range.to_json(), self.hydrofabric_data_id, @@ -124,14 +83,15 @@ def __hash__(self): def __init__( self, - time_range: TimeRange, - hydrofabric_uid: str, - hydrofabric_data_id: str, - bmi_cfg_data_id: str, + # required in prior version of code + time_range: TimeRange = None, + hydrofabric_uid: str = None, + hydrofabric_data_id: str = None, + bmi_cfg_data_id: str = None, + # optional in prior version of code catchments: Optional[Union[Set[str], List[str]]] = None, partition_cfg_data_id: Optional[str] = None, - *args, - **kwargs + **data ): """ Initialize an instance. @@ -159,28 +119,24 @@ def __init__( session_secret : str The session secret for the right session when communicating with the MaaS request handler """ - super().__init__(*args, **kwargs) - self._time_range = time_range - self._hydrofabric_uid = hydrofabric_uid - self._hydrofabric_data_id = hydrofabric_data_id - self._bmi_config_data_id = bmi_cfg_data_id - self._part_config_data_id = partition_cfg_data_id - # Convert an initial list to a set to remove duplicates - try: - catchments = set(catchments) - # TypeError should mean that we received `None`, so just use that to set _catchments - except TypeError: - self._catchments = catchments - # Assuming we have a set now, move this set back to list and sort - else: - self._catchments = list(catchments) - self._catchments.sort() - - self._hydrofabric_data_requirement = None - self._forcing_data_requirement = None - self._realization_cfg_data_requirement = None - self._bmi_cfg_data_requirement = None - self._partition_cfg_data_requirement = None + # If `model` key is present, assume there is not a need for backwards compatibility + if "model" in data: + super().__init__(**data) + return + + # NOTE: backwards compatibility support. + model = NGENRequestBody( + time_range=time_range, + hydrofabric_uid=hydrofabric_uid, + hydrofabric_data_id=hydrofabric_data_id, + catchments=catchments, + partition_cfg_data_id=partition_cfg_data_id, + # previous version of code used `bmi_cfg_data_id` as parameter name. + bmi_config_data_id=bmi_cfg_data_id, + **data + ) + + super().__init__(model=model, **data) def _gen_catchments_domain_restriction( self, var_name: str = "catchment_id" @@ -237,7 +193,7 @@ def bmi_config_data_id(self) -> str: str Index value of ``data_id`` to uniquely identify sets of BMI module config data that are otherwise similar. """ - return self._bmi_config_data_id + return self.model.bmi_config_data_id @property def bmi_cfg_data_requirement(self) -> DataRequirement: @@ -252,7 +208,7 @@ def bmi_cfg_data_requirement(self) -> DataRequirement: if self._bmi_cfg_data_requirement is None: bmi_config_restrict = [ DiscreteRestriction( - variable="data_id", values=[self._bmi_config_data_id] + variable="data_id", values=[self.bmi_config_data_id] ) ] bmi_config_domain = DataDomain( @@ -260,7 +216,7 @@ def bmi_cfg_data_requirement(self) -> DataRequirement: discrete_restrictions=bmi_config_restrict, ) self._bmi_cfg_data_requirement = DataRequirement( - bmi_config_domain, True, DataCategory.CONFIG + domain=bmi_config_domain, is_input=True, category=DataCategory.CONFIG ) return self._bmi_cfg_data_requirement @@ -276,7 +232,7 @@ def catchments(self) -> Optional[List[str]]: Optional[List[str]] An optional list of catchment ids for those catchments in the request ngen execution. """ - return self._catchments + return self.model.catchments @property def forcing_data_requirement(self) -> DataRequirement: @@ -292,7 +248,7 @@ def forcing_data_requirement(self) -> DataRequirement: # TODO: going to need to address the CSV usage later forcing_domain = DataDomain( data_format=DataFormat.AORC_CSV, - continuous_restrictions=[self._time_range], + continuous_restrictions=[self.model.time_range], discrete_restrictions=[self._gen_catchments_domain_restriction()], ) self._forcing_data_requirement = DataRequirement( @@ -313,10 +269,10 @@ def hydrofabric_data_requirement(self) -> DataRequirement: if self._hydrofabric_data_requirement is None: hydro_restrictions = [ DiscreteRestriction( - variable="hydrofabric_id", values=[self._hydrofabric_uid] + variable="hydrofabric_id", values=[self.model.hydrofabric_uid] ), DiscreteRestriction( - variable="data_id", values=[self._hydrofabric_data_id] + variable="data_id", values=[self.model.hydrofabric_data_id] ), ] hydro_domain = DataDomain( @@ -343,7 +299,7 @@ def hydrofabric_data_id(self) -> str: str The data format ``data_id`` for the hydrofabric dataset to use in requested modeling. """ - return self._hydrofabric_data_id + return self.model.hydrofabric_data_id @property def hydrofabric_uid(self) -> str: @@ -355,7 +311,7 @@ def hydrofabric_uid(self) -> str: str The unique id of the hydrofabric for this modeling request. """ - return self._hydrofabric_uid + return self.model.hydrofabric_uid @property def output_formats(self) -> List[DataFormat]: @@ -384,7 +340,7 @@ def partition_cfg_data_id(self) -> Optional[str]: Optional[str] The data format ``data_id`` for the partition config dataset to use in requested modeling, or ``None``. """ - return self._part_config_data_id + return self.model.partition_cfg_data_id @property def partition_cfg_data_requirement(self) -> DataRequirement: @@ -480,54 +436,7 @@ def time_range(self) -> TimeRange: TimeRange The time range for the requested model execution. """ - return self._time_range - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - """ - Converts the request to a dictionary that may be passed to web requests - - Will look like: - - { - 'model': { - 'name': 'ngen', - 'allocation_paradigm': , - 'cpu_count': , - 'time_range': { }, - 'hydrofabric_data_id': 'hy-data-id-val', - 'hydrofabric_uid': 'hy-uid-val', - 'config_data_id': 'config-data-id-val', - 'bmi_config_data_id': 'bmi-config-data-id', - 'partition_config_data_id': 'partition_config_data_id', - ['catchments': { },] - 'version': 4.0 - }, - 'session-secret': 'secret-string-val' - } - - As a reminder, the ``catchments`` item may be absent, which implies the object does not have a specified list of - catchment ids. - - Returns - ------- - Dict[str, Union[str, Number, dict, list]] - A dictionary containing all the data in such a way that it may be used by a web request - """ - model = dict() - model["name"] = self.get_model_name() - model["allocation_paradigm"] = self.allocation_paradigm.name - model["cpu_count"] = self.cpu_count - model["time_range"] = self.time_range.to_dict() - model["hydrofabric_data_id"] = self.hydrofabric_data_id - model["hydrofabric_uid"] = self.hydrofabric_uid - model["config_data_id"] = self.config_data_id - model["bmi_config_data_id"] = self._bmi_config_data_id - if self.catchments is not None: - model["catchments"] = self.catchments - if self.partition_cfg_data_id is not None: - model["partition_config_data_id"] = self.partition_cfg_data_id - - return {"model": model, "session-secret": self.session_secret} + return self.model.time_range class NGENRequestResponse(ModelExecRequestResponse): @@ -565,4 +474,4 @@ class NGENRequestResponse(ModelExecRequestResponse): } """ - response_to_type = NGENRequest + response_to_type: ClassVar[Type[AbstractInitRequest]] = NGENRequest diff --git a/python/lib/communication/dmod/communication/maas_request/nwm/nwm_exec_request_body.py b/python/lib/communication/dmod/communication/maas_request/nwm/nwm_exec_request_body.py new file mode 100644 index 000000000..b6e1d2705 --- /dev/null +++ b/python/lib/communication/dmod/communication/maas_request/nwm/nwm_exec_request_body.py @@ -0,0 +1,75 @@ +from pydantic import root_validator + +from dmod.core.meta_data import ( + DataCategory, + DataDomain, + DataFormat, + DataRequirement, + DiscreteRestriction, +) +from dmod.core.execution import AllocationParadigm +from dmod.core.serializable import Serializable +from ..model_exec_request_body import ModelExecRequestBody + +from typing import List, Literal + + +class NWMInnerRequestBody(ModelExecRequestBody): + name: Literal["nwm"] = "nwm" + + # NOTE: default value, `None`, is not validated by pydantic + data_requirements: List[DataRequirement] = None + + @root_validator() + def _add_data_requirements_if_missing(cls, values: dict): + data_requirements = values["data_requirements"] + + # None is non-validated default + if data_requirements is None: + config_data_id: str = values["config_data_id"] + + data_id_restriction = DiscreteRestriction( + variable="data_id", values=[config_data_id] + ) + values["data_requirements"] = [ + DataRequirement( + domain=DataDomain( + data_format=DataFormat.NWM_CONFIG, + discrete_restrictions=[data_id_restriction], + ), + is_input=True, + category=DataCategory.CONFIG, + ) + ] + + return values + + class Config: + # NOTE: `name` field is not included at this point for backwards compatibility sake. This + # may change in the future. + fields = {"name": {"exclude": True}} + + +class NWMRequestBody(Serializable): + # TODO: flatten this hierarchy by replacing NWMRequestBody with NWMInnerRequestBody. + nwm: NWMInnerRequestBody + + @property + def name(self) -> str: + return self.nwm.name + + @property + def config_data_id(self) -> str: + return self.nwm.config_data_id + + @property + def cpu_count(self) -> int: + return self.nwm.cpu_count + + @property + def allocation_paradigm(self) -> AllocationParadigm: + return self.nwm.allocation_paradigm + + @property + def data_requirements(self) -> List[DataRequirement]: + return self.nwm.data_requirements diff --git a/python/lib/communication/dmod/communication/maas_request/nwm/nwm_request.py b/python/lib/communication/dmod/communication/maas_request/nwm/nwm_request.py index e9d947601..1d8acdac4 100644 --- a/python/lib/communication/dmod/communication/maas_request/nwm/nwm_request.py +++ b/python/lib/communication/dmod/communication/maas_request/nwm/nwm_request.py @@ -1,26 +1,27 @@ -from typing import List +from typing import ClassVar, List, Optional, Union +from dmod.core.execution import AllocationParadigm from dmod.core.meta_data import ( - DataCategory, - DataDomain, DataFormat, DataRequirement, - DiscreteRestriction, ) from ...message import MessageEventType from ..model_exec_request import ModelExecRequest from ..model_exec_request_response import ModelExecRequestResponse +from .nwm_exec_request_body import NWMRequestBody class NWMRequest(ModelExecRequest): - event_type = MessageEventType.MODEL_EXEC_REQUEST + event_type: ClassVar[MessageEventType] = MessageEventType.MODEL_EXEC_REQUEST """(:class:`MessageEventType`) The type of event for this message""" # Once more the case senstivity of this model name is called into question # note: this is essentially keyed to image_and_domain.yml and the cases must match! - model_name = "nwm" + model_name: ClassVar[str] = "nwm" """(:class:`str`) The name of the model to be used""" + model: NWMRequestBody + @classmethod def factory_init_correct_response_subtype( cls, json_obj: dict @@ -38,65 +39,38 @@ def factory_init_correct_response_subtype( """ return NWMRequestResponse.factory_init_from_deserialized_json(json_obj=json_obj) - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. - - Recall this will look something like: - - { - 'model': { - 'NWM': { - 'allocation_paradigm': '', - 'config_data_id': '', - 'cpu_count': , - 'data_requirements': [ ... (serialized DataRequirement objects) ... ] - } - } - 'session-secret': 'secret-string-val' - } + def __init__( + self, + # required in prior version of code + config_data_id: str = None, + # optional in prior version of code + cpu_count: Optional[int] = None, + allocation_paradigm: Optional[Union[str, AllocationParadigm]] = None, + **data + ): + # assume no need for backwards compatibility + if "model" in data: + super().__init__(**data) + return - Parameters - ---------- - json_obj + data["model"] = dict() + nwm_inner_request_body = {"config_data_id": config_data_id} - Returns - ------- - A new object of this type instantiated from the deserialize JSON object dictionary, or none if the provided - parameter could not be used to instantiated a new object. - """ - try: - nwm_element = json_obj["model"][cls.model_name] - additional_kwargs = dict() - if "cpu_count" in nwm_element: - additional_kwargs["cpu_count"] = nwm_element["cpu_count"] - - if "allocation_paradigm" in nwm_element: - additional_kwargs["allocation_paradigm"] = nwm_element[ - "allocation_paradigm" - ] - - obj = cls( - config_data_id=nwm_element["config_data_id"], - session_secret=json_obj["session-secret"], - **additional_kwargs - ) - - reqs = [ - DataRequirement.factory_init_from_deserialized_json(req_json) - for req_json in json_obj["model"][cls.model_name]["data_requirements"] - ] - - obj._data_requirements = reqs - - return obj - except Exception as e: - return None - - def __init__(self, *args, **kwargs): - super(NWMRequest, self).__init__(*args, **kwargs) - self._data_requirements = None + if cpu_count is not None: + nwm_inner_request_body["cpu_count"] = cpu_count + + if allocation_paradigm is not None: + nwm_inner_request_body["allocation_paradigm"] = allocation_paradigm + + data["model"]["nwm"] = nwm_inner_request_body + + super().__init__(**data) + + @classmethod + def get_model_name(cls) -> str: + # NOTE: overridden b.c. nwm request has nested model field. In the future we should be able + # to remove this. + return cls.__fields__["model"].type_.__fields__["nwm"].type_.__fields__["name"].default @property def data_requirements(self) -> List[DataRequirement]: @@ -108,21 +82,7 @@ def data_requirements(self) -> List[DataRequirement]: List[DataRequirement] List of all the explicit and implied data requirements for this request. """ - if self._data_requirements is None: - data_id_restriction = DiscreteRestriction( - variable="data_id", values=[self.config_data_id] - ) - self._data_requirements = [ - DataRequirement( - domain=DataDomain( - data_format=DataFormat.NWM_CONFIG, - discrete_restrictions=[data_id_restriction], - ), - is_input=True, - category=DataCategory.CONFIG, - ) - ] - return self._data_requirements + return self.model.data_requirements @property def output_formats(self) -> List[DataFormat]: @@ -136,40 +96,6 @@ def output_formats(self) -> List[DataFormat]: """ return [DataFormat.NWM_OUTPUT] - def to_dict(self) -> dict: - """ - Converts the request to a dictionary that may be passed to web requests. - - Will look like: - - { - 'model': { - 'NWM': { - 'allocation_paradigm': '', - 'config_data_id': '', - 'cpu_count': , - 'data_requirements': [ ... (serialized DataRequirement objects) ... ] - } - } - 'session-secret': 'secret-string-val' - } - - Returns - ------- - dict - A dictionary containing all the data in such a way that it may be used by a web request - """ - model = dict() - model[self.get_model_name()] = dict() - model[self.get_model_name()][ - "allocation_paradigm" - ] = self.allocation_paradigm.name - model[self.get_model_name()]["config_data_id"] = self.config_data_id - model[self.get_model_name()]["cpu_count"] = self.cpu_count - model[self.get_model_name()]["data_requirements"] = [ - r.to_dict() for r in self.data_requirements - ] - return {"model": model, "session-secret": self.session_secret} class NWMRequestResponse(ModelExecRequestResponse): diff --git a/python/lib/communication/dmod/communication/maas_request/parameter.py b/python/lib/communication/dmod/communication/maas_request/parameter.py index 6ac1f5ea3..71362f8ac 100644 --- a/python/lib/communication/dmod/communication/maas_request/parameter.py +++ b/python/lib/communication/dmod/communication/maas_request/parameter.py @@ -1,16 +1,11 @@ -class Scalar: +from dmod.core.serializable import Serializable + + +class Scalar(Serializable): """ Represents a parameter value that is bound to a single number """ - - def __init__(self, scalar: int): - """ - :param int scalar: The value for the parameter - """ - self.scalar = scalar - - def to_dict(self): - return {"scalar": self.scalar} + scalar: int def __str__(self): return str(self.scalar) @@ -19,16 +14,11 @@ def __repr__(self): return self.__str__() -class Parameter: +class Parameter(Serializable): """ Base clase for model parameter descriptions that a given model may expose to DMOD for dynamic parameter selection. """ - - def __init__(self, name): - """ - Set the base meta data of the parameter - """ - self.name = name + name: str class ScalarParameter(Parameter): @@ -36,7 +26,5 @@ class ScalarParameter(Parameter): A Scalar parameter is a simple interger parameter who's valid range are integer increments between min and max, inclusive. """ - - def __init__(self, min, max): - self.min = min - self.max = max + min: int + max: int diff --git a/python/lib/communication/dmod/communication/message.py b/python/lib/communication/dmod/communication/message.py index bad2e4869..3d4db4074 100644 --- a/python/lib/communication/dmod/communication/message.py +++ b/python/lib/communication/dmod/communication/message.py @@ -1,12 +1,14 @@ from abc import ABC -from enum import Enum -from typing import Type +from typing import Any, ClassVar, Dict, Literal, Optional, Type +from pydantic import Field from dmod.core.serializable import Serializable, ResultIndicator +from dmod.core.serializable_dict import SerializableDict +from dmod.core.enum import PydanticEnum #FIXME make an independent enum of model request types??? -class MessageEventType(Enum): +class MessageEventType(PydanticEnum): SESSION_INIT = 1 MODEL_EXEC_REQUEST = 2 @@ -36,7 +38,7 @@ class MessageEventType(Enum): INVALID = -1 -class InitRequestResponseReason(Enum): +class InitRequestResponseReason(PydanticEnum): """ Values for the ``reason`` attribute in responses to ``AbstractInitRequest`` messages. """ @@ -62,7 +64,7 @@ class Message(Serializable, ABC): Class representing communication message of some kind between parts of the NWM MaaS system. """ - event_type: MessageEventType = None + event_type: ClassVar[MessageEventType] = MessageEventType.INVALID """ :class:`MessageEventType`: the event type for this message implementation """ @classmethod @@ -77,9 +79,6 @@ def get_message_event_type(cls) -> MessageEventType: """ return cls.event_type - def __init__(self, *args, **kwargs): - pass - class AbstractInitRequest(Message, ABC): """ @@ -92,9 +91,6 @@ class AbstractInitRequest(Message, ABC): interactions. """ - def __int__(self, *args, **kwargs): - super(AbstractInitRequest, self).__int__(*args, **kwargs) - class Response(ResultIndicator, Message, ABC): """ @@ -124,65 +120,10 @@ class Response(ResultIndicator, Message, ABC): """ - response_to_type = AbstractInitRequest + response_to_type: ClassVar[Type[AbstractInitRequest]] = AbstractInitRequest """ The type of :class:`AbstractInitRequest` for which this type is the response""" - @classmethod - def _factory_init_data_attribute(cls, json_obj: dict): - """ - Initialize the argument value for a constructor param used to set the :attr:`data` attribute appropriate for - this type, given the parent JSON object, which may mean simply returning the value or may mean deserializing the - value to some object type, depending on the implementation. - - The intent is for this to be used by :meth:`factory_init_from_deserialized_json`, where initialization logic for - the value to be set as :attr:`data` from the provided param may vary depending on the particular class. - - In the default implementation, the value found at the 'data' key is simply directly returned, or None is - returned if the 'data' key is not found. - - Parameters - ---------- - json_obj : dict - the parent JSON object containing the desired data value under the 'data' key - - Returns - ------- - data : dict - the resulting data value object - - See Also - ------- - factory_init_from_deserialized_json - """ - try: - return json_obj['data'] - except Exception as e: - return None - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. - - Parameters - ---------- - json_obj - - Returns - ------- - response_obj : Response - A new object of this type instantiated from the deserialize JSON object dictionary, or none if the provided - parameter could not be used to instantiated a new object. - - See Also - ------- - _factory_init_data_attribute - """ - try: - return cls(success=json_obj['success'], reason=json_obj['reason'], message=json_obj['message'], - data=cls._factory_init_data_attribute(json_obj)) - except Exception as e: - return None + data: Optional[SerializableDict] @classmethod def get_message_event_type(cls) -> MessageEventType: @@ -211,24 +152,6 @@ def get_response_to_type(cls) -> Type[AbstractInitRequest]: """ return cls.response_to_type - def __init__(self, data=None, *args, **kwargs): - super(Response, self).__init__(*args, **kwargs) - self.data = data - - def __eq__(self, other): - return self.success == other.success and self.reason == other.reason and self.message == other.message \ - and self.data == other.data - - def to_dict(self) -> dict: - serial = super(Response, self).to_dict() - if self.data is None: - serial['data'] = {} - elif isinstance(self.data, dict): - serial['data'] = self.data - else: - serial['data'] = self.data.to_dict() - return serial - class InvalidMessage(AbstractInitRequest): """ @@ -236,58 +159,40 @@ class InvalidMessage(AbstractInitRequest): type. """ - event_type: MessageEventType = MessageEventType.INVALID + event_type: ClassVar[MessageEventType] = MessageEventType.INVALID """ :class:`MessageEventType`: the type of ``MessageEventType`` for which this message is applicable. """ - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. - - Parameters - ---------- - json_obj - - Returns - ------- - A new object of this type instantiated from the deserialize JSON object dictionary, or none if the provided - parameter could not be used to instantiated a new object. - """ - try: - return cls(content=json_obj['content']) - except: - return None - - def __init__(self, content: dict): - self.content = content - - def to_dict(self) -> dict: - return {'content': self.content} + content: Dict[str, Any] class InvalidMessageResponse(Response): - response_to_type = InvalidMessage + response_to_type: ClassVar[Type[AbstractInitRequest]] = InvalidMessage """ The type of :class:`AbstractInitRequest` for which this type is the response""" - def __init__(self, data=None): - super().__init__(success=False, - reason='Invalid Request Message', - message='Request message was not formatted as any known valid type', - data=data) + success = False + reason: Literal["Invalid Request message"] = "Invalid Request message" + message: Literal["Request message was not formatted as any known valid type"] = "Request message was not formatted as any known valid type" + data: Optional[SerializableDict] + def __init__(self, data: Optional[Serializable]=None, **kwargs): + super().__init__(data=data) + + +class HttpCode(SerializableDict): + http_code: int = Field(ge=100, le=599) class ErrorResponse(Response): """ A response to inform a client of an error that has occured within a request """ - def __init__(self, message: str, http_code: int = None): - if not http_code: - http_code = 500 - - if not isinstance(http_code, int): - try: - http_code = int(float(http_code)) - except: - http_code = str(http_code) - super().__init__(success=False, reason="Error", message=message, data={"http_code": http_code}) + success = False + reason: Literal["Error"] = "Error" + data: HttpCode = Field(default_factory=lambda: HttpCode(http_code=500)) + + def __init__(self, message: str, http_code: int = None, **kwargs): + if http_code is None: + super().__init__(message=message) + return + + super().__init__(message=message, data={"http_code": http_code}) diff --git a/python/lib/communication/dmod/communication/metadata_message.py b/python/lib/communication/dmod/communication/metadata_message.py index 193fa5ee4..2adf75b73 100644 --- a/python/lib/communication/dmod/communication/metadata_message.py +++ b/python/lib/communication/dmod/communication/metadata_message.py @@ -1,10 +1,13 @@ from .message import AbstractInitRequest, MessageEventType, Response -from enum import Enum +from dmod.core.serializable_dict import SerializableDict from numbers import Number -from typing import Dict, Optional, Union +from typing import ClassVar, Dict, Optional, Type, Union +from pydantic import Field, root_validator +from dmod.core.enum import PydanticEnum -class MetadataPurpose(Enum): + +class MetadataPurpose(PydanticEnum): CONNECT = 1, """ The metadata relates to the opening of a connection. """ DISCONNECT = 2, @@ -25,32 +28,53 @@ def get_value_for_name(cls, name_str: str) -> Optional['MetadataPurpose']: return None -class MetadataMessage(AbstractInitRequest): +class MetadataSignal(SerializableDict): + purpose: MetadataPurpose + metadata_follows: bool = False - event_type: MessageEventType = MessageEventType.METADATA + class Config: + fields = { + "metadata_follows": { + "alias": "additional_metadata", + "description": ( + "An indication of whether there is more metadata the sender needs to communicate beyond what is contained in this" + "message, thus letting the receiver know whether it should continue receiving after sending the response to this." + ), + } + } - _purpose_serial_key = 'purpose' - _description_serial_key = 'description' - _metadata_follows_serial_key = 'additional_metadata' - _config_changes_serial_key = 'config_changes' - _config_change_dict_type_key = 'config_value_dict_type' - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict) -> Optional['MetadataMessage']: - if cls._purpose_serial_key not in json_obj: - return None - purpose = MetadataPurpose.get_value_for_name(json_obj[cls._purpose_serial_key]) - if purpose is None: - return None - if cls._metadata_follows_serial_key in json_obj: - metadata_follows = json_obj[cls._metadata_follows_serial_key] - else: - # default to False for this, as this is pretty safe assumption if we don't see it explicit - metadata_follows = False - description = json_obj[cls._description_serial_key] if cls._description_serial_key in json_obj else None - cfg_changes = json_obj[cls._config_changes_serial_key] if cls._config_changes_serial_key in json_obj else None - return cls(purpose=purpose, description=description, metadata_follows=metadata_follows, - config_changes=cfg_changes) +class MetadataMessage(MetadataSignal, AbstractInitRequest): + + event_type: ClassVar[MessageEventType] = MessageEventType.INVALID + + description: Optional[str] + + config_changes: Optional[Dict[str, Union[None, str, bool, int, float, dict, list]]] = Field(description="A dictionary, keyed by strings, representing some configurable setting(s) that need their value(s) changed.") + """ + A dictionary, keyed by strings, representing some configurable setting(s) that need their value(s) changed. + + This will mainly be applicable when the purpose property is ``CHANGE_CONFIG``, and frequently can otherwise be + left to/expected to be ``None``. However, it should not be ``None`` when the purpose is ``CHANGE_CONFIG``. + + Note that the main dictionary can contain nested dictionaries also. These should essentially be the serialized + representations of ::class:`Serializable` object. While the type hinting does not explicitly note this due to + the recursive nature of the definition, nested dictionaries at any depth should have string keys and values of + one of the types allowed for values in the top-level dictionary. + + It is recommended that an additional value be added to such nested dictionaries, under the key returned by + ::method:`get_config_change_dict_type_key`. This should be the string representation of the class type of the + nested, serialized object. + """ + + @root_validator() + def validate_purpose(cls, values): + if values["purpose"] == MetadataPurpose.CHANGE_CONFIG and not values["config_changes"]: + raise RuntimeError('Invalid {} initialization, setting {} to {} but without any config changes.'.format( + cls.__class__, values["purpose"].__class__, values["purpose"].name)) + return values + + _config_change_dict_type_key: ClassVar[str] = 'config_value_dict_type' @classmethod def get_config_change_dict_type_key(cls) -> str: @@ -68,84 +92,14 @@ def get_config_change_dict_type_key(cls) -> str: """ return cls._config_change_dict_type_key - def __init__(self, purpose: MetadataPurpose, description: Optional[str] = None, metadata_follows: bool = False, - config_changes: Optional[Dict[str, Union[None, str, bool, Number, dict, list]]] = None): - self._purpose = purpose - self._description = description - self._metadata_follows = metadata_follows - self._config_changes = config_changes - if self._purpose == MetadataPurpose.CHANGE_CONFIG and not self._config_changes: - raise RuntimeError('Invalid {} initialization, setting {} to {} but without any config changes.'.format( - self.__class__, self._purpose.__class__, self._purpose.name)) - - @property - def config_changes(self) -> Optional[Dict[str, Union[None, str, bool, Number, dict, list]]]: - """ - A dictionary, keyed by strings, representing some configurable setting(s) that need their value(s) changed. - - This will mainly be applicable when the purpose property is ``CHANGE_CONFIG``, and frequently can otherwise be - left to/expected to be ``None``. However, it should not be ``None`` when the purpose is ``CHANGE_CONFIG``. - - Note that the main dictionary can contain nested dictionaries also. These should essentially be the serialized - representations of ::class:`Serializable` object. While the type hinting does not explicitly note this due to - the recursive nature of the definition, nested dictionaries at any depth should have string keys and values of - one of the types allowed for values in the top-level dictionary. - - It is recommended that an additional value be added to such nested dictionaries, under the key returned by - ::method:`get_config_change_dict_type_key`. This should be the string representation of the class type of the - nested, serialized object. - - Returns - ------- - Optional[Dict[str, Union[None, str, bool, Number, dict]]] - A dictionary, keyed by strings, representing some configurable setting(s) that need their value(s) changed. - """ - # This should get handled in __init__ but put here anyway - if self._purpose == MetadataPurpose.CHANGE_CONFIG and not self._config_changes: - raise RuntimeError('Invalid {} initialization, setting {} to {} but without any config changes.'.format( - self.__class__, self._purpose.__class__, self._purpose.name)) - return self._config_changes - - @property - def description(self) -> Optional[str]: - return self._description - - @property - def metadata_follows(self) -> bool: - """ - An indication of whether there is more metadata the sender needs to communicate beyond what is contained in this - message, thus letting the receiver know whether it should continue receiving after sending the response to this. - - Returns - ------- - bool - An indication of whether there is more metadata the sender needs to communicate beyond what is contained in - this message, thus letting the receiver know whether it should continue receiving after sending the response - to this. - """ - return self._metadata_follows - - @property - def purpose(self) -> MetadataPurpose: - return self._purpose - - def to_dict(self) -> dict: - result = {self._purpose_serial_key: self.purpose.name, self._metadata_follows_serial_key: self.metadata_follows} - if self.description: - result[self._description_serial_key] = self.description - if self.config_changes: - result[self._config_changes_serial_key] = self.config_changes - return result - class MetadataResponse(Response): """ The subtype of ::class:`Response` appropriate for ::class:`MetadataMessage` objects. """ - _metadata_follows_serial_key = MetadataMessage._metadata_follows_serial_key - _purpose_serial_key = MetadataMessage._purpose_serial_key - response_to_type = MetadataMessage + response_to_type: ClassVar[Type[AbstractInitRequest]] = MetadataMessage + data: MetadataSignal @classmethod def factory_create(cls, success: bool, reason: str, purpose: MetadataPurpose, expect_more: bool, message: str = ''): @@ -165,16 +119,14 @@ def factory_create(cls, success: bool, reason: str, purpose: MetadataPurpose, ex ------- """ - data = {cls._purpose_serial_key: purpose.name, cls._metadata_follows_serial_key: expect_more} - return cls(success=success, reason=reason, data=data, message=message) + data = MetadataSignal(purpose=purpose, metadata_follows=expect_more) - def __init__(self, success: bool, reason: str, data: dict, message: str = ''): - super().__init__(success=success, reason=reason, message=message, data=data) + return cls(success=success, reason=reason, data=data, message=message) @property def metadata_follows(self) -> bool: - return self.data[self._metadata_follows_serial_key] + return self.data.metadata_follows @property def purpose(self) -> MetadataPurpose: - return MetadataPurpose.get_value_for_name(self.data[self._purpose_serial_key]) + return self.data.purpose diff --git a/python/lib/communication/dmod/communication/partition_request.py b/python/lib/communication/dmod/communication/partition_request.py index 4184eaa92..5a85342ac 100644 --- a/python/lib/communication/dmod/communication/partition_request.py +++ b/python/lib/communication/dmod/communication/partition_request.py @@ -1,6 +1,7 @@ from uuid import uuid4 -from numbers import Number -from typing import Optional, Union, Dict +from pydantic import Field, Extra +from typing import ClassVar, Dict, Optional, Type, Union +from dmod.core.serializable_dict import SerializableDict from .message import AbstractInitRequest, MessageEventType, Response from .maas_request import ExternalRequest @@ -11,27 +12,23 @@ class PartitionRequest(AbstractInitRequest): Request for partitioning of the catchments in a hydrofabric, typically for distributed processing. """ - event_type = MessageEventType.PARTITION_REQUEST - _KEY_NUM_PARTS = 'partition_count' - _KEY_NUM_CATS = 'catchment_count' - _KEY_UUID = 'uuid' - _KEY_HYDROFABRIC_UID = 'hydrofabric_uid' - _KEY_HYDROFABRIC_DATA_ID = 'hydrofabric_data_id' - _KEY_HYDROFABRIC_DESC = 'hydrofabric_description' + event_type: ClassVar[MessageEventType] = MessageEventType.PARTITION_REQUEST - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict, **kwargs): - hy_data_id = json_obj[cls._KEY_HYDROFABRIC_DATA_ID] if cls._KEY_HYDROFABRIC_DATA_ID in json_obj else None - - try: - return cls(hydrofabric_uid=json_obj[cls._KEY_HYDROFABRIC_UID], - hydrofabric_data_id=hy_data_id, - num_partitions=json_obj[cls._KEY_NUM_PARTS], - description=json_obj.get(cls._KEY_HYDROFABRIC_DESC), - uuid=json_obj[cls._KEY_UUID], - **kwargs) - except: - return None + num_partitions: int + uuid: Optional[str] = Field(default_factory=lambda: str(uuid4()), description="Get (as a string) the UUID for this instance.") + hydrofabric_uid: str = Field(description="The unique identifier for the hydrofabric that is to be partitioned.") + hydrofabric_data_id: Optional[str] = Field(description="When known, the 'data_id' for the dataset containing the associated hydrofabric.") + description: Optional[str] = Field(description="The optional description or name of the hydrofabric that is to be partitioned.") + + class Config: + fields = { + "num_partitions": {"alias": "partition_count"}, + "description": {"alias": "hydrofabric_description"} + } + + # QUESTION: is this unused? + # catchment_count: str + # _KEY_NUM_CATS = 'catchment_count' @classmethod def factory_init_correct_response_subtype(cls, json_obj: dict): @@ -48,8 +45,17 @@ def factory_init_correct_response_subtype(cls, json_obj: dict): """ return PartitionResponse.factory_init_from_deserialized_json(json_obj=json_obj) - def __init__(self, num_partitions: int, hydrofabric_uid: str, hydrofabric_data_id: Optional[str] = None, - uuid: Optional[str] = None, description: Optional[str] = None, *args, **kwargs): + def __init__( + self, + *, + hydrofabric_uid: str, + # NOTE: default is None for backwards compatibility. could be specified using alias. + num_partitions: int = None, + hydrofabric_data_id: Optional[str] = None, + uuid: Optional[str] = None, + description: Optional[str] = None, + **data + ): """ Initialize the request. @@ -66,12 +72,15 @@ def __init__(self, num_partitions: int, hydrofabric_uid: str, hydrofabric_data_i description : Optional[str] An optional description or name for the hydrofabric. """ - super(PartitionRequest, self).__init__(*args, **kwargs) - self._hydrofabric_uid = hydrofabric_uid - self._hydrofabric_data_id = hydrofabric_data_id - self._num_partitions = num_partitions - self._uuid = uuid if uuid else str(uuid4()) - self._description = description + + super().__init__( + num_partitions=num_partitions or data.pop("partition_count", None), + hydrofabric_uid=hydrofabric_uid, + hydrofabric_data_id=hydrofabric_data_id, + uuid=uuid, + description=description or data.pop("hydrofabric_description", None), + **data + ) def __eq__(self, other): return self.uuid == other.uuid and self.hydrofabric_uid == other.hydrofabric_uid and self.hydrofabric_data_id == other.hydrofabric_data_id @@ -79,72 +88,37 @@ def __eq__(self, other): def __hash__(self): return hash("{}{}{}".format(self.uuid, self.hydrofabric_uid, self.hydrofabric_data_id)) - @property - def description(self) -> Optional[str]: - """ - The optional description or name of the hydrofabric that is to be partitioned. - - Returns - ------- - Optional[str] - The optional description or name of the hydrofabric that is to be partitioned. - """ - return self._description - - @property - def hydrofabric_data_id(self) -> Optional[str]: - """ - When known, the 'data_id' for the dataset containing the associated hydrofabric. - - Returns - ------- - Optional[str] - When known, the 'data_id' for the dataset containing the associated hydrofabric. - """ - return self._hydrofabric_data_id + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> Dict[str, Union[str, int]]: + serial = super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + # include "class_name" if not in excludes + if exclude is not None and "class_name" not in exclude: + serial["class_name"] = self.__class__.__name__ - @property - def hydrofabric_uid(self) -> str: - """ - The unique identifier for the hydrofabric that is to be partitioned. - - Returns - ------- - str - The unique identifier for the hydrofabric that is to be partitioned. - """ - return self._hydrofabric_uid - - @property - def num_partitions(self) -> int: - return self._num_partitions - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serialized = { - 'class_name': self.__class__.__name__, - self._KEY_HYDROFABRIC_UID: self.hydrofabric_uid, - self._KEY_NUM_PARTS: self.num_partitions, - #self._KEY_SECRET: self.session_secret, - self._KEY_UUID: self.uuid, - } - if self.description is not None: - serialized[self._KEY_HYDROFABRIC_DESC] = self.description - if self.hydrofabric_data_id is not None: - serialized[self._KEY_HYDROFABRIC_DATA_ID] = self.hydrofabric_data_id - return serialized - - @property - def uuid(self) -> str: - """ - Get (as a string) the UUID for this instance. + return serial - Returns - ------- - str - The UUID for this instance, as a string. - """ - return self._uuid +class PartitionResponseBody(SerializableDict): + data_id: Optional[str] + dataset_name: Optional[str] class PartitionResponse(Response): """ @@ -152,24 +126,24 @@ class PartitionResponse(Response): A successful response will contain the serialized partition representation within the ::attribute:`data` property. """ - _DATA_KEY_DATASET_DATA_ID = 'data_id' - _DATA_KEY_DATASET_NAME = 'dataset_name' - response_to_type = PartitionRequest + data: PartitionResponseBody + + response_to_type: ClassVar[Type[AbstractInitRequest]] = PartitionRequest @classmethod def factory_create(cls, dataset_name: Optional[str], dataset_data_id: Optional[str], reason: str, message: str = '', data: Optional[dict] = None): - data_dict = {cls._DATA_KEY_DATASET_DATA_ID: dataset_data_id, cls._DATA_KEY_DATASET_NAME: dataset_name} + data_dict = {"data_id": dataset_data_id, "dataset_name": dataset_name} if data is not None: data_dict.update(data) return cls(success=(dataset_data_id is not None), reason=reason, message=message, data=data_dict) - def __init__(self, success: bool, reason: str, message: str = '', data: Optional[dict] = None): - if data is None: - data = {} + def __init__(self, success: bool, reason: str, message: str = '', data: Optional[Union[dict, PartitionResponseBody]] = None): + data = data if isinstance(data, PartitionResponseBody) else PartitionResponseBody(**data or {}) + if not success: - data[self._DATA_KEY_DATASET_DATA_ID] = None - data[self._DATA_KEY_DATASET_NAME] = None + data.data_id = None + data.dataset_name = None super().__init__(success=success, reason=reason, message=message, data=data) @property @@ -182,7 +156,7 @@ def dataset_data_id(self) -> Optional[str]: Optional[str] The 'data_id' of the dataset where the partition config is saved when requests are successful. """ - return self.data[self._DATA_KEY_DATASET_DATA_ID] + return self.data.data_id @property def dataset_name(self) -> Optional[str]: @@ -194,35 +168,44 @@ def dataset_name(self) -> Optional[str]: Optional[str] The name of the dataset where the partitioning config is saved when requests are successful. """ - return self.data[self._DATA_KEY_DATASET_NAME] + return self.data.dataset_name + + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> Dict[str, Union[str, int]]: + class_name_in_exclude = exclude is not None and "class_name" in exclude + + serial = super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + if not class_name_in_exclude: + serial["class_name"] = self.__class__.__name__ - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serial = super(PartitionResponse, self).to_dict() - serial['class_name'] = self.__class__.__name__ return serial class PartitionExternalRequest(PartitionRequest, ExternalRequest): - _KEY_SECRET = 'session_secret' - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict, **kwargs): - try: - kwargs['session_secret'] = json_obj[cls._KEY_SECRET] - return super(PartitionExternalRequest, cls).factory_init_from_deserialized_json(json_obj, **kwargs) - except: - return None - - def __init__(self, *args, **kwargs): - super(PartitionExternalRequest, self).__init__(*args, **kwargs) - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serial = super(PartitionExternalRequest, self).to_dict() - serial[self._KEY_SECRET] = self.session_secret - return serial + class Config: + # NOTE: in parent class, `ExternalRequest`, `session_secret` is aliased using `session-secret` + fields = {"session_secret": {"alias": "session_secret"}} class PartitionExternalResponse(PartitionResponse): - response_to_type = PartitionExternalRequest + response_to_type: ClassVar[Type[AbstractInitRequest]] = PartitionExternalRequest diff --git a/python/lib/communication/dmod/communication/registered/registered_message.py b/python/lib/communication/dmod/communication/registered/registered_message.py index 12c20233e..c80f58982 100644 --- a/python/lib/communication/dmod/communication/registered/registered_message.py +++ b/python/lib/communication/dmod/communication/registered/registered_message.py @@ -5,8 +5,7 @@ import abc import typing from numbers import Number -from typing import Dict -from typing import Union +from typing import ClassVar, Dict, Union from ..message import AbstractInitRequest from ..message import MessageEventType @@ -296,7 +295,7 @@ class FieldedMessage(AbstractInitRequest): """ A message formed by dictated fields coming from subclasses """ - event_type: MessageEventType = MessageEventType.INFORMATION_UPDATE + event_type: ClassVar[MessageEventType] = MessageEventType.INFORMATION_UPDATE """ The event type for this message; this shouldn't have as much bearing on how to handle this message. Use members and class type instead. diff --git a/python/lib/communication/dmod/communication/scheduler_request.py b/python/lib/communication/dmod/communication/scheduler_request.py index 790534475..9fb41ba91 100644 --- a/python/lib/communication/dmod/communication/scheduler_request.py +++ b/python/lib/communication/dmod/communication/scheduler_request.py @@ -1,14 +1,42 @@ from dmod.core.execution import AllocationParadigm -from .maas_request import ModelExecRequest, ModelExecRequestResponse +from .maas_request import ModelExecRequest from .message import AbstractInitRequest, MessageEventType, Response -from typing import Optional, Union - +from .scheduler_request_response_body import SchedulerRequestResponseBody, UNSUCCESSFUL_JOB +from pydantic import Field, PrivateAttr, validator +from typing import ClassVar, Dict, Optional, Type, Union class SchedulerRequestMessage(AbstractInitRequest): - event_type: MessageEventType = MessageEventType.SCHEDULER_REQUEST + event_type: ClassVar[MessageEventType] = MessageEventType.SCHEDULER_REQUEST """ :class:`MessageEventType`: the event type for this message implementation """ + model_request: ModelExecRequest = Field(description="The underlying request for a job to be scheduled.") + user_id: str = Field(description="The associated user id for this scheduling request.") + memory: int = Field(500_000, description="The amount of memory, in bytes, requested for the scheduling of this job.") + cpus_: Optional[int] = Field(description="The number of processors requested for the scheduling of this job.") + allocation_paradigm_: Optional[AllocationParadigm] + + _memory_unset: bool = PrivateAttr() + + @validator("model_request", pre=True) + def _factory_init_model_request(cls, value): + if isinstance(value, ModelExecRequest): + return value + return ModelExecRequest.factory_init_correct_subtype_from_deserialized_json(value) + + @validator("allocation_paradigm_", pre=True) + def _dekabob_input(cls, value: Optional[Union[AllocationParadigm, str]]) -> Optional[Union[AllocationParadigm, str]]: + if isinstance(value, str): + return value.replace("-", "_") + return value + + class Config: + fields = { + "memory": {"alias": "mem"}, + "cpus_": {"alias": "cpus"}, + "allocation_paradigm_": {"alias": "allocation_paradigm"}, + } + @classmethod def default_allocation_paradigm_str(cls) -> str: """ @@ -27,61 +55,28 @@ def default_allocation_paradigm_str(cls) -> str: """ return AllocationParadigm.get_default_selection().name - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. - - Parameters - ---------- - json_obj - - Returns - ------- - SchedulerRequestMessage - A new object of this type instantiated from the deserialize JSON object dictionary, or ``None`` if the - provided parameter could not be used to instantiated a new object of this type. - """ - try: - model_request = ModelExecRequest.factory_init_correct_subtype_from_deserialized_json(json_obj['model_request']) - if model_request is not None: - alloc_paradigm = json_obj['allocation_paradigm'] if 'allocation_paradigm' in json_obj else None - return cls(model_request=model_request, - user_id=json_obj['user_id'], - # This may be absent to indicate use the value from the backing model request - cpus=json_obj['cpus'] if 'cpus' in json_obj else None, - # This may be absent to indicate it should be marked "unset" and a default should be used - mem=json_obj['mem'] if 'mem' in json_obj else None, - allocation_paradigm=alloc_paradigm) - else: - return None - except: - return None - # TODO: may need to generalize the underlying request to support, say, scheduling evaluation jobs - def __init__(self, model_request: ModelExecRequest, user_id: str, cpus: Optional[int] = None, mem: Optional[int] = None, - allocation_paradigm: Optional[Union[str, AllocationParadigm]] = None): - self._model_request = model_request - self._user_id = user_id - self._cpus = cpus + def __init__( + self, + model_request: ModelExecRequest, + user_id: str, + cpus: Optional[int] = None, + mem: Optional[int] = None, + allocation_paradigm: Optional[Union[str, AllocationParadigm]] = None, + **data + ): + super().__init__( + model_request=model_request, + user_id=user_id, + cpus=cpus or data.pop("cpus_", None), + memory=mem or data.pop("memory", None) or self.__fields__["memory"].default, + allocation_paradigm=allocation_paradigm or data.pop("allocation_paradigm_", None), + **data + ) if mem is None: self._memory_unset = True - self._memory = 500000 else: self._memory_unset = False - self._memory = mem - if isinstance(allocation_paradigm, str): - self._allocation_paradigm = AllocationParadigm.get_from_name(allocation_paradigm) - else: - self._allocation_paradigm = allocation_paradigm - - def __eq__(self, other): - return self.__class__ == other.__class__ \ - and self.model_request == other.model_request \ - and self.cpus == other.cpus \ - and self.memory == other.memory \ - and self.user_id == other.user_id \ - and self.allocation_paradigm == other.allocation_paradigm @property def allocation_paradigm(self) -> AllocationParadigm: @@ -93,10 +88,10 @@ def allocation_paradigm(self) -> AllocationParadigm: AllocationParadigm The allocation paradigm requested for the job to be scheduled. """ - if self._allocation_paradigm is None: + if self.allocation_paradigm_ is None: return self.model_request.allocation_paradigm else: - return self._allocation_paradigm + return self.allocation_paradigm_ @property def cpus(self) -> int: @@ -111,19 +106,7 @@ def cpus(self) -> int: int The number of processors requested for the scheduling of this job. """ - return self.model_request.cpu_count if self._cpus is None else self._cpus - - @property - def memory(self) -> int: - """ - The amount of memory, in bytes, requested for the scheduling of this job. - - Returns - ------- - int - The amount of memory, in bytes, requested for the scheduling of this job. - """ - return self._memory + return self.model_request.cpu_count if self.cpus_ is None else self.cpus_ @property def memory_unset(self) -> bool: @@ -137,18 +120,6 @@ def memory_unset(self) -> bool: """ return self._memory_unset - @property - def model_request(self) -> ModelExecRequest: - """ - The underlying request for a job to be scheduled. - - Returns - ------- - ModelExecRequest - The underlying request for a job to be scheduled. - """ - return self._model_request - @property def nested_event(self) -> MessageEventType: """ @@ -161,61 +132,68 @@ def nested_event(self) -> MessageEventType: """ return self.model_request.get_message_event_type() - @property - def user_id(self) -> str: - """ - The associated user id for this scheduling request. - - Returns - ------- - str - The associated user id for this scheduling request. - """ - return self._user_id - - def to_dict(self) -> dict: - serial = {'model_request': self.model_request.to_dict(), 'user_id': self.user_id} - if self._allocation_paradigm is not None: - serial['allocation_paradigm'] = self._allocation_paradigm.name - # Don't include this in serial form if property value is sourced from underlying model request - if self._cpus is not None: - serial['cpus'] = self._cpus + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> Dict[str, Union[str, int]]: # Only including memory value in serial form if it was explicitly set in the first place - if not self.memory_unset: - serial['mem'] = self.memory - return serial - + if self.memory_unset: + exclude = {"memory"} if exclude is None else {"memory", *exclude} + + return super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) class SchedulerRequestResponse(Response): - response_to_type = SchedulerRequestMessage + + response_to_type: ClassVar[Type[AbstractInitRequest]] = SchedulerRequestMessage + + data: Union[SchedulerRequestResponseBody, Dict[None, None], None] def __init__(self, job_id: Optional[int] = None, output_data_id: Optional[str] = None, data: dict = None, **kwargs): # TODO: how to handle if kwargs has success=True, but job_id value (as param or in data) implies success=False - key_job_id = ModelExecRequestResponse.get_job_id_key() + # Create an empty data if not supplied a dict, but only if there is a job_id or output_data_id to insert if data is None and (job_id is not None or output_data_id is not None): data = {} + # Prioritize provided job_id over something already in data # Note that this condition implies that either a data dict was passed as param, or one just got created above if job_id is not None: - data[key_job_id] = job_id + data["job_id"] = job_id + # Insert this into dict if present also (again, it being non-None implies data must be a dict object) if output_data_id is not None: - data[ModelExecRequestResponse.get_output_data_id_key()] = output_data_id + data["output_data_id"] = output_data_id + # Ensure that 'success' is being passed as a kwarg to the superclass constructor - if 'success' not in kwargs: - kwargs['success'] = data is not None and key_job_id in data and data[key_job_id] > 0 - super(SchedulerRequestResponse, self).__init__(data=data, **kwargs) + if "success" not in kwargs: + kwargs["success"] = data is not None and "job_id" in data and data["job_id"] > 0 + + super().__init__(data=data, **kwargs) def __eq__(self, other): return self.__class__ == other.__class__ and self.success == other.success and self.job_id == other.job_id @property - def job_id(self): - if self.success and self.data is not None: - return self.data[ModelExecRequestResponse.get_job_id_key()] + def job_id(self) -> int: + if self.success: + return self.data.job_id else: - return -1 + return UNSUCCESSFUL_JOB # TODO: make sure this value gets included in the data dict @property @@ -228,7 +206,23 @@ def output_data_id(self) -> Optional[str]: Optional[str] The 'data_id' of the output dataset for requested job, or ``None`` if not known. """ - if self.data is not None and ModelExecRequestResponse.get_output_data_id_key() in self.data: - return self.data[ModelExecRequestResponse.get_output_data_id_key()] - else: + if self.data is None: return None + return self.data.get("output_data_id") + + @classmethod + def factory_init_from_deserialized_json(cls, json_obj: dict) -> "SchedulerRequestResponse": + # TODO: remove in future. necessary for backwards compatibility + if isinstance(json_obj, SchedulerRequestResponse): + return json_obj + + return super().factory_init_from_deserialized_json(json_obj=json_obj) + + # NOTE: legacy support. previously this class was treated as a dictionary + def __contains__(self, element: str) -> bool: + return element in self.__dict__ + + # NOTE: legacy support. previously this class was treated as a dictionary + def __getitem__(self, item: str): + return self.__dict__[item] + diff --git a/python/lib/communication/dmod/communication/scheduler_request_response_body.py b/python/lib/communication/dmod/communication/scheduler_request_response_body.py new file mode 100644 index 000000000..2eee90a69 --- /dev/null +++ b/python/lib/communication/dmod/communication/scheduler_request_response_body.py @@ -0,0 +1,43 @@ +from dmod.core.serializable_dict import SerializableDict + +from typing import Optional, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from pydantic.typing import AbstractSetIntStr, MappingIntStrAny, DictStrAny + +UNSUCCESSFUL_JOB = -1 + + +class SchedulerRequestResponseBody(SerializableDict): + job_id: int = UNSUCCESSFUL_JOB + output_data_id: Optional[str] + + def __eq__(self, other: object): + if isinstance(other, dict): + return self.to_dict() == other + return super().__eq__(other) + + def __getattr__(self, key: str): + return self.__dict__[key] + + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = True, # noop + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> "DictStrAny": + # Note: for backwards compatibility, unset fields are excluded by default + return super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=True, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) diff --git a/python/lib/communication/dmod/communication/session.py b/python/lib/communication/dmod/communication/session.py index d6adbec17..32bbe3ea8 100644 --- a/python/lib/communication/dmod/communication/session.py +++ b/python/lib/communication/dmod/communication/session.py @@ -3,13 +3,30 @@ import random from .message import AbstractInitRequest, MessageEventType, Response from dmod.core.serializable import Serializable +from dmod.core.serializable_dict import SerializableDict +from dmod.core.enum import PydanticEnum from abc import ABC, abstractmethod -from enum import Enum from numbers import Number -from typing import Dict, Optional, Union +from typing import ClassVar, Dict, Optional, List, Type, TYPE_CHECKING, Union +from pydantic import Field, IPvAnyAddress, validator, root_validator +if TYPE_CHECKING: + from pydantic.fields import ModelField -class SessionInitFailureReason(Enum): + +def _generate_secret() -> str: + """Generate random sha256 session secret. + + Returns + ------- + str + sha256 digest + """ + random.seed() + return hashlib.sha256(str(random.random()).encode('utf-8')).hexdigest() + + +class SessionInitFailureReason(PydanticEnum): AUTHENTICATION_SYS_FAIL = 1, # some error other than bad credentials prevented successful user authentication AUTHENTICATION_DENIED = 2, # the user's asserted identity was not authenticated due to the provided credentials USER_NOT_AUTHORIZED = 3, # the user was authenticated, but does not have authorized permission for a session @@ -21,58 +38,57 @@ class SessionInitFailureReason(Enum): UNKNOWN = -1 -class Session(Serializable): +class Session(SerializableDict): """ A bare-bones representation of a session between some compatible server and client, over which various requests may be made, and potentially other communication may take place. """ - _DATETIME_FORMAT = '%Y-%m-%d %H:%M:%S.%f' + _DATETIME_FORMAT: ClassVar[str] = '%Y-%m-%d %H:%M:%S.%f' + + session_id: int = Field(description="The unique identifier for this session.") + # QUESTION: we are using UUID4's elsewhere, do we want to use that instead here? Or perhaps a ULID? + session_secret: str = Field(default_factory=_generate_secret, min_length=64, max_length=64, description="The unique random secret for this session.") + created: datetime.datetime = Field(default_factory=datetime.datetime.now, description="The date and time this session was created.") + last_accessed: datetime.datetime = Field(default_factory=datetime.datetime.now) - _full_equality_attributes = ['session_id', 'session_secret', 'created', 'last_accessed'] + _full_equality_attributes: ClassVar[List[str]]= ['session_id', 'session_secret', 'created', 'last_accessed'] """ list of str: the names of attributes/properties to include when testing instances for complete equality """ - _serialized_attributes = ['session_id', 'session_secret', 'created', 'last_accessed'] + _serialized_attributes: ClassVar[List[str]]= ['session_id', 'session_secret', 'created', 'last_accessed'] """ list of str: the names of attributes/properties to include when serializing an instance """ - _session_timeout_delta = datetime.timedelta(minutes=30.0) + _session_timeout_delta: ClassVar[datetime.timedelta] = datetime.timedelta(minutes=30.0) - @classmethod - def _init_datetime_val(cls, value): - try: - if value is None: - return datetime.datetime.now() - elif isinstance(value, str): - return datetime.datetime.strptime(value, Session._DATETIME_FORMAT) - elif not isinstance(value, datetime.datetime): - raise RuntimeError() - else: - return value - except Exception as e: - return datetime.datetime.now() + @validator("session_secret", pre=True) + def _populate_session_secret_if_none(cls, value: Optional[str], field: "ModelField") -> str: + # NOTE: pre-pydantic, this field was a computed optional: + # (i.e. `__init__(..., session_secret: str = None)`) but if None, a value was generated. + # this validator handles that case + if value is None: + return field.default_factory() # type: ignore - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. + return value - Parameters - ---------- - json_obj + @validator("created", "last_accessed", pre=True) + def validate_date(cls, value): + if isinstance(value, datetime.datetime): + return value - Returns - ------- - A new object of this type instantiated from the deserialize JSON object dictionary - """ - int_converter = lambda x: int(x) - str_converter = lambda s: str(s) - date_converter = lambda date_str: datetime.datetime.strptime(date_str, cls.get_datetime_str_format()) + try: + return datetime.datetime.strptime(value, cls.get_datetime_str_format()) + # TODO: improve error handling, or throw something know for downstream users. + except: + return datetime.datetime.now() - return cls(session_id=cls.parse_simple_serialized(json_obj, 'session_id', int, True, int_converter), - session_secret=cls.parse_simple_serialized(json_obj, 'session_secret', str, False, str_converter), - created=cls.parse_simple_serialized(json_obj, 'created', datetime.datetime, False, date_converter), - last_accessed=cls.parse_simple_serialized(json_obj, 'last_accessed', datetime.datetime, False, - date_converter)) + class Config: + def _serialize_datetime(self: "Session", value: datetime.datetime) -> str: + return value.strftime(self.get_datetime_str_format()) + + field_serializers = { + "created": _serialize_datetime, + "last_accessed": _serialize_datetime, + } @classmethod def get_datetime_str_format(cls): @@ -91,7 +107,7 @@ def get_full_equality_attributes(cls) -> tuple: a tuple-ized (and therefore immutable) collection of attribute names for those attributes used for determining full/complete equality between instances. """ - return tuple(cls._full_equality_attributes) + return tuple(cls.__fields__) @classmethod def get_serialized_attributes(cls) -> tuple: @@ -106,7 +122,7 @@ def get_serialized_attributes(cls) -> tuple: tuple of str: a tuple-ized (and therefore immutable) collection of attribute names for attributes used in serialization """ - return tuple(cls._serialized_attributes) + return tuple(cls.__fields__) @classmethod def get_session_timeout_delta(cls) -> datetime.timedelta: @@ -115,46 +131,10 @@ def get_session_timeout_delta(cls) -> datetime.timedelta: def __eq__(self, other): return isinstance(other, Session) and self.session_id == other.session_id - def __init__(self, - session_id: Union[str, int], - session_secret: str = None, - created: Union[datetime.datetime, str, None] = None, - last_accessed: Union[datetime.datetime, str, None] = None): - """ - Instantiate, either from an existing record - in which case values for 'secret' and 'created' are provided - or - from a newly acquired session id - in which case 'secret' is randomly generated, 'created' is set to now(), and - the expectation is that a new session record will be created from this instance. - - Parameters - ---------- - session_id : Union[str, int] - numeric session id value - session_secret : :obj:`str`, optional - the session secret, if deserializing this object from an existing session record - created : Union[:obj:`datetime.datetime`, :obj:`str`] - the date and time of session creation, either as a datetime object or parseable string, set to - :method:`datetime.datetime.now()` by default - """ - - self._session_id = int(session_id) - if session_secret is None: - random.seed() - self._session_secret = hashlib.sha256(str(random.random()).encode('utf-8')).hexdigest() - else: - self._session_secret = session_secret - - self._created = self._init_datetime_val(created) - self._last_accessed = self._init_datetime_val(last_accessed) - def __hash__(self): return self.session_id - @property - def created(self): - """:obj:`datetime.datetime`: The date and time this session was created.""" - return self._created - - def full_equals(self, other) -> bool: + def full_equals(self, other: object) -> bool: """ Test if this object and another are both of the exact same type and are more "fully" equal than can be determined from the standard equality implementation, by comparing all the attributes from @@ -172,16 +152,7 @@ def full_equals(self, other) -> bool: fully_equal : bool whether the objects are of the same type and with equal values for all serialized attributes """ - if self.__class__ != other.__class__: - return False - try: - for attr in self.get_full_equality_attributes(): - if getattr(self, attr) != getattr(other, attr): - return False - return True - except Exception as e: - # TODO: do something with this exception - return False + return super().__eq__(other) def get_as_dict(self) -> dict: """ @@ -192,17 +163,7 @@ def get_as_dict(self) -> dict: dict a serialized representation of this instance """ - attributes = {} - for attr in self._serialized_attributes: - attr_val = getattr(self, attr) - if isinstance(attr_val, datetime.datetime): - attributes[attr] = attr_val.strftime(self.get_datetime_str_format()) - elif isinstance(attr_val, Number) or isinstance(attr_val, str): - attributes[attr] = attr_val - else: - attributes[attr] = str(attr_val) - - return attributes + return self.dict() def get_as_json(self) -> str: """ @@ -219,12 +180,12 @@ def get_created_serialized(self): return self.created.strftime(Session._DATETIME_FORMAT) def get_last_accessed_serialized(self): - return self._last_accessed.strftime(Session._DATETIME_FORMAT) + return self.last_accessed.strftime(Session._DATETIME_FORMAT) def is_expired(self): - return self._last_accessed + self.get_session_timeout_delta() < datetime.datetime.now() + return self.last_accessed + self.get_session_timeout_delta() < datetime.datetime.now() - def is_serialized_attribute(self, attribute) -> bool: + def is_serialized_attribute(self, attribute: str) -> bool: """ Test whether an attribute of the given name is included in the serialized version of the instance returned by :method:`get_as_dict` and/or :method:`get_as_json` (at the top level). @@ -239,30 +200,21 @@ def is_serialized_attribute(self, attribute) -> bool: True if there is an attribute with the given name in the :attr:`_serialized_attributes` list, or False otherwise """ - for attr in self._serialized_attributes: - if attribute == attr: - return True - return False - - @property - def session_id(self): - """int: The unique identifier for this session.""" - return int(self._session_id) - - @property - def session_secret(self): - """str: The unique random secret for this session.""" - return self._session_secret - - def to_dict(self) -> dict: - return self.get_as_dict() + if not isinstance(attribute, str): + return False + return attribute in self.__fields__ # TODO: work more on this later, when authentication becomes more important class FullAuthSession(Session): - _full_equality_attributes = ['session_id', 'session_secret', 'created', 'ip_address', 'user', 'last_accessed'] - _serialized_attributes = ['session_id', 'session_secret', 'created', 'ip_address', 'user', 'last_accessed'] + ip_address: str + user: str = 'default' + + @validator("ip_address", pre=True) + def cast_ip_address_to_str(cls, value: str) -> str: + # this will raise if cannot be coerced into IPv(4|6)Address + return str(IPvAnyAddress.validate(value)) @classmethod def factory_init_from_deserialized_json(cls, json_obj: dict): @@ -277,45 +229,12 @@ def factory_init_from_deserialized_json(cls, json_obj: dict): ------- A new object of this type instantiated from the deserialize JSON object dictionary """ - # TODO: these are duplicated ... try to improve on that - int_converter = lambda x: int(x) - str_converter = lambda s: str(s) - date_converter = lambda date_str: datetime.datetime.strptime(date_str, cls.get_datetime_str_format()) try: - return cls(session_id=cls.parse_simple_serialized(json_obj, 'session_id', int, True, int_converter), - session_secret=cls.parse_simple_serialized(json_obj, 'session_secret', str, False, str_converter), - created=cls.parse_simple_serialized(json_obj, 'created', datetime.datetime, False, date_converter), - ip_address=cls.parse_simple_serialized(json_obj, 'ip_address', str, True, str_converter), - user=cls.parse_simple_serialized(json_obj, 'user', str, True, str_converter), - last_accessed=cls.parse_simple_serialized(json_obj, 'last_accessed', datetime.datetime, False, date_converter)) + return cls(**json_obj) except: return Session.factory_init_from_deserialized_json(json_obj) - def __init__(self, - ip_address: str, - session_id: Union[str, int], - session_secret: str = None, - user: str = 'default', - created: Union[datetime.datetime, str, None] = None, - last_accessed: Union[datetime.datetime, str, None] = None): - super().__init__(session_id=session_id, session_secret=session_secret, created=created, - last_accessed=last_accessed) - self._user = user if user is not None else 'default' - self._ip_address = ip_address - - @property - def ip_address(self): - return self._ip_address - - @property - def last_accessed(self): - return self._last_accessed - - @property - def user(self): - return self._user - class SessionInitMessage(AbstractInitRequest): """ @@ -336,108 +255,38 @@ class SessionInitMessage(AbstractInitRequest): The secret through which the client entity establishes the authenticity of its username assertion """ - event_type: MessageEventType = MessageEventType.SESSION_INIT - """ :class:`MessageEventType`: the event type for this message implementation """ - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. - - Parameters - ---------- - json_obj + username: str + user_secret: str - Returns - ------- - A new object of this type instantiated from the deserialize JSON object dictionary - """ - try: - return SessionInitMessage(username=json_obj['username'], user_secret=json_obj['user_secret']) - except: - return None - - def __init__(self, username: str, user_secret: str): - self.username = username - self.user_secret = user_secret - - def to_dict(self) -> dict: - return {'username': self.username, 'user_secret': self.user_secret} + event_type: ClassVar[MessageEventType] = MessageEventType.SESSION_INIT + """ :class:`MessageEventType`: the event type for this message implementation """ -class FailedSessionInitInfo(Serializable): +class FailedSessionInitInfo(SerializableDict): """ A :class:`~.serializeable.Serializable` type for representing details on why a :class:`SessionInitMessage` didn't successfully init a session. """ + user: str + reason: SessionInitFailureReason = SessionInitFailureReason.UNKNOWN + fail_time: datetime.datetime = Field(default_factory=datetime.datetime.now) + details: Optional[str] + @classmethod def get_datetime_str_format(cls): return Session.get_datetime_str_format() - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - date_converter = lambda date_str: datetime.datetime.strptime(date_str, cls.get_datetime_str_format()) - reason_converter = lambda r: SessionInitFailureReason[r] - try: - user = cls.parse_simple_serialized(json_obj, 'user', str, True) - fail_time = cls.parse_simple_serialized(json_obj, 'fail_time', datetime.datetime, False, date_converter) - reason = cls.parse_simple_serialized(json_obj, 'reason', SessionInitFailureReason, False, reason_converter) - details = cls.parse_simple_serialized(json_obj, 'details', str, False) - - if reason is None: - FailedSessionInitInfo(user=user, fail_time=fail_time, details=details) - else: - return FailedSessionInitInfo(user=user, reason=reason, fail_time=fail_time, details=details) - except: - return None - - def __eq__(self, other): - if self.__class__ != other.__class__ or self.user != other.user or self.reason != other.reason: - return False - if self.fail_time is not None and other.fail_time is not None and self.fail_time != other.fail_time: - return False - return True - - def __init__(self, user: str, reason: SessionInitFailureReason = SessionInitFailureReason.UNKNOWN, - fail_time: Optional[datetime.datetime] = None, details: Optional[str] = None): - self.user = user - self.reason = reason - self.fail_time = fail_time if fail_time is not None else datetime.datetime.now() - self.details = details - - def to_dict(self) -> Dict[str, str]: - """ - Get the representation of this instance as a serialized dictionary or dictionary-like object (e.g., a JSON - object). - - Since the returned value must be serializable and JSON-like, key and value types are restricted. For this - implementation, all keys and values in the returned dictionary must be strings. Thus, for the - ::attribute:`fail_time` and ::attribute:`details` attributes, there should be no key or value if the attribute - has a current value of ``None``. - - Returns - ------- - Dict[str, str] - The representation of this instance as a serialized dictionary or dictionary-like object, with valid types - of keys and values. - - See Also - ------- - ::method:`Serializable.to_dict` - """ - result = {'user': self.user, 'reason': self.reason.value} - if self.fail_time is not None: - result['fail_time'] = self.fail_time.strftime(self.get_datetime_str_format()) - if self.details is not None: - result['details'] = self.details - return result + class Config: + def _serialize_datetime(self: "Session", value: datetime.datetime) -> str: + return value.strftime(self.get_datetime_str_format()) + + field_serializers = {"fail_time": _serialize_datetime} # Define this custom type here for hinting SessionInitDataType = Union[Session, FailedSessionInitInfo] - class SessionInitResponse(Response): """ The :class:`~.message.Response` subtype used to response to a :class:`.SessionInitMessage`, either @@ -481,42 +330,55 @@ class SessionInitResponse(Response): """ - response_to_type = SessionInitMessage + response_to_type: ClassVar[Type[AbstractInitRequest]] = SessionInitMessage """ Type[`SessionInitMessage`]: the type or subtype of :class:`Message` for which this type is the response""" - @classmethod - def _factory_init_data_attribute(cls, json_obj: dict) -> Optional[SessionInitDataType]: - """ - Initialize the argument value for a constructor param used to set the :attr:`data` attribute appropriate for - this type, given the parent JSON object, which for this type means deserializing the dict value to either a - session object or a failure info object. - - Parameters - ---------- - json_obj : dict - the parent JSON object containing the desired session data serialized value - - Returns - ------- - data - the resulting :class:`Session` or :class:`FailedSessionInitInfo` object obtained after processing, - or None if no valid object could be processed of either type - """ - data = None - try: - data = json_obj['data'] - except: - det = 'Received serialized JSON response object that did not contain expected key for serialized session.' - return FailedSessionInitInfo(user='', reason=SessionInitFailureReason.SESSION_DETAILS_MISSING, details=det) - - try: - # If we can, return the FullAuthSession or Session obtained by this class method - return FullAuthSession.factory_init_from_deserialized_json(data) - except: + # NOTE: this field _is_ optional, however `data` will be FailedSessionInitInfo if it is not + # provided or set to None. + # NOTE: order of this Union matters. types will be coerced from left to right. meaning, more + # specific types (i.e. subtypes) should be listed before more general types. see `SmartUnion` + # for more detail: https://docs.pydantic.dev/usage/model_config/#smart-union + data: Union[FullAuthSession, Session, FailedSessionInitInfo] + + @root_validator(pre=True) + def _coerce_data_field(cls, values): + data = values.get("data") + + if data is None: + details = "Instantiated SessionInitResponse object without session data; defaulting to failure" + values["data"] = FailedSessionInitInfo( + user="", + reason=SessionInitFailureReason.SESSION_DETAILS_MISSING, + details=details, + ) + return values + + # run `data` field validators + coerced_data, errors = cls.__fields__["data"].validate(data, {}, loc="") + if errors is not None: + details = 'Instantiated SessionInitResponse object using unexpected type for data ({})'.format( + data.__class__.__name__) try: - return FailedSessionInitInfo.factory_init_from_deserialized_json(data) + as_str = '; converted to string: \n{}'.format(str(data)) + details += as_str except: - return None + # If we can't cast to string, don't worry; just leave out that part in details + pass + values["data"] = FailedSessionInitInfo( + user="", + reason=SessionInitFailureReason.SESSION_DETAILS_MISSING, + details=details, + ) + return values + + values["data"] = coerced_data + return values + + @root_validator() + def _update_success(cls, values): + # Make sure to reset/change self.success if self.data ends up being a failure info object + values["success"] = values["success"] and isinstance(values["data"], Session) + return values def __eq__(self, other): return self.__class__ == other.__class__ \ @@ -525,34 +387,6 @@ def __eq__(self, other): and self.message == other.message \ and self.data.full_equals(other.data) if isinstance(self.data, Session) else self.data == other.data - def __init__(self, success: bool, reason: str, message: str = '', data: Optional[SessionInitDataType] = None): - super().__init__(success=success, reason=reason, message=message, data=data) - - # If we received a dict for data, try to deserialize using the class method (failures will set to None, - # which will get handled by the next conditional logic) - if isinstance(self.data, dict): - # Remember, the class method expects a JSON obj dict with the data as a child element, not the data directly - self.data = self.__class__._factory_init_data_attribute({'success': self.success, 'data': data}) - - if self.data is None: - details = 'Instantiated SessionInitResponse object without session data; defaulting to failure' - self.data = FailedSessionInitInfo(user='', reason=SessionInitFailureReason.SESSION_DETAILS_MISSING, - details=details) - elif not (isinstance(self.data, Session) or isinstance(self.data, FailedSessionInitInfo)): - details = 'Instantiated SessionInitResponse object using unexpected type for data ({})'.format( - self.data.__class__.__name__) - try: - as_str = '; converted to string: \n{}'.format(str(self.data)) - details += as_str - except: - # If we can't cast to string, don't worry; just leave out that part in details - pass - self.data = FailedSessionInitInfo(user='', reason=SessionInitFailureReason.SESSION_DETAILS_MISSING, - details=details) - - # Make sure to reset/change self.success if self.data ends up being a failure info object - self.success = self.success and isinstance(self.data, Session) - class SessionManager(ABC): """ diff --git a/python/lib/communication/dmod/communication/unsupported_message.py b/python/lib/communication/dmod/communication/unsupported_message.py index 88295ad74..e1c64ec2b 100644 --- a/python/lib/communication/dmod/communication/unsupported_message.py +++ b/python/lib/communication/dmod/communication/unsupported_message.py @@ -4,12 +4,28 @@ class UnsupportedMessageTypeResponse(Response): + actual_event_type: MessageEventType + listener_type: Type[WebSocketInterface] + message: str - def __init__(self, actual_event_type: MessageEventType, listener_type: Type[WebSocketInterface], - message: str = None, data=None): + success = False + reason = "Message Event Type Unsupported" + + def __init__( + self, + actual_event_type: MessageEventType, + listener_type: Type[WebSocketInterface], + message: str = None, + data=None, + **kwargs + ): if message is None: - message = 'The {} event type is not supported by this {} listener'.format( - actual_event_type, listener_type.__name__) - super().__init__(success=False, reason='Message Event Type Unsupported', message=message, data=data) - self.actual_event_type = actual_event_type - self.listener_type = listener_type \ No newline at end of file + message = "The {} event type is not supported by this {} listener".format( + actual_event_type, listener_type.__name__ + ) + super().__init__( + message=message, + data=data, + actual_event_type=actual_event_type, + listener_type=listener_type, + ) diff --git a/python/lib/communication/dmod/communication/update_message.py b/python/lib/communication/dmod/communication/update_message.py index 787f5cfeb..c639c06fa 100644 --- a/python/lib/communication/dmod/communication/update_message.py +++ b/python/lib/communication/dmod/communication/update_message.py @@ -1,8 +1,11 @@ from .message import AbstractInitRequest, MessageEventType, Response from pydoc import locate -from typing import Dict, Optional, Type, Union +from typing import ClassVar, Dict, Optional, Type, Union +from pydantic import Field, validator import uuid +from dmod.core.serializable_dict import SerializableDict + class UpdateMessage(AbstractInitRequest): """ @@ -28,132 +31,56 @@ class type, but note that when messages are serialized, it is converted to the f update it conveys. """ - event_type: MessageEventType = MessageEventType.INFORMATION_UPDATE + event_type: ClassVar[MessageEventType] = MessageEventType.INFORMATION_UPDATE + + object_id: str = Field(description="The identifier for the object being updated, as a string.") + object_type: Type[object] = Field(description="The type of object being updated.") + # NOTE: updated_data must container at least one key + updated_data: Dict[str, str] = Field(description="A serialized dictionary of properties to new values.") + digest: str = Field(default_factory=lambda: uuid.uuid4().hex) + + @validator("object_type", pre=True) + def _coerce_object_type(cls, value): + if isinstance(value, str): + obj_type = locate(value) + if obj_type is None: + raise ValueError("could not resolve `object_type`") + return obj_type + return value + + @validator("updated_data") + def _validate_updated_data_has_keys(cls, value: Dict[str, str]): + if not value.keys(): + raise ValueError("`updated_data` must have at least one key.") + return value - _DIGEST_KEY = 'digest' - _OBJECT_ID_KEY = 'object_id' - _OBJECT_TYPE_KEY = 'object_type' - _UPDATED_DATA_KEY = 'updated_data' + class Config: + field_serializers = {"object_type": lambda self, _: self.object_type_string} @classmethod def get_digest_key(cls) -> str: - return cls._DIGEST_KEY + return cls.__fields__["digest"].alias @classmethod def get_object_id_key(cls) -> str: - return cls._OBJECT_ID_KEY + return cls.__fields__["object_id"].alias @classmethod def get_object_type_key(cls) -> str: - return cls._OBJECT_TYPE_KEY + return cls.__fields__["object_type"].alias @classmethod def get_updated_data_key(cls) -> str: - return cls._UPDATED_DATA_KEY - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. - - The method expects the ::attribute:`object_type` to be represented as the fully-qualified name string for the - particular class type. If the method cannot located the actual class type by this string, the JSON is - considered invalid. - - Additionally, if the representation of the ::attribute:`updated_data` property is not a (serialized) nested - dictionary, or is an empty dictionary, this is also considered invalid. - - Both ::attribute:`digest` and ::attribute:`object_id` representations are valid if they can be cast to strings. - - The JSON is not considered invalid if it has other keys/values at the root level beyond those for the standard - properties. - - For invalid JSON representations, ``None`` is returned. - - Parameters - ---------- - json_obj - - Returns - ------- - Optional[UpdateMessage] - A new object of this type instantiated from the deserialize JSON object dictionary, or ``None`` if the JSON - is not a valid serialized representation of this type. - """ - try: - obj_type = locate(json_obj[cls.get_object_type_key()]) - if obj_type is None: - return None - obj_id = str(json_obj[cls.get_object_id_key()]) - updated_data = json_obj[cls.get_updated_data_key()] - if not isinstance(updated_data, dict) or len(updated_data.keys()) == 0: - return None - message = cls(object_id=obj_id, object_type=obj_type, updated_data=updated_data) - message._digest = str(json_obj[cls.get_digest_key()]) - except: - return None - - def __init__(self, object_id: str, object_type: Type, updated_data: Dict[str, str]): - """ - Initialize a new object. - - Parameters - ---------- - object_id : str - The identifier for the object being updated, as a string. - object_type : Type - The type of object being updated. - updated_data : Dict[str, str] - A serialized dictionary of properties to new values. - """ - self._digest = None - self._object_type = object_type - self._object_id = object_id - self._updated_data = updated_data - - @property - def digest(self) -> str: - if self._digest is None: - self._digest = uuid.uuid4().hex - return self._digest - - @property - def object_id(self) -> str: - return self._object_id - - @property - def object_type(self) -> Type: - return self._object_type + return cls.__fields__["updated_data"].alias @property def object_type_string(self) -> str: return '{}.{}'.format(self.object_type.__module__, self.object_type.__name__) - def to_dict(self) -> dict: - """ - Get the representation of this instance as a dictionary or dictionary-like object (e.g., a JSON object). - - Returns - ------- - dict - The representation of this instance as a dictionary or dictionary-like object (e.g., a JSON object). - """ - return {self.get_object_id_key(): self.object_id, self.get_digest_key(): self.digest, - self.get_object_type_key(): self.object_type_string, self.get_updated_data_key(): self.updated_data} - - @property - def updated_data(self) -> Dict[str, str]: - """ - Get the updated properties of the updated entity and the new values, as a dictionary of string property name - keys mapped to string representations of the values. - Returns - ------- - Dict[str, str] - The updated properties of the updated entity and the new values, as a dictionary of string property name - keys mapped to string representations of the values. - """ - return self._updated_data +class UpdateMessageData(SerializableDict): + digest: Optional[str] + object_found: Optional[bool] class UpdateMessageResponse(Response): @@ -161,10 +88,9 @@ class UpdateMessageResponse(Response): The subtype of ::class:`Response` appropriate for ::class:`UpdateMessage` objects. """ - _DIGEST_SUBKEY = 'digest' - _OBJECT_FOUND_SUBKEY = 'object_found' + response_to_type: ClassVar[Type[AbstractInitRequest]] = UpdateMessage - response_to_type = UpdateMessage + data: UpdateMessageData = Field(default_factory=UpdateMessageData) @classmethod def get_digest_subkey(cls) -> str: @@ -178,7 +104,7 @@ def get_digest_subkey(cls) -> str: The "subkey" (i.e., the key for the value within the nested ``data`` dictionary) for the ``digest`` in serialized representations. """ - return cls._DIGEST_SUBKEY + return UpdateMessageData.__fields__["digest"].alias @classmethod def get_object_found_subkey(cls) -> str: @@ -192,25 +118,33 @@ def get_object_found_subkey(cls) -> str: The "subkey" (i.e., the key for the value within the nested ``data`` dictionary) for the ``digest`` in serialized representations. """ - return cls._OBJECT_FOUND_SUBKEY + return UpdateMessageData.__fields__["object_found"].alias def __init__(self, success: bool, reason: str, response_text: str = '', data: Optional[Dict[str, Union[str, bool]]] = None, digest: Optional[str] = None, - object_found: Optional[bool] = None): + object_found: Optional[bool] = None, **kwargs): # Work with digest/found either as params or contained within data param # However, move explicit params into the data dict param, allowing non-None params to overwrite data = dict() if data is None else data - digest = data[self.get_digest_subkey()] if digest is None and self.get_digest_subkey() in data else digest - if object_found is None and self.get_object_found_subkey(): + + if digest is None and self.get_digest_subkey() in data: + digest = data[self.get_digest_subkey()] + + if object_found is None and self.get_object_found_subkey() in data: object_found = data[self.get_object_found_subkey()] - super().__init__(success=success, reason=reason, message=response_text, - data={self.get_digest_subkey(): digest, self.get_object_found_subkey(): object_found}) + super().__init__( + success=success, + reason=reason, + message=response_text, + data=UpdateMessageData(digest=digest, object_found=object_found), + **kwargs + ) @property - def digest(self) -> str: - return self.data[self.get_digest_subkey()] + def digest(self) -> Optional[str]: + return self.data.digest @property - def object_found(self) -> bool: - return self.data[self.get_object_found_subkey()] + def object_found(self) -> Optional[bool]: + return self.data.object_found diff --git a/python/lib/communication/dmod/test/test_dataset_query.py b/python/lib/communication/dmod/test/test_dataset_query.py index f313e97fe..8334e953e 100644 --- a/python/lib/communication/dmod/test/test_dataset_query.py +++ b/python/lib/communication/dmod/test/test_dataset_query.py @@ -10,7 +10,7 @@ def setUp(self) -> None: self.examples = [] self.ex_query_types.append(QueryType.LIST_FILES) - self.ex_json_data.append({DatasetQuery._KEY_QUERY_TYPE: 'LIST_FILES'}) + self.ex_json_data.append({"query_type": 'LIST_FILES'}) self.examples.append(DatasetQuery(query_type=QueryType.LIST_FILES)) def test_factory_init_from_deserialized_json_0_a(self): diff --git a/python/lib/communication/dmod/test/test_decorated_interface.py b/python/lib/communication/dmod/test/test_decorated_interface.py index 48524ac1b..0850c1a97 100644 --- a/python/lib/communication/dmod/test/test_decorated_interface.py +++ b/python/lib/communication/dmod/test/test_decorated_interface.py @@ -7,8 +7,8 @@ import sys import unittest from ..communication.message import MessageEventType -from dmod.communication import ModelExecRequest, SessionInitMessage -from dmod.communication.dataset_management_message import MaaSDatasetManagementMessage +from ..communication import ModelExecRequest, SessionInitMessage +from ..communication.dataset_management_message import MaaSDatasetManagementMessage from ..communication.websocket_interface import NoOpHandler from pathlib import Path from socket import gethostname @@ -153,7 +153,10 @@ def setUp(self): "model": { "nwm": { "config_data_id": "1", - "data_requirements": [{"domain": {"data_format": "NWM_CONFIG", "continuous": [], + "data_requirements": [{ + "category": "CONFIG", + "is_input": True, + "domain": {"data_format": "NWM_CONFIG", "continuous": [], "discrete": [{"variable": "data_id", "values": ["1"]}]}}] } }, diff --git a/python/lib/communication/dmod/test/test_ngen_request_response.py b/python/lib/communication/dmod/test/test_ngen_request_response.py index 25d3c24e0..96e63dd15 100644 --- a/python/lib/communication/dmod/test/test_ngen_request_response.py +++ b/python/lib/communication/dmod/test/test_ngen_request_response.py @@ -1,6 +1,7 @@ import json import unittest from ..communication.maas_request import NGENRequestResponse +from ..communication.maas_request.model_exec_request_response_body import ModelExecRequestResponseBody from ..communication.message import InitRequestResponseReason from ..communication.scheduler_request import SchedulerRequestResponse @@ -95,7 +96,7 @@ def test_factory_init_from_deserialized_json_2_e(self): the expected dictionary value for ``data``. """ obj = NGENRequestResponse.factory_init_from_deserialized_json(self.response_jsons[2]) - self.assertEqual(obj.data.__class__, dict) + self.assertEqual(obj.data.__class__, ModelExecRequestResponseBody) def test_factory_init_from_deserialized_json_2_f(self): """ @@ -127,7 +128,7 @@ def test_factory_init_from_deserialized_json_2_i(self): the expected dictionary value for ``data``, with the ``scheduler_response`` being of the right type. """ obj = NGENRequestResponse.factory_init_from_deserialized_json(self.response_jsons[2]) - self.assertEqual(obj.data['scheduler_response'].__class__, dict) + self.assertEqual(obj.data['scheduler_response'].__class__, SchedulerRequestResponse) def test_factory_init_from_deserialized_json_2_j(self): """ diff --git a/python/lib/communication/dmod/test/test_nwm_request_response.py b/python/lib/communication/dmod/test/test_nwm_request_response.py index 848ddc709..d3a8e224e 100644 --- a/python/lib/communication/dmod/test/test_nwm_request_response.py +++ b/python/lib/communication/dmod/test/test_nwm_request_response.py @@ -1,6 +1,7 @@ import json import unittest from ..communication.maas_request import NWMRequestResponse +from ..communication.maas_request.model_exec_request_response_body import ModelExecRequestResponseBody from ..communication.message import InitRequestResponseReason from ..communication.scheduler_request import SchedulerRequestResponse @@ -92,10 +93,11 @@ def test_factory_init_from_deserialized_json_2_d(self): def test_factory_init_from_deserialized_json_2_e(self): """ Test ``factory_init_from_deserialized_json()`` on raw string example 2 to make sure the deserialized object has - the expected dictionary value for ``data``. + the expected ModelExecRequestResponseBody value for ``data``. For legacy support, this can still be + treated like a dictionary. """ obj = NWMRequestResponse.factory_init_from_deserialized_json(self.response_jsons[2]) - self.assertEqual(obj.data.__class__, dict) + self.assertEqual(obj.data.__class__, ModelExecRequestResponseBody) def test_factory_init_from_deserialized_json_2_f(self): """ @@ -124,10 +126,11 @@ def test_factory_init_from_deserialized_json_2_h(self): def test_factory_init_from_deserialized_json_2_i(self): """ Test ``factory_init_from_deserialized_json()`` on raw string example 2 to make sure the deserialized object has - the expected dictionary value for ``data``, with the ``scheduler_response`` being of the right type. + the expected SchedulerRequestResponse value for ``data``, with the ``scheduler_response`` being of the right type. + For legacy support, ``SchedulerRequestResponse`` can still be treated as a dictionary. """ obj = NWMRequestResponse.factory_init_from_deserialized_json(self.response_jsons[2]) - self.assertEqual(obj.data['scheduler_response'].__class__, dict) + self.assertEqual(obj.data['scheduler_response'].__class__, SchedulerRequestResponse) def test_factory_init_from_deserialized_json_2_j(self): """ diff --git a/python/lib/communication/dmod/test/test_websocket_interface.py b/python/lib/communication/dmod/test/test_websocket_interface.py index 3fb8a973f..b3908b6fb 100644 --- a/python/lib/communication/dmod/test/test_websocket_interface.py +++ b/python/lib/communication/dmod/test/test_websocket_interface.py @@ -7,8 +7,9 @@ import sys import unittest from ..communication.message import MessageEventType -from dmod.communication import ModelExecRequest, SessionInitMessage -from dmod.communication.dataset_management_message import MaaSDatasetManagementMessage +from ..communication.maas_request import ModelExecRequest +from ..communication.session import SessionInitMessage +from ..communication.dataset_management_message import MaaSDatasetManagementMessage from ..communication.websocket_interface import NoOpHandler from pathlib import Path from socket import gethostname @@ -155,7 +156,10 @@ def setUp(self): "allocation_paradigm": "ROUND_ROBIN", "config_data_id": "1", "cpu_count": 2, - "data_requirements": [{"domain": {"data_format": "NWM_CONFIG", "continuous": [], + "data_requirements": [{ + "category": "CONFIG", + "is_input": True, + "domain": {"data_format": "NWM_CONFIG", "continuous": [], "discrete": [{"variable": "data_id", "values": ["1"]}]}}] } }, diff --git a/python/lib/communication/setup.py b/python/lib/communication/setup.py index 839d3e133..45edb9a0d 100644 --- a/python/lib/communication/setup.py +++ b/python/lib/communication/setup.py @@ -21,7 +21,6 @@ url='', license='', include_package_data=True, - #install_requires=['websockets', 'jsonschema'],vi - install_requires=['dmod-core>=0.1.2', 'websockets>=8.1', 'jsonschema', 'redis'], + install_requires=['dmod-core>=0.4.2', 'websockets>=8.1', 'jsonschema', 'redis', 'pydantic'], packages=find_namespace_packages(include=['dmod.*'], exclude=['dmod.test']) ) diff --git a/python/lib/core/dmod/core/_version.py b/python/lib/core/dmod/core/_version.py index b703f5c96..a98734733 100644 --- a/python/lib/core/dmod/core/_version.py +++ b/python/lib/core/dmod/core/_version.py @@ -1 +1 @@ -__version__ = '0.4.1' \ No newline at end of file +__version__ = '0.4.2' diff --git a/python/lib/core/dmod/core/dataset.py b/python/lib/core/dmod/core/dataset.py index b1e5ca2aa..142bd88b5 100644 --- a/python/lib/core/dmod/core/dataset.py +++ b/python/lib/core/dmod/core/dataset.py @@ -5,13 +5,13 @@ from datetime import datetime, timedelta from .serializable import Serializable, ResultIndicator -from enum import Enum -from numbers import Number -from typing import Any, Callable, Dict, FrozenSet, List, Optional, Set, Tuple, Type, Union +from .enum import PydanticEnum +from typing import Any, Callable, ClassVar, Dict, FrozenSet, List, Optional, Set, Tuple, Type, Union +from pydantic import Field, validator, root_validator from uuid import UUID, uuid4 -class DatasetType(Enum): +class DatasetType(PydanticEnum): UNKNOWN = (-1, False, lambda dataset: None) OBJECT_STORE = (0, True, lambda dataset: dataset.name) FILESYSTEM = (1, True, lambda dataset: dataset.access_location) @@ -59,85 +59,70 @@ class Dataset(Serializable): Rrepresentation of the descriptive metadata for a grouped collection of data. """ - _SERIAL_DATETIME_STR_FORMAT = '%Y-%m-%d %H:%M:%S' - - _KEY_ACCESS_LOCATION = 'access_location' - _KEY_CREATED_ON = 'create_on' - _KEY_DATA_CATEGORY = 'data_category' - _KEY_DATA_DOMAIN = 'data_domain' - _KEY_DERIVED_FROM = 'derived_from' - _KEY_DERIVATIONS = 'derivations' - _KEY_DESCRIPTION = 'description' - _KEY_EXPIRES = 'expires' - _KEY_IS_READ_ONLY = 'is_read_only' - _KEY_LAST_UPDATE = 'last_updated' - _KEY_MANAGER_UUID = 'manager_uuid' - _KEY_NAME = 'name' - _KEY_TYPE = 'type' - _KEY_UUID = 'uuid' - - # TODO: move this (and something more to better automatically handle Serializable subtypes) to Serializable directly - @classmethod - def _date_parse_helper(cls, json_obj: dict, key: str) -> Optional[datetime]: - if key in json_obj: - return datetime.strptime(json_obj[key], cls.get_datetime_str_format()) - else: - return None - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - try: - manager_uuid = UUID(json_obj[cls._KEY_MANAGER_UUID]) if cls._KEY_MANAGER_UUID in json_obj else None - return cls(name=json_obj[cls._KEY_NAME], - category=DataCategory.get_for_name(json_obj[cls._KEY_DATA_CATEGORY]), - data_domain=DataDomain.factory_init_from_deserialized_json(json_obj[cls._KEY_DATA_DOMAIN]), - dataset_type=DatasetType.get_for_name(json_obj[cls._KEY_TYPE]), - access_location=json_obj[cls._KEY_ACCESS_LOCATION], - description=json_obj.get(cls._KEY_DESCRIPTION, None), - uuid=UUID(json_obj[cls._KEY_UUID]), - manager_uuid=manager_uuid, - is_read_only=json_obj[cls._KEY_IS_READ_ONLY], - expires=cls._date_parse_helper(json_obj, cls._KEY_EXPIRES), - derived_from=json_obj[cls._KEY_DERIVED_FROM] if cls._KEY_DERIVED_FROM in json_obj else None, - derivations=json_obj[cls._KEY_DERIVATIONS] if cls._KEY_DERIVATIONS in json_obj else [], - created_on=cls._date_parse_helper(json_obj, cls._KEY_CREATED_ON), - last_updated=cls._date_parse_helper(json_obj, cls._KEY_LAST_UPDATE)) - except Exception as e: + _SERIAL_DATETIME_STR_FORMAT: ClassVar = '%Y-%m-%d %H:%M:%S' + name: str = Field(description="The name for this dataset, which also should be a unique identifier.") + category: DataCategory = Field(None, alias="data_category", description="The ::class:`DataCategory` type value for this instance.") + data_domain: DataDomain + dataset_type: DatasetType = Field(DatasetType.UNKNOWN, alias="type") + access_location: str = Field(description="String representation of the location at which this dataset is accessible.") + uuid: Optional[UUID] = Field(default_factory=uuid4) + # manager can only be passed as constructed DatasetManager subtype. Manager not included in `dict` or `json` deserialization. + # TODO: don't include `manager` in `Dataset.schema()`. Inclusion is not reflective of the de/serialization behavior. + manager: Optional['DatasetManager'] = Field(exclude=True) + manager_uuid: Optional[UUID] + is_read_only: bool = Field(True, description="Whether this is a dataset that can only be read from.") + description: Optional[str] + expires: Optional[datetime] = Field(description='The time after which a dataset may "expire" and be removed, or ``None`` if the dataset is not temporary.') + derived_from: Optional[str] = Field(description="The name of the dataset from which this dataset was derived, if it is known to have been derived.") + derivations: Optional[List[str]] = Field(default_factory=list, description="""List of names of datasets which were derived from this dataset.\n + Note that it is not guaranteed that any such dataset still exist and/or are still available.""") + created_on: Optional[datetime] = Field(description="When this dataset was created, or ``None`` if that is not known.") + last_updated: Optional[datetime] + + @validator("created_on", "last_updated", "expires", pre=True) + def parse_dates(cls, v): + if v is None: return None - def __eq__(self, other): - return isinstance(other, Dataset) and self.name == other.name and self.category == other.category \ - and self.dataset_type == other.dataset_type and self.data_domain == other.data_domain \ - and self.access_location == other.access_location and self.is_read_only == other.is_read_only \ - and self.created_on == other.created_on + if isinstance(v, datetime): + return v + + return datetime.strptime(v, cls.get_datetime_str_format()) + + @validator("created_on", "last_updated", "expires") + def drop_microseconds(cls, v: datetime): + return v.replace(microsecond=0) + + @validator("manager", pre=True) + def drop_manager_if_not_constructed_subtype(cls, value): + # manager can only be passed as constructed DatasetManager subtype + if isinstance(value, DatasetManager): + return value + return None + + @root_validator() + def set_manager_uuid(cls, values) -> dict: + manager: Optional[DatasetManager] = values["manager"] + # give preference to `manager.uuid` otherwise use specified `manager_uuid` + if manager is not None: + # pydantic will not validate this, so we need to check it + if not isinstance(manager.uuid, UUID): + raise ValueError(f"Expected UUID got {type(manager.uuid)}") + values["manager_uuid"] = manager.uuid + + return values + + class Config: + # NOTE: re-validate when any field is re-assigned (i.e. `model.foo = 12`) + # TODO: in future deprecate setting properties unless through a setter method + validate_assignment = True + arbitrary_types_allowed = True + field_serializers = {"uuid": lambda f: str(f)} def __hash__(self): return hash(','.join([self.__class__.__name__, self.name, self.category.name, str(hash(self.data_domain)), self.access_location, str(self.is_read_only), str(hash(self.created_on))])) - def __init__(self, name: str, category: DataCategory, data_domain: DataDomain, dataset_type: DatasetType, - access_location: str, uuid: Optional[UUID] = None, manager: Optional['DatasetManager'] = None, - manager_uuid: Optional[UUID] = None, is_read_only: bool = True, description: Optional[str] = None, expires: Optional[datetime] = None, - derived_from: Optional[str] = None, derivations: Optional[List[str]] = None, - created_on: Optional[datetime] = None, last_updated: Optional[datetime] = None): - self._name = name - self._category = category - self._data_domain = data_domain - self._dataset_type = dataset_type - self._access_location = access_location - self._uuid = uuid4() if uuid is None else uuid - self._manager = manager - self._manager_uuid = manager.uuid if manager is not None else manager_uuid - self._description = description - self._is_read_only = is_read_only - self._expires = expires if expires is None else expires.replace(microsecond=0) - self._derived_from = derived_from - self._derivations = derivations if derivations is not None else list() - self._created_on = created_on if created_on is None else created_on.replace(microsecond=0) - self._last_updated = last_updated if last_updated is None else last_updated.replace(microsecond=0) - # TODO: have manager handle the logic - #retention_strategy - def _set_expires(self, new_expires: datetime): """ "Private" function to set the ::attribute:`expires` property. @@ -150,60 +135,8 @@ def _set_expires(self, new_expires: datetime): new_expires : datetime The new value for ::attribute:`expires`. """ - self._expires = new_expires - # n = datetime.now() - # n.astimezone().tzinfo.tzname(n.astimezone()) - self._last_updated = datetime.now() - - @property - def access_location(self) -> str: - """ - String representation of the location at which this dataset is accessible. - - Depending on the subtype, this may be the string form of a URL, URI, or basic filesystem path. - - Returns - ------- - str - String representation of the location at which this dataset is accessible. - """ - return self._access_location - - @property - def category(self) -> DataCategory: - """ - The ::class:`DataCategory` type value for this instance. - - Returns - ------- - DataCategory - The ::class:`DataCategory` type value for this instance. - """ - return self._category - - @property - def created_on(self) -> Optional[datetime]: - """ - When this dataset was created, or ``None`` if that is not known. - - Returns - ------- - Optional[datetime] - When this dataset was created, or ``None`` if that is not known. - """ - return self._created_on - - @property - def data_domain(self) -> DataDomain: - """ - The data domain for this instance. - - Returns - ------- - DataDomain - The ::class:`DataDomain` for this instance. - """ - return self._data_domain + self.expires = new_expires + self.last_updated = datetime.now() @property def data_format(self) -> DataFormat: @@ -217,53 +150,6 @@ def data_format(self) -> DataFormat: """ return self.data_domain.data_format - @property - def dataset_type(self) -> DatasetType: - return self._dataset_type - - @property - def derivations(self) -> List[str]: - """ - List of names of datasets which were derived from this dataset. - - Note that it is not guaranteed that any such dataset still exist and/or are still available. - - Returns - ------- - List[str] - List of names of datasets which were derived from this dataset. - """ - return self._derivations - - @property - def derived_from(self) -> Optional[str]: - """ - The name of the dataset from which this dataset was derived, if it is known to have been derived. - - Returns - ------- - Optional[str] - The name of the dataset from which this dataset was derived, or ``None`` if this dataset is not known to - have been derived. - """ - return self._derived_from - - @property - def description(self) -> Optional[str]: - """ - An optional string description of this dataset. - - Returns - ------- - Optional[str] - An optional string description of this dataset. - """ - return self._description - - @description.setter - def description(self, desc: Optional[str]): - self._description = desc - @property def docker_mount(self) -> str: """ @@ -289,22 +175,6 @@ def docker_mount(self) -> str: else: return result - @property - def expires(self) -> Optional[datetime]: - """ - The time after which a dataset may "expire" and be removed, or ``None`` if the dataset is not temporary. - - A dataset may be temporary, meaning its availability and validity cannot be assumed perpetually; e.g., the data - may be removed from storage. This property indicates the time through which availability and validity is - guaranteed. - - Returns - ------- - Optional[datetime] - The time after which a dataset may "expire" and be removed, or ``None`` if the dataset is not temporary. - """ - return self._expires - def extend_life(self, value: Union[datetime, timedelta]) -> bool: """ Extend the expiration of this dataset. @@ -335,7 +205,7 @@ def extend_life(self, value: Union[datetime, timedelta]) -> bool: if not self.is_temporary: return False elif isinstance(value, timedelta): - self._set_expires(self._expires + value) + self._set_expires(self.expires + value) return True elif isinstance(value, datetime) and self.expires < value: self._set_expires(value) @@ -355,18 +225,6 @@ def fields(self) -> Dict[str, Type]: """ return self.data_domain.data_fields - @property - def is_read_only(self) -> bool: - """ - Whether this is a dataset that can only be read from. - - Returns - ------- - bool - Whether this is a dataset that can only be read from. - """ - return self._is_read_only - @property def is_temporary(self) -> bool: """ @@ -382,64 +240,6 @@ def is_temporary(self) -> bool: """ return self.expires is not None - @property - def last_updated(self) -> Optional[datetime]: - """ - When this dataset was last updated, or ``None`` if that is not known. - - Note that this includes adjustments to metadata, including the value for ::attribute:`expires`. - - Returns - ------- - Optional[datetime] - When this dataset was last updated, or ``None`` if that is not known. - """ - return self._last_updated - - @property - def manager(self) -> 'DatasetManager': - """ - The ::class:`DatasetManager` for this instance. - - Returns - ------- - DatasetManager - The ::class:`DatasetManager` for this instance. - """ - return self._manager - - @manager.setter - def manager(self, manager: 'DatasetManager'): - self._manager = manager - self._manager_uuid = manager.uuid - - @property - def manager_uuid(self) -> UUID: - """ - The UUID of the ::class:`DatasetManager` for this instance. - - Returns - ------- - DatasetManager - The UUID of the ::class:`DatasetManager` for this instance. - """ - return self._manager_uuid - - @property - def name(self) -> str: - """ - The name for this dataset, which also should be a unique identifier. - - Every dataset in the domain of all datasets known to this instance's ::attribute:`manager` must have a unique - name value. - - Returns - ------- - str - The dataset's unique name. - """ - return self._name - @property def time_range(self) -> Optional[TimeRange]: """ @@ -456,51 +256,17 @@ def time_range(self) -> Optional[TimeRange]: tr = self.data_domain.continuous_restrictions[StandardDatasetIndex.TIME] return tr if isinstance(tr, TimeRange) else TimeRange(begin=tr.begin, end=tr.end, variable=tr.variable) - @property - def uuid(self) -> UUID: - """ - The UUID for this instance. + def _get_exclude_fields(self) -> Set[str]: + """Set of fields to exclude during deserialization if they are some None variant (e.g. '', 0, None)""" + candidates = ("manager_uuid", "expires", "derived_from", "derivations", "description", "created_on", "last_updated") + return {f for f in candidates if not self.__getattribute__(f)} - Returns - ------- - UUID - The UUID for this instance. - """ - return self._uuid + def dict(self, **kwargs) -> dict: + # if exclude is set, ignore this _get_exclude_fields() + exclude = self._get_exclude_fields() if kwargs.get("exclude", False) is False else kwargs["exclude"] + kwargs["exclude"] = exclude - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - """ - Get the serial form of this instance as a dictionary object. - - Returns - ------- - Dict[str, Union[str, Number, dict, list]] - The serialized form of this instance. - """ - serial = dict() - serial[self._KEY_NAME] = self.name - serial[self._KEY_DATA_CATEGORY] = self.category.name - serial[self._KEY_DATA_DOMAIN] = self.data_domain.to_dict() - serial[self._KEY_TYPE] = self.dataset_type.name - # TODO: unit test this - serial[self._KEY_ACCESS_LOCATION] = self.access_location - serial[self._KEY_UUID] = str(self.uuid) - serial[self._KEY_IS_READ_ONLY] = self.is_read_only - if self.manager_uuid is not None: - serial[self._KEY_MANAGER_UUID] = str(self.manager_uuid) - if self.expires is not None: - serial[self._KEY_EXPIRES] = self.expires.strftime(self.get_datetime_str_format()) - if self.derived_from is not None: - serial[self._KEY_DERIVED_FROM] = self.derived_from - if len(self.derivations) > 0: - serial[self._KEY_DERIVATIONS] = self.derivations - if self.description is not None: - serial[self._KEY_DESCRIPTION] = self.description - if self.created_on is not None: - serial[self._KEY_CREATED_ON] = self.created_on.strftime(self.get_datetime_str_format()) - if self.last_updated is not None: - serial[self._KEY_LAST_UPDATE] = self.last_updated.strftime(self.get_datetime_str_format()) - return serial + return super().dict(**kwargs) class DatasetUser(ABC): diff --git a/python/lib/core/dmod/core/decorators/__init__.py b/python/lib/core/dmod/core/decorators/__init__.py index da8bc7c13..82b1d09bb 100644 --- a/python/lib/core/dmod/core/decorators/__init__.py +++ b/python/lib/core/dmod/core/decorators/__init__.py @@ -8,6 +8,7 @@ from .decorator_functions import initializer from .decorator_functions import additional_parameter +from .decorator_functions import deprecated from .message_handlers import socket_handler from .message_handlers import client_message_handler diff --git a/python/lib/core/dmod/core/decorators/decorator_functions.py b/python/lib/core/dmod/core/decorators/decorator_functions.py index 4af7abb25..13f3e25d3 100644 --- a/python/lib/core/dmod/core/decorators/decorator_functions.py +++ b/python/lib/core/dmod/core/decorators/decorator_functions.py @@ -2,6 +2,8 @@ Defines common decorators """ import typing +from warnings import warn +from functools import wraps from .decorator_constants import * @@ -77,3 +79,18 @@ def additional_parameter(function): if not hasattr(function, ADDITIONAL_PARAMETER_ATTRIBUTE): setattr(function, ADDITIONAL_PARAMETER_ATTRIBUTE, True) return function + + +def deprecated(deprecation_message: str): + def function_to_deprecate(fn): + + @wraps(fn) + def wrapper(*args, **kwargs): + warn(deprecation_message, DeprecationWarning) + return fn(*args, **kwargs) + + return wrapper + + return function_to_deprecate + + diff --git a/python/lib/core/dmod/core/enum.py b/python/lib/core/dmod/core/enum.py new file mode 100644 index 000000000..221664d21 --- /dev/null +++ b/python/lib/core/dmod/core/enum.py @@ -0,0 +1,85 @@ +from enum import Enum +from pydantic.fields import ModelField +from pprint import pformat + +from typing import Any, Dict, Union + +# inspiration from https://github.com/pydantic/pydantic/issues/598 +class PydanticEnum(Enum): + """ + Subtypes of this enum variant that are embedded in a pydantic model will be: + - coerced into an enum instance using member name (case insensitive) + - and expose member names (upper case) in model json schema. + + + Example: + ```python + class PowerState(PydanticEnum): + OFF = 0 + ON = 1 + + class Appliance(pydantic.BaseModel): + power_state: PowerState + ... + + Appliance(power_state=PowerState.ON) + Appliance(power_state="ON") + Appliance(power_state="on") + + Appliance(power_state=1) # invalid + ``` + + Note, `PydanticEnum` subtypes with member names that case-intensively match will yield + undesirable behavior. + """ + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any], field: ModelField) -> None: + """Method used by pydantic to populate json schema fields and their associated types.""" + # display enum field names as field options + if "enum" in field_schema: + field_schema["enum"] = [f.name.upper() for f in field.type_] + field_schema["type"] = "string" + + @classmethod + def __get_validators__(cls): + """Method used by pydantic to retrieve a class's validators.""" + yield cls.validate + + @classmethod + def validate(cls, v: Union[Enum, str]): + """ + Method used by pydantic to validate and potentially coerce a `v` into a `cls` enum type. + + Coercion from a `str` into a `cls` enum instance is performed _case-insensitively_ based on + the `cls` enum's `name` fields. For example, enum Foo with member `bar = 1` is coercible by + providing `"bar"`, _not_ `1`. + + Example: + ```python + class Foo(PydanticEnum): + bar = 1 + + class Model(pydantic.BaseModel): + foo: Foo + + Model(foo=Foo.bar) # valid + Model(foo="bar") # valid + Model(foo="BAR") # valid + + Model(foo=1) # invalid + ``` + """ + if isinstance(v, cls): + return v + + v = str(v).upper() + + for name, value in cls.__members__.items(): + if name.upper() == v: + return value + + error_message = pformat( + f"Invalid Enum field. Field {v!r} is not a member of {set(cls.__members__)}" + ) + raise ValueError(error_message) diff --git a/python/lib/core/dmod/core/execution.py b/python/lib/core/dmod/core/execution.py index f8f99f50b..9286b037b 100644 --- a/python/lib/core/dmod/core/execution.py +++ b/python/lib/core/dmod/core/execution.py @@ -1,8 +1,9 @@ -from enum import Enum from typing import Optional +from .enum import PydanticEnum -class AllocationParadigm(Enum): + +class AllocationParadigm(PydanticEnum): """ Representation of the ways compute assets may be combined to fulfill a total required asset amount for a task. diff --git a/python/lib/core/dmod/core/meta_data.py b/python/lib/core/dmod/core/meta_data.py index 9c74c9fe0..53269adcf 100644 --- a/python/lib/core/dmod/core/meta_data.py +++ b/python/lib/core/dmod/core/meta_data.py @@ -1,13 +1,15 @@ -from enum import Enum from datetime import datetime +from .enum import PydanticEnum from .serializable import Serializable from numbers import Number -from typing import Any, Dict, List, Optional, Set, Type, Union +from typing import Any, Dict, List, Literal, Optional, Set, Type, Union from collections.abc import Iterable +from collections import OrderedDict +from pydantic import root_validator, validator, PyObject, Field, StrictStr, StrictFloat, StrictInt -class StandardDatasetIndex(Enum): +class StandardDatasetIndex(PydanticEnum): UNKNOWN = (-1, Any) TIME = (0, datetime) @@ -34,8 +36,13 @@ def get_for_name(cls, name_str: str) -> 'StandardDatasetIndex': return value return StandardDatasetIndex.UNKNOWN +def _validate_variable_is_known(cls, variable: StandardDatasetIndex) -> StandardDatasetIndex: + if variable == StandardDatasetIndex.UNKNOWN: + raise ValueError("Invalid value for {} variable: {}".format(cls.__name__, variable)) + return variable -class DataFormat(Enum): + +class DataFormat(PydanticEnum): """ Supported data format types for data needed or produced by workflow execution tasks. @@ -215,6 +222,63 @@ class ContinuousRestriction(Serializable): """ A filtering component, typically applied as a restriction on a domain, by a continuous range of values of a variable. """ + variable: StandardDatasetIndex + begin: datetime + end: datetime + datetime_pattern: Optional[str] + subclass: PyObject + + @root_validator(pre=True) + def coerce_times_if_datetime_pattern(cls, values): + subclass_str = values.get("subclass") + + if subclass_str is None: + values["subclass"] = cls + + if isinstance(subclass_str, str): + if subclass_str == cls.__name__: + values["subclass"] = cls + + datetime_ptr = values.get("datetime_pattern") + + if datetime_ptr is not None: + # If there is a datetime pattern, then expect begin and end to parse properly to datetime objects + begin = values["begin"] + end = values["end"] + + if not isinstance(begin, datetime): + values["begin"] = datetime.strptime(begin, datetime_ptr) + + if not isinstance(end, datetime): + values["end"] = datetime.strptime(end, datetime_ptr) + return values + + @root_validator() + def validate_start_before_end(cls, values): + if values["begin"] > values["end"]: + raise RuntimeError("Cannot have {} with begin value larger than end.".format(cls.__name__)) + + return values + + # validate variable is not UNKNOWN variant + _validate_variable = validator("variable", allow_reuse=True)(_validate_variable_is_known) + + class Config: + def _serialize_datetime(self: "ContinuousRestriction", value: datetime) -> str: + if self.datetime_pattern is not None: + return value.strftime(self.datetime_pattern) + return str(value) + + field_serializers = { + "begin": _serialize_datetime, + "end": _serialize_datetime, + "subclass": lambda value: value.__name__ + } + + def __eq__(self, o: object) -> bool: + if not isinstance(o, ContinuousRestriction): + return False + return self.variable == o.variable and self.begin == o.begin and self.end == o.end @classmethod def convert_truncated_serial_form(cls, truncated_json_obj: dict, datetime_format: Optional[str] = None) -> dict: @@ -251,61 +315,28 @@ def convert_truncated_serial_form(cls, truncated_json_obj: dict, datetime_format @classmethod def factory_init_from_deserialized_json(cls, json_obj: dict): - datetime_ptr = json_obj["datetime_pattern"] if "datetime_pattern" in json_obj else None - try: - variable = StandardDatasetIndex.get_for_name(json_obj['variable']) - if variable == StandardDatasetIndex.UNKNOWN: - raise RuntimeError( - "Unrecognized continuous restriction serialize variable: {}".format(json_obj['variable'])) - # Handle simple case, which currently means non-datetime item (i.e., no pattern included) - if datetime_ptr is None: - return cls(variable=variable, begin=json_obj["begin"], end=json_obj["end"]) - - # If there is a datetime pattern, then expect begin and end to parse properly to datetime objects - begin = datetime.strptime(json_obj["begin"], datetime_ptr) - end = datetime.strptime(json_obj["end"], datetime_ptr) + if "subclass" in json_obj: + try: + subclass_str = json_obj["subclass"] - # Use this type if that's what the JSON specifies is the Serializable subtype - if cls.__name__ == json_obj["subclass"]: - return cls(variable=variable, begin=begin, end=end, datetime_pattern=datetime_ptr) + if subclass_str == cls.__name__: + json_obj["subclass"] = cls + return cls(**json_obj) - # Try to initialize the right subclass type, or fall back if appropriate to the base type - # TODO: consider adding something for recursive search for subclass, not just immediate children types - # Use nested try, because we want to fall back to cls type if no subclass attempt or subclass attempt fails - try: for subclass in cls.__subclasses__(): - if subclass.__name__ == json_obj["subclass"]: - return subclass(variable=variable, begin=begin, end=end, datetime_pattern=datetime_ptr) + if subclass.__name__ == subclass_str: + json_obj["subclass"] = subclass + return subclass(**json_obj) except: pass - # Fall back if needed - return cls(variable=variable, begin=begin, end=end, datetime_pattern=datetime_ptr) + try: + return cls(**json_obj) except: return None - def __init__(self, variable: Union[str, StandardDatasetIndex], begin, end, datetime_pattern: Optional[str] = None): - self.variable = StandardDatasetIndex.get_for_name(variable) if isinstance(variable, str) else variable - if self.variable == StandardDatasetIndex.UNKNOWN: - raise ValueError("Invalid value for {} variable: {}".format(self.__class__.__name__, variable)) - if begin > end: - raise RuntimeError("Cannot have {} with begin value larger than end.".format(self.__class__.__name__)) - self.begin = begin - self.end = end - self._datetime_pattern = datetime_pattern - - def __eq__(self, other): - if self.__class__ == other.__class__ or isinstance(other, self.__class__): - return self.variable == other.variable and self.begin == other.begin and self.end == other.end \ - and self._datetime_pattern == other._datetime_pattern - elif isinstance(self, other.__class__): - return other.__eq__(self) - else: - return False - def __hash__(self): - str_func = lambda x: str(x) if self._datetime_pattern is None else datetime.strptime(x, self._datetime_pattern) - hash('{}-{}-{}'.format(self.variable.name, str_func(self.begin), str_func(self.end))) + return hash('{}-{}-{}'.format(self.variable.name, self.begin, self.end)) def contains(self, other: 'ContinuousRestriction') -> bool: """ @@ -329,19 +360,6 @@ def contains(self, other: 'ContinuousRestriction') -> bool: else: return self.begin <= other.begin and self.end >= other.end - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serial = dict() - serial["variable"] = self.variable.name - serial["subclass"] = self.__class__.__name__ - if self._datetime_pattern is not None: - serial["datetime_pattern"] = self._datetime_pattern - serial["begin"] = self.begin.strftime(self._datetime_pattern) - serial["end"] = self.end.strftime(self._datetime_pattern) - else: - serial["begin"] = self.begin - serial["end"] = self.end - return serial - class DiscreteRestriction(Serializable): """ @@ -350,35 +368,22 @@ class DiscreteRestriction(Serializable): Note that an empty list for the ::attribute:`values` property implies a restriction of all possible values being required. This is reflected by the :method:`is_all_possible_values` property. """ - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - try: - variable = StandardDatasetIndex.get_for_name(json_obj["variable"]) - if variable == StandardDatasetIndex.UNKNOWN: - return None - return cls(variable=variable, values=json_obj["values"]) - except: - return None + variable: StandardDatasetIndex + values: Union[List[StrictStr], List[StrictFloat], List[StrictInt]] + + # validate variable is not UNKNOWN variant + _validate_variable = validator("variable", allow_reuse=True)(_validate_variable_is_known) def __init__(self, variable: Union[str, StandardDatasetIndex], values: Union[List[str], List[Number]], allow_reorder: bool = True, - remove_duplicates: bool = True): - self.variable = StandardDatasetIndex.get_for_name(variable) if isinstance(variable, str) else variable - if self.variable == StandardDatasetIndex.UNKNOWN: - raise ValueError("Invalid value for {} variable: {}".format(self.__class__.__name__, variable)) - self.values: Union[List[str], List[Number]] = list(set(values)) if remove_duplicates else values + remove_duplicates: bool = True, **kwargs): + super().__init__(variable=variable, values=values, **kwargs) + if remove_duplicates: + self.values = list(OrderedDict.fromkeys(self.values)) if allow_reorder: self.values.sort() - def __eq__(self, other): - if self.__class__ == other.__class__ or isinstance(other, self.__class__): - return self.variable == other.variable and self.values == other.values - elif isinstance(self, other.__class__): - return other.__eq__(self) - else: - return False - - def __hash__(self): - hash('{}-{}'.format(self.variable.name, ','.join([str(v) for v in self.values]))) + def __hash__(self) -> int: + return hash('{}-{}'.format(self.variable.name, ','.join([str(v) for v in self.values]))) def contains(self, other: 'DiscreteRestriction') -> bool: """ @@ -433,40 +438,72 @@ def is_all_possible_values(self) -> bool: """ return self.values is not None and len(self.values) == 0 - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - return {"variable": self.variable.name, "values": self.values} - class DataDomain(Serializable): """ A domain for a dataset, with domain-defining values contained by one or more discrete and/or continuous components. """ + data_format: DataFormat = Field( + description="The format for the data in this domain, which contains details like the indices and other data fields." + ) + continuous_restrictions: Optional[List[ContinuousRestriction]] = Field( + description="Map of the continuous restrictions defining this domain, keyed by variable name.", + alias="continuous", + default_factory=list + ) + discrete_restrictions: Optional[List[DiscreteRestriction]] = Field( + description="Map of the discrete restrictions defining this domain, keyed by variable name.", + alias="discrete", + default_factory=list + ) + # NOTE: remove this field after #239 is merged. will close #245. + custom_data_fields: Optional[Dict[str, Union[str, int, float, Any]]] = Field( + description=("This will either be directly from the format, if its format specifies any fields, or from a custom fields" + "attribute that may be set during initialization (but is ignored when the format specifies fields)."), + alias="data_fields" + ) + + @validator("continuous_restrictions", pre=True, each_item=True) + def _factory_init_continuous_restrictions(cls, value): + if isinstance(value, ContinuousRestriction): + return value + return ContinuousRestriction.factory_init_from_deserialized_json(value) + + @validator("continuous_restrictions", "discrete_restrictions", always=True) + def _validate_restriction_default(cls, value): + if value is None: + return [] + return value + + @validator("custom_data_fields") + def validate_data_fields(cls, values): + def handle_type_map(t): + if t == "str" or t == str: + return str + elif t == "int" or t == int: + return int + elif t == "float" or t == float: + return float + elif t == "bool" or t == bool: + return bool + # maintain reference to a passed in python type or subtype + elif isinstance(t, type): + return t + return Any + + if values is None: + return None - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - try: - data_format = DataFormat.get_for_name(json_obj["data_format"]) - continuous = [ContinuousRestriction.factory_init_from_deserialized_json(c) for c in json_obj["continuous"]] - discrete = [DiscreteRestriction.factory_init_from_deserialized_json(d) for d in json_obj["discrete"]] - if 'data_fields' in json_obj: - data_fields = dict() - for key in json_obj['data_fields']: - val = json_obj['data_fields'][key] - if val == 'str': - data_fields[key] = str - elif val == 'int': - data_fields[key] = int - elif val == 'float': - data_fields[key] = float - else: - data_fields[key] = Any - else: - data_fields = None + return {k: handle_type_map(v) for k, v in values.items()} - return cls(data_format=data_format, continuous_restrictions=continuous, discrete_restrictions=discrete, - custom_data_fields=data_fields) - except: - return None + @root_validator() + def validate_sufficient_restrictions(cls, values): + continuous_restrictions = values.get("continuous_restrictions", []) + discrete_restrictions = values.get("discrete_restrictions", []) + if len(continuous_restrictions) + len(discrete_restrictions) == 0: + msg = "Cannot create {} without at least one finite continuous or discrete restriction" + raise RuntimeError(msg.format(cls.__name__)) + return values @classmethod def factory_init_from_restriction_collections(cls, data_format: DataFormat, **kwargs) -> 'DataDomain': @@ -537,12 +574,6 @@ def factory_init_from_restriction_collections(cls, data_format: DataFormat, **kw continuous_restrictions=None if len(continuous) == 0 else continuous, discrete_restrictions=None if len(discrete) == 0 else discrete) - def __eq__(self, other): - return self.__class__ == other.__class__ and self.data_format == other.data_format \ - and self.continuous_restrictions == other.continuous_restrictions \ - and self.discrete_restrictions == other.discrete_restrictions \ - and self._custom_data_fields == other._custom_data_fields - def __hash__(self): if self._custom_data_fields is None: cu = '' @@ -554,27 +585,6 @@ def __hash__(self): ','.join([str(hash(self.discrete_restrictions[k])) for k in sorted(self.discrete_restrictions)]), cu)) - def __init__(self, data_format: DataFormat, continuous_restrictions: Optional[List[ContinuousRestriction]] = None, - discrete_restrictions: Optional[List[DiscreteRestriction]] = None, - custom_data_fields: Optional[Dict[str, Type]] = None): - self._data_format = data_format - self._continuous_restrictions = dict() - self._discrete_restrictions = dict() - self._custom_data_fields = custom_data_fields - """ Extra attribute for custom data fields when format does not specify all data fields (ignore when format does specify). """ - - if continuous_restrictions is not None: - for c in continuous_restrictions: - self._continuous_restrictions[c.variable] = c - - if discrete_restrictions is not None: - for d in discrete_restrictions: - self._discrete_restrictions[d.variable] = d - - if len(self._continuous_restrictions) + len(self._discrete_restrictions) == 0: - msg = "Cannot create {} without at least one finite continuous or discrete restriction" - raise RuntimeError(msg.format(self.__class__.__name__)) - def _extends_continuous_restriction(self, continuous_restriction: ContinuousRestriction) -> bool: idx = continuous_restriction.variable return idx in self.continuous_restrictions and self.continuous_restrictions[idx].contains(continuous_restriction) @@ -612,30 +622,6 @@ def contains(self, other: Union[ContinuousRestriction, DiscreteRestriction, 'Dat return False return True - @property - def continuous_restrictions(self) -> Dict[StandardDatasetIndex, ContinuousRestriction]: - """ - Map of the continuous restrictions defining this domain, keyed by variable name. - - Returns - ------- - Dict[str, ContinuousRestriction] - Map of the continuous restrictions defining this domain, keyed by variable name. - """ - return self._continuous_restrictions - - @property - def discrete_restrictions(self) -> Dict[StandardDatasetIndex, DiscreteRestriction]: - """ - Map of the discrete restrictions defining this domain, keyed by variable name. - - Returns - ------- - Dict[str, DiscreteRestriction] - Map of the discrete restrictions defining this domain, keyed by variable name. - """ - return self._discrete_restrictions - @property def data_fields(self) -> Dict[str, Type]: """ @@ -649,23 +635,9 @@ def data_fields(self) -> Dict[str, Type]: """ if self.data_format.data_fields is None: - return self._custom_data_fields + return self.custom_data_fields else: - return self._data_format.data_fields - - @property - def data_format(self) -> DataFormat: - """ - The format for data in this domain. - - The format for the data in this domain, which contains details like the indices and other data fields. - - Returns - ------- - DataFormat - The format for data in this domain. - """ - return self._data_format + return self.data_format.data_fields @property def indices(self) -> List[str]: @@ -680,37 +652,70 @@ def indices(self) -> List[str]: List[str] List of the string forms of the ::class:`StandardDataIndex` indices that define this domain. """ - return self._data_format.indices - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - """ - Serialize to a dictionary. - - Serialize this instance to a dictionary, with there being two top-level list items. These are made from the - the contained ::class:`ContinuousRestriction` and ::class:`DiscreteRestriction` objects - - Returns - ------- - - """ - serial = {"data_format": self._data_format.name, - "continuous": [component.to_dict() for idx, component in self.continuous_restrictions.items()], - "discrete": [component.to_dict() for idx, component in self.discrete_restrictions.items()]} - if self.data_format.data_fields is None: - serial['data_fields'] = dict() - for key in self._custom_data_fields: - if self._custom_data_fields[key] == str: - serial['data_fields'][key] = 'str' - elif self._custom_data_fields[key] == int: - serial['data_fields'][key] = 'int' - elif self._custom_data_fields[key] == float: - serial['data_fields'][key] = 'float' - else: - serial['data_fields'][key] = 'Any' + return self.data_format.indices + + @staticmethod + def _encode_py_type(o: type) -> str: + """Return string representation of a built in type (e.g. 'int') or 'Any'.""" + if o in {str, int, float, bool}: + return o.__name__ + return "Any" + + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> Dict[str, Union[str, int]]: + """ + `data_fields` is excluded from dict if `self.data_format.data_fields` is None. + + called by `to_dict` and `to_json`. + """ + DATA_FIELDS_KEY = "custom_data_fields" + DATA_FIELDS_ALIAS_KEY = "data_fields" + + exclude = exclude or set() + + exclude_data_fields = DATA_FIELDS_KEY in exclude + exclude.add(DATA_FIELDS_KEY) + + serial = super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + # NOTE: `custom_data_fields` is excluded if it is a empty T variant. This breaks with + # Serializable's convention to only exclude `None` value fields. + if exclude_data_fields or self.data_format.data_fields or not self.custom_data_fields: + return serial + + # serialize "custom_data_fields" python types + custom_data_fields = ( + {k: self._encode_py_type(v) for k, v in self.custom_data_fields.items()} + if self.custom_data_fields is not None + else dict() + ) + + if by_alias: + serial[DATA_FIELDS_ALIAS_KEY] = custom_data_fields + return serial + + serial[DATA_FIELDS_KEY] = custom_data_fields return serial -class DataCategory(Enum): +class DataCategory(PydanticEnum): """ The general category values for different data. """ @@ -733,173 +738,32 @@ class TimeRange(ContinuousRestriction): """ Encapsulated representation of a time range. """ - - def __init__(self, begin: Union[str, datetime], end: Union[str, datetime], datetime_pattern: Optional[str] = None, - **kwargs): - dt_ptrn = self.get_datetime_str_format() if datetime_pattern is None else datetime_pattern - super(TimeRange, self).__init__(variable=StandardDatasetIndex.TIME, - begin=begin if isinstance(begin, datetime) else datetime.strptime(begin, dt_ptrn), - end=end if isinstance(end, datetime) else datetime.strptime(end, dt_ptrn), - datetime_pattern=dt_ptrn) + variable: StandardDatasetIndex = Field(StandardDatasetIndex.TIME, const=True) class DataRequirement(Serializable): """ A definition of a particular data requirement needed for an execution task. """ - - _KEY_CATEGORY = 'category' - """ Serialization dictionary JSON key for ::attribute:`category` property value. """ - _KEY_DOMAIN = 'domain' - """ Serialization dictionary JSON key for ::attribute:`domain_params` property value. """ - _KEY_FULFILLED_ACCESS_AT = 'fulfilled_access_at' - """ Serialization dictionary JSON key for ::attribute:`fulfilled_access_at` property value. """ - _KEY_FULFILLED_BY = 'fulfilled_by' - """ Serialization dictionary JSON key for ::attribute:`fulfilled_by` property value. """ - _KEY_IS_INPUT = 'is_input' - """ Serialization dictionary JSON key for ::attribute:`is_input` property value. """ - _KEY_SIZE = 'size' - """ Serialization dictionary JSON key for ::attribute:`size` property value. """ - - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict) -> Optional['DataRequirement']: - """ - Deserialize the given JSON to a ::class:`DataRequirement` instance, or return ``None`` if it is not valid. - - Parameters - ---------- - json_obj : dict - The JSON to be deserialized. - - Returns - ------- - Optional[DataRequirement] - A deserialized ::class:`DataRequirement` instance, or return ``None`` if the JSON is not valid. - """ - try: - domain = DataDomain.factory_init_from_deserialized_json(json_obj[cls._KEY_DOMAIN]) - category = DataCategory.get_for_name(json_obj[cls._KEY_CATEGORY]) - is_input = json_obj[cls._KEY_IS_INPUT] - - opt_kwargs_w_defaults = dict() - if cls._KEY_FULFILLED_BY in json_obj: - opt_kwargs_w_defaults['fulfilled_by'] = json_obj[cls._KEY_FULFILLED_BY] - if cls._KEY_SIZE in json_obj: - opt_kwargs_w_defaults['size'] = json_obj[cls._KEY_SIZE] - if cls._KEY_FULFILLED_ACCESS_AT in json_obj: - opt_kwargs_w_defaults['fulfilled_access_at'] = json_obj[cls._KEY_FULFILLED_ACCESS_AT] - - return cls(domain=domain, is_input=is_input, category=category, **opt_kwargs_w_defaults) - except: - return None - - def __eq__(self, other): - return self.__class__ == other.__class__ and self.domain == other.domain and self.is_input == other.is_input \ - and self.category == other.category + category: DataCategory + domain: DataDomain + fulfilled_access_at: Optional[str] = Field(description="The location at which the fulfilling dataset for this requirement is accessible, if the dataset known.") + fulfilled_by: Optional[str] = Field(description="The name of the dataset that will fulfill this, if it is known.") + is_input: bool = Field(description="Whether this represents required input data, as opposed to a requirement for storing output data.") + size: Optional[int] + + def __eq__(self, other: object) -> bool: + return ( + self.__class__ == other.__class__ + and self.domain == other.domain + and self.is_input == other.is_input + and self.category == other.category + ) def __hash__(self): return hash('{}-{}-{}'.format(hash(self.domain), self.is_input, self.category)) - def __init__(self, domain: DataDomain, is_input: bool, category: DataCategory, size: Optional[int] = None, - fulfilled_by: Optional[str] = None, fulfilled_access_at: Optional[str] = None): - self._domain = domain - self._is_input = is_input - self._category = category - self._size = size - self._fulfilled_by = fulfilled_by - self._fulfilled_access_at = fulfilled_access_at - - @property - def category(self) -> DataCategory: - """ - The ::class:`DataCategory` of data required. - - Returns - ------- - DataCategory - The category of data required. - """ - return self._category - - @property - def domain(self) -> DataDomain: - """ - The (restricted) domain of the data that is required. - - Returns - ------- - DataDomain - The (restricted) domain of the data that is required. - """ - return self._domain - - @property - def fulfilled_access_at(self) -> Optional[str]: - """ - The location at which the fulfilling dataset for this requirement is accessible, if the dataset known. - - Returns - ------- - Optional[str] - The location at which the fulfilling dataset for this requirement is accessible, if known, or ``None`` - otherwise. - """ - return self._fulfilled_access_at - - @fulfilled_access_at.setter - def fulfilled_access_at(self, location: str): - self._fulfilled_access_at = location - - @property - def fulfilled_by(self) -> Optional[str]: - """ - The name of the dataset that will fulfill this, if it is known. - - Returns - ------- - Optional[str] - The name of the dataset that will fulfill this, if it is known; ``None`` otherwise. - """ - return self._fulfilled_by - - @fulfilled_by.setter - def fulfilled_by(self, name: str): - self._fulfilled_by = name - - @property - def is_input(self) -> bool: - """ - Whether this represents required input data, as opposed to a requirement for storing output data. - - Returns - ------- - bool - Whether this represents required input data. - """ - return self._is_input - - @property - def size(self) -> Optional[int]: - """ - The size of the required data, if it is known. - - This is particularly important (though still not strictly required) for an output data requirement; i.e., a - requirement to store output data somewhere. - - Returns - ------- - Optional[int] - he size of the required data, if it is known, or ``None`` otherwise. - """ - return self._size - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serial = {self._KEY_DOMAIN: self.domain.to_dict(), self._KEY_IS_INPUT: self.is_input, - self._KEY_CATEGORY: self.category.name} - if self.size is not None: - serial[self._KEY_SIZE] = self.size - if self.fulfilled_by is not None: - serial[self._KEY_FULFILLED_BY] = self.fulfilled_by - if self.fulfilled_access_at is not None: - serial[self._KEY_FULFILLED_ACCESS_AT] = self.fulfilled_access_at - return serial + def dict(self, **kwargs) -> dict: + exclude_unset = True if kwargs.get("exclude_unset") is None else False + kwargs["exclude_unset"] = exclude_unset + return super().dict(**kwargs) diff --git a/python/lib/core/dmod/core/serializable.py b/python/lib/core/dmod/core/serializable.py index cde509b4d..59331a8f6 100644 --- a/python/lib/core/dmod/core/serializable.py +++ b/python/lib/core/dmod/core/serializable.py @@ -1,14 +1,42 @@ -from abc import ABC, abstractmethod +from abc import ABC from numbers import Number -from typing import Callable, Dict, Type, Union +from enum import Enum +from typing import Any, Callable, ClassVar, Dict, Type, TypeVar, TYPE_CHECKING, Union, Optional +from typing_extensions import Self, TypeAlias +from pydantic import BaseModel, Field +from functools import lru_cache +import inspect import json +from .decorators import deprecated -class Serializable(ABC): +if TYPE_CHECKING: + from pydantic.typing import ( + AbstractSetIntStr, + MappingIntStrAny, + DictStrAny + ) + +M = TypeVar("M", bound="Serializable") +T = TypeVar("T") +R = Union[str, int, float, bool, None] + +FnSerializer: TypeAlias = Callable[[T], R] +SelfFieldSerializer: TypeAlias = Callable[[M, T], R] +FieldSerializer = Union[SelfFieldSerializer[M, Any], FnSerializer[Any]] + + +class Serializable(BaseModel, ABC): """ An interface class for an object that can be serialized to a dictionary-like format (i.e., potentially a JSON object) and JSON string format based directly from dumping the aforementioned dictionary-like representation. + Subtypes of `Serializable` should specify their fields following + [`pydantic.BaseModel`](https://docs.pydantic.dev/usage/models/) semantics (see example below). + Notably, `to_dict` and `to_json` will exclude `None` fields and serialize fields using any + provided aliases (i.e. `pydantic.Field(alias="some_alias")`). Also, enum subtypes are + serialized using their member `name` property. + Objects of this type will also used the JSON string format as their default string representation. While not strictly enforced (because this probably isn't possible), it is HIGHLY recommended that instance @@ -20,13 +48,65 @@ class Serializable(ABC): An exception to the aforementioned recommendation is the ::class:`datetime.datetime` type. Subtype attributes of ::class:`datetime.datetime` type should be parsed and serialized using the pattern returned by the ::method:`get_datetime_str_format` class method. A reasonable default is provided in the base interface class, but - the pattern can be adjusted eitehr by overriding the class method directly or by having a subtypes set/override + the pattern can be adjusted either by overriding the class method directly or by having a subtypes set/override its ::attribute:`_SERIAL_DATETIME_STR_FORMAT` class attribute. Note that the actual parsing/serialization logic is left entirely to the subtypes, as many will not need it (and thus should not have to worry about implement another method or have their superclass bloated by importing the ``datetime`` package). + + Example: + ``` + # specify field as class variable, specify final type using type hint. + # pydantic will try to coerce a field into the specified type, if it can't, a + # `pydantic.ValidationError` is raised. + + class User(Serializable): + id: int + username: str + email: str # more appropriately, `pydantic.EmailStr` + + >>> user = User(id=1, username="uncle_sam", email="uncle_sam@fake.gov") + >>> user.to_dict() # {"id": 1, "username": "uncle_sam", "email": "uncle_sam@fake.gov"} + >>> user.to_json() # '{"id": 1, "username": "uncle_sam", "email": "uncle_sam@fake.gov"}' + ``` """ - _SERIAL_DATETIME_STR_FORMAT = '%Y-%m-%d %H:%M:%S' + _SERIAL_DATETIME_STR_FORMAT: ClassVar[str] = '%Y-%m-%d %H:%M:%S' + + # global pydantic options + class Config: + # fields can be populated using their given name or provided alias + allow_population_by_field_name = True + field_serializers: Dict[str, FieldSerializer[M]] = {} + """ + Mapping of field name to callable that changes the default serialized form of a field. + This is often helpful when a field requires a use case specific representation (i.e. + datetime) or is not JSON serializable. + + Callables can be specified as either: + (value: T) -> R or + (self: M, value: T) -> R + where: + T is the field type + M is an instance of the Serializable subtype + R is the, json serializable, return type of the transformation + + Example: + + class Observation(Serializable): + value: float + value_time: datetime.datetime + value_unit: str + + class Config: + field_serializers = { + "value_time": lambda value_time: value_time.isoformat(timespec="seconds") + } + + o = Observation(value=42.0, value_time=datetime(2020, 1, 1), value_unit="m") + expect = {"value": 42.0, "value_time": "2020-01-01T00:00:00", "value_unit": "m"} + + assert o.dict() == expect + """ @classmethod def _get_invalid_type_message(cls): @@ -34,8 +114,7 @@ def _get_invalid_type_message(cls): return invalid_type_msg @classmethod - @abstractmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): + def factory_init_from_deserialized_json(cls: Type[Self], json_obj: dict) -> Optional[Self]: """ Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. @@ -47,7 +126,10 @@ def factory_init_from_deserialized_json(cls, json_obj: dict): ------- A new object of this type instantiated from the deserialize JSON object dictionary """ - pass + try: + return cls(**json_obj) + except: + return None @classmethod def get_datetime_str_format(cls): @@ -65,6 +147,7 @@ def get_datetime_str_format(cls): return cls._SERIAL_DATETIME_STR_FORMAT @classmethod + @deprecated("In the future this will be removed. Use pydantic type hints, validators, or root validators instead.") def parse_simple_serialized(cls, json_obj: dict, key: str, expected_type: Type, required_present: bool = True, converter: Callable = None): """ @@ -158,11 +241,11 @@ def parse_simple_serialized(cls, json_obj: dict, key: str, expected_type: Type, # If we get this far, then return the converted value return converted_value - @abstractmethod def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: """ Get the representation of this instance as a serialized dictionary or dictionary-like object (e.g., a JSON - object). + object). Field's are serialized using an alias, if provided. Field's that are `None` are + excluded from serialization. Since the returned value must be serializable and JSON-like, key and value types are restricted. In particular, the returned value type, which this docstring will call ``D``, must adhere to the criteria defined below: @@ -180,7 +263,31 @@ def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: The representation of this instance as a serialized dictionary or dictionary-like object, with valid types of keys and values. """ - pass + return self.dict(exclude_none=True, by_alias=True) + + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> "DictStrAny": + serial = super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + transformers = _collect_field_transformers(type(self)) + return _transform_fields(self, transformers, serial, by_alias=by_alias) def __str__(self): return str(self.to_json()) @@ -196,21 +303,50 @@ def to_json(self) -> str: """ return json.dumps(self.to_dict(), sort_keys=True) + @classmethod + def _get_value( + cls, + v: Any, + to_dict: bool, + by_alias: bool, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]], + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]], + exclude_unset: bool, + exclude_defaults: bool, + exclude_none: bool, + ) -> Any: + """ + Method used by pydantic to serialize field values. + + Override how `enum.Enum` subclasses are serialized by pydantic. Enums are serialized using + their member name, not their value. + """ + # serialize enum's using their name property + if isinstance(v, Enum) and not getattr(cls.Config, "use_enum_values", False): + return v.name + + return super()._get_value( + v, + to_dict=to_dict, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + include=include, + exclude=exclude, + exclude_none=exclude_none, + ) + class SerializedDict(Serializable): """ A basic encapsulation of a dictionary as a ::class:`Serializable`. """ + base_dict: dict @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - return cls(json_obj) - - def __init__(self, base_dict: dict): - self.base_dict = base_dict - - def to_dict(self) -> dict: - return self.base_dict + def factory_init_from_deserialized_json(cls: Self, json_obj: dict) -> Self: + # NOTE: could raise. return type has fewer constraints + return cls(**json_obj) class ResultIndicator(Serializable, ABC): @@ -236,18 +372,9 @@ class ResultIndicator(Serializable, ABC): An optional, more detailed explanation of the result, which by default is an empty string. """ - - def __init__(self, success: bool, reason: str, message: str = '', *args, **kwargs): - super(ResultIndicator, self).__init__(*args, **kwargs) - self.success: bool = success - """ Whether this indicates a successful result. """ - self.reason: str = reason - """ A very short, high-level summary of the result. """ - self.message: str = message - """ An optional, more detailed explanation of the result, which by default is an empty string. """ - - def to_dict(self) -> dict: - return {'success': self.success, 'reason': self.reason, 'message': self.message} + success: bool = Field(description="Whether this indicates a successful result.") + reason: str = Field(description="A very short, high-level summary of the result.") + message: str = Field("", description="An optional, more detailed explanation of the result, which by default is an empty string.") class BasicResultIndicator(ResultIndicator): @@ -255,12 +382,90 @@ class BasicResultIndicator(ResultIndicator): Bare-bones, concrete implementation of ::class:`ResultIndicator`. """ - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - try: - return cls(success=json_obj['success'], reason=json_obj['reason'], message=json_obj['message']) - except Exception as e: - return None - - def __init__(self, *args, **kwargs): - super(BasicResultIndicator, self).__init__(*args, **kwargs) \ No newline at end of file +# NOTE: function below are intentionally not methods on `Serializable` to avoid subclasses +# overriding their behavior. + +@lru_cache +def _collect_field_transformers(cls: Type[M]) -> Dict[str, FieldSerializer[M]]: + transformers: Dict[str, FieldSerializer[M]] = {} + + # base case + if cls == Serializable: + return transformers + + super_classes = cls.__mro__ + base_class_index = super_classes.index(Serializable) + + # index 0 is the calling cls try and merge `field_serializers` from superclasses up until + # Base class (stopping condition). merge in reverse order of mro so child class + # `field_serializers` override superclasses `field_serializers`. + for s in super_classes[1:base_class_index][::-1]: + if not issubclass(s, Serializable): + continue + + # doesn't have a Config class or Config.field_serializers + if not hasattr(s, "Config") and not hasattr(s.Config, "field_serializers"): + continue + + transformers.update(_collect_field_transformers(s)) + + # has Config class and Config.field_serializers + if hasattr(cls, "Config") and hasattr(cls.Config, "field_serializers"): + transformers.update(cls.Config.field_serializers) + + return transformers + + +def _get_field_alias(cls: Type[M], field_name: str) -> str: + # NOTE: KeyError will raise if field_name does not exist + return cls.__fields__[field_name].alias + + +def _transform_fields( + self: M, + transformers: Dict[str, FieldSerializer[M]], + serial: Dict[str, Any], + by_alias: bool = False, +) -> Dict[str, Any]: + for field, transform in transformers.items(): + if by_alias: + field = _get_field_alias(type(self), field) + + if field not in serial: + # TODO: field could have been excluded. need to consider what to do if invalid + # serial key was provided. + continue + + if not inspect.isfunction(transform): + error_message = ( + f"non-callable field_transformer provided for field {field!r}." + "\n\n" + "field_transformers should be specified as either:" + "\n" + "\t(value: T) -> R\n" + "\t(self: M, value: T) -> R\n" + "where:\n" + "\tT is the field type\n" + "\tM is an instance of the Serializable subtype\n" + "\tR is the, json serializable, return type of the transformation" + ) + raise ValueError(error_message) + + sig = inspect.signature(transform) + + if len(sig.parameters) == 1: + serial[field] = transform(serial[field]) + + elif len(sig.parameters) == 2: + serial[field] = transform(self, serial[field]) + + else: + error_message = ( + f"unsupported parameter length for field_transformer callable, {field!r}." + "\n\n" + "field_transformer's take either 1 or 2 parameters, (value: T) or (self, value: T),\n" + "where T is the type of the field." + ) + raise RuntimeError(error_message) + + return serial diff --git a/python/lib/core/dmod/test/test_data_requirement.py b/python/lib/core/dmod/test/test_data_requirement.py index 4cbf4e2c5..3d3ff52a2 100644 --- a/python/lib/core/dmod/test/test_data_requirement.py +++ b/python/lib/core/dmod/test/test_data_requirement.py @@ -30,15 +30,15 @@ def test_to_dict_0_a(self): requirement = self.example_reqs[ex] as_dict = requirement.to_dict() self.assertTrue(isinstance(as_dict, dict)) - self.assertTrue(DataRequirement._KEY_DOMAIN in as_dict) + self.assertTrue("domain" in as_dict) def test_to_dict_0_b(self): ex = 0 requirement = self.example_reqs[ex] as_dict = requirement.to_dict() self.assertTrue(requirement.is_input) - self.assertTrue(isinstance(as_dict[DataRequirement._KEY_IS_INPUT], bool)) - self.assertTrue(as_dict[DataRequirement._KEY_IS_INPUT]) + self.assertTrue(isinstance(as_dict["is_input"], bool)) + self.assertTrue(as_dict["is_input"]) def test_factory_init_from_deserialized_json_0_a(self): """ diff --git a/python/lib/core/dmod/test/test_dataset.py b/python/lib/core/dmod/test/test_dataset.py index 0548b395c..68f13ac4d 100644 --- a/python/lib/core/dmod/test/test_dataset.py +++ b/python/lib/core/dmod/test/test_dataset.py @@ -60,14 +60,14 @@ def setUp(self) -> None: discrete_restrictions=[self.example_catchment_restrictions[i]])) self.example_datasets.append(self._init_dataset_example(i)) date_fmt = Dataset.get_datetime_str_format() - self.example_data.append({Dataset._KEY_NAME: self.gen_dataset_name(i), - Dataset._KEY_DATA_DOMAIN: self.example_domains[i].to_dict(), - Dataset._KEY_DATA_CATEGORY: self.example_categories[i].name, - Dataset._KEY_TYPE: self.example_types[i].name, - Dataset._KEY_UUID: str(self.example_datasets[i].uuid), - Dataset._KEY_ACCESS_LOCATION: 'location_{}'.format(i), - Dataset._KEY_IS_READ_ONLY: False, - Dataset._KEY_CREATED_ON: self._created_on.strftime(date_fmt), + self.example_data.append({"name": self.gen_dataset_name(i), + "data_domain": self.example_domains[i].to_dict(), + "data_category": self.example_categories[i].name, + "type": self.example_types[i].name, + "uuid": str(self.example_datasets[i].uuid), + "access_location": 'location_{}'.format(i), + "is_read_only": False, + "created_on": self._created_on, # NOTE: breaking change }) def test_factory_init_from_deserialized_json_0_a(self): diff --git a/python/lib/core/dmod/test/test_decorator.py b/python/lib/core/dmod/test/test_decorator.py new file mode 100644 index 000000000..02b1fd4d2 --- /dev/null +++ b/python/lib/core/dmod/test/test_decorator.py @@ -0,0 +1,13 @@ +import unittest +from ..core.decorators import deprecated + +DEPRECATION_MESSAGE = "test is deprecated" + +@deprecated(DEPRECATION_MESSAGE) +def deprecated_function(): + ... + +class TestDeprecatedDecorator(unittest.TestCase): + def test_raises_deprecated_warning(self): + with self.assertWarns(DeprecationWarning): + deprecated_function() diff --git a/python/lib/core/dmod/test/test_enum.py b/python/lib/core/dmod/test/test_enum.py new file mode 100644 index 000000000..8d6e6834c --- /dev/null +++ b/python/lib/core/dmod/test/test_enum.py @@ -0,0 +1,44 @@ +import unittest +import enum +from pydantic import BaseModel + +from ..core.enum import PydanticEnum + + +class SomeEnum(PydanticEnum): + foo = 1 + bar = 2 + baz = 3 + + +class SomeModel(BaseModel): + some_enum: SomeEnum + + +class TestEnumValidateByNameMixIn(unittest.TestCase): + def test_instantiate_model_with_enum_field_name(self): + model = SomeModel(some_enum="foo") + self.assertEqual(model.some_enum, SomeEnum.foo) + + def test_instantiate_model_with_enum_instance(self): + model = SomeModel(some_enum=SomeEnum.foo) + self.assertEqual(model.some_enum, SomeEnum.foo) + + def test_raises_ValueError_instantiate_model_with_bad_enum_field_name(self): + with self.assertRaises(ValueError): + SomeModel(some_enum="missing_field") + + def test_raises_ValueError_instantiate_model_with_bad_enum_instance(self): + class BadEnum(enum.Enum): + bad = 1 + + with self.assertRaises(ValueError): + SomeModel(some_enum=BadEnum.bad) + + def test_enum_names_in_json_schema(self): + schema = SomeModel.schema() + some_enum_schema = schema["definitions"]["SomeEnum"] + self.assertEqual(some_enum_schema["type"], "string") + + enum_field_names = [member.name.upper() for member in SomeEnum] + self.assertListEqual(enum_field_names, some_enum_schema["enum"]) diff --git a/python/lib/core/dmod/test/test_meta_data.py b/python/lib/core/dmod/test/test_meta_data.py new file mode 100644 index 000000000..b3c9a43e1 --- /dev/null +++ b/python/lib/core/dmod/test/test_meta_data.py @@ -0,0 +1,241 @@ +import unittest +from datetime import datetime + +from ..core.meta_data import ( + ContinuousRestriction, + DiscreteRestriction, + StandardDatasetIndex, + DataDomain, + DataFormat, + TimeRange, + DataCategory, + DataRequirement, +) + +from typing import Any + + +class TestContinuousRestriction(unittest.TestCase): + def test_custom_datetime_pattern(self): + o = ContinuousRestriction( + begin="2020-01-01", + end="2020-01-02", + variable="TIME", + datetime_pattern="%Y-%m-%d", + ) + self.assertEqual(o.variable, StandardDatasetIndex.TIME) + + def test_custom_datetime_pattern_should_fail(self): + with self.assertRaises(RuntimeError): + ContinuousRestriction( + begin="2020-01-01", + end="2019-12-31", + variable="TIME", + datetime_pattern="%Y-%m-%d", + ) + + def test_create_from_python_objects(self): + begin = datetime(2020, 1, 1) + end = datetime(2020, 1, 2) + o = ContinuousRestriction( + begin=begin, end=end, variable=StandardDatasetIndex.TIME + ) + self.assertEqual(o.begin, begin) + self.assertEqual(o.end, end) + self.assertEqual(o.variable, StandardDatasetIndex.TIME) + + def test_create_fails_with_invalid_variable(self): + begin = datetime(2020, 1, 1) + end = datetime(2020, 1, 2) + with self.assertRaises(ValueError): + ContinuousRestriction( + begin=begin, end=end, variable=StandardDatasetIndex.UNKNOWN + ) + + def test_eq(self): + begin = datetime(2020, 1, 1) + end = datetime(2020, 1, 2) + o1 = ContinuousRestriction( + begin=begin, end=end, variable=StandardDatasetIndex.TIME + ) + o2 = ContinuousRestriction( + begin=begin, end=end, variable=StandardDatasetIndex.TIME + ) + + self.assertEqual(o1, o2) + + def test_hash(self): + begin = datetime(2020, 1, 1) + end = datetime(2020, 1, 2) + var = StandardDatasetIndex.TIME + expected_hash = hash(f"{var.name}-{begin}-{end}") + o_hash = hash(ContinuousRestriction(variable=var, begin=begin, end=end)) + self.assertEqual(expected_hash, o_hash) + + def test_to_dict(self): + begin = "2020-01-01" + end = "2020-01-02" + d = ContinuousRestriction( + begin=begin, + end=end, + variable="TIME", + datetime_pattern="%Y-%m-%d", + ).to_dict() + self.assertEqual(d["begin"], begin) + self.assertEqual(d["end"], end) + self.assertEqual(d["variable"], StandardDatasetIndex.TIME.name) + + def test_factory_init_from_deserialized_json(self): + deserialied = {"begin": 0, "end": 1, "variable": "TIME"} + o1 = ContinuousRestriction.factory_init_from_deserialized_json(deserialied) + + deserialied = { + "begin": 0, + "end": 1, + "variable": "TIME", + "subclass": "ContinuousRestriction", + } + o2 = ContinuousRestriction.factory_init_from_deserialized_json(deserialied) + self.assertEqual(o1, o2) + + def test_to_json(self): + import json + begin = "2020-01-01" + end = "2020-01-02" + d = json.loads( + ContinuousRestriction( + begin=begin, + end=end, + variable="TIME", + datetime_pattern="%Y-%m-%d", + ).to_json() + ) + + self.assertEqual(d["begin"], begin) + self.assertEqual(d["end"], end) + self.assertEqual(d["variable"], StandardDatasetIndex.TIME.name) + + +class TestDiscreteRestriction(unittest.TestCase): + def test_duplicate_values_removed(self): + o = DiscreteRestriction( + variable="TIME", values=[1, 1, 1], remove_duplicates=True + ) + self.assertListEqual(o.values, [1]) + + def test_values_reordered(self): + values = [3, 2, 1] + o = DiscreteRestriction(variable="TIME", values=values, allow_reorder=True) + self.assertListEqual(o.values, values[::-1]) + + def test_values_removed_not_reordered(self): + values = [3, 3, 2, 1] + o = DiscreteRestriction( + variable="TIME", + values=values, + allow_reorder=False, + remove_duplicates=True, + ) + self.assertListEqual(o.values, values[1:]) + + def test_values_reordered_not_removed(self): + values = [3, 3, 2, 1] + o = DiscreteRestriction( + variable="TIME", + values=values, + allow_reorder=True, + remove_duplicates=False, + ) + self.assertListEqual(o.values, values[::-1]) + + +class TestDataDomain(unittest.TestCase): + def test_it_works(self): + disc_rest = DiscreteRestriction( + variable=StandardDatasetIndex.DATA_ID, values=["0"] + ) + o = DataDomain( + data_format=DataFormat.AORC_CSV, + discrete_restrictions=[disc_rest], + data_fields=dict(a="str", b="float", c="int", d="datetime"), + ) + self.assertEqual(o.custom_data_fields["a"], str) + self.assertEqual(o.custom_data_fields["b"], float) + self.assertEqual(o.custom_data_fields["c"], int) + self.assertEqual(o.custom_data_fields["d"], Any) + + def test_init_fails_if_insufficient_restrictions(self): + with self.assertRaises(RuntimeError): + DataDomain( + data_format=DataFormat.AORC_CSV, + continuous_restrictions=[], + discrete_restrictions=[], + ) + + with self.assertRaises(RuntimeError): + DataDomain(data_format=DataFormat.AORC_CSV) + + def test_factory_init_from_deserialized_json(self): + data = { + "data_format": "AORC_CSV", + "continuous_restrictions": [], + "discrete_restrictions": [{"variable": "DATA_ID", "values": ["0"]}], + } + o = DataDomain.factory_init_from_deserialized_json(data) + self.assertEqual(o.data_format.name, "AORC_CSV") + + def test_to_dict(self): + input_data_fields = {"a": "int", "b": "float", "c": "bool", "d": "str", "e": "flux_capacitor"} + expected_serialized_data_fields = {"a": "int", "b": "float", "c": "bool", "d": "str", "e": "Any"} + data = { + # NOTE: NGEN_OUTPUT data_fields = None. + "data_format": "NGEN_OUTPUT", + "continuous": [], + "discrete": [{"variable": "DATA_ID", "values": ["0"]}], + } + input_data = data.copy() + input_data["data_fields"] = input_data_fields + + expected_data = data.copy() + expected_data["data_fields"] = expected_serialized_data_fields + + # better error detection if this fails + o = DataDomain(**input_data) + serial = o.to_dict() + self.assertDictEqual(serial, expected_data) + + def test_factory_init_from_restriction_collections(self): + catchment_id = ["12"] + o = DataDomain.factory_init_from_restriction_collections(data_format=DataFormat.AORC_CSV, CATCHMENT_ID=catchment_id) + self.assertListEqual(o.discrete_restrictions[0].values, catchment_id) + + def test_factory_init_from_restriction_collections_fail_for_mismatching_index_field(self): + with self.assertRaises(RuntimeError): + DataDomain.factory_init_from_restriction_collections(data_format=DataFormat.AORC_CSV, DATA_ID=["12"]) + + +class TestTimeRange(unittest.TestCase): + def test_begin_cannot_come_after_end(self): + with self.assertRaises(RuntimeError): + TimeRange(begin=1, end=0) + + def test_cannot_provide_non_time_variable(self): + with self.assertRaises(RuntimeError): + TimeRange(variable=StandardDatasetIndex.DATA_ID, begin=1, end=0) + +class TestDataRequirement(unittest.TestCase): + def test_unset_fields_are_excluded_in_serialized_dict(self): + domain = DataDomain( + data_format=DataFormat.AORC_CSV, + discrete_restrictions=[ + DiscreteRestriction(variable=StandardDatasetIndex.DATA_ID, values=["0"]) + ], + ) + + d = DataRequirement( + domain=domain, is_input=True, category=DataCategory.CONFIG + ).to_dict() + self.assertNotIn("size", d) + self.assertNotIn("fulfilled_by", d) + self.assertNotIn("fulfilled_access_at", d) + diff --git a/python/lib/core/dmod/test/test_serializable_field_serializers.py b/python/lib/core/dmod/test/test_serializable_field_serializers.py new file mode 100644 index 000000000..648df8bda --- /dev/null +++ b/python/lib/core/dmod/test/test_serializable_field_serializers.py @@ -0,0 +1,279 @@ +import unittest +from typing import List +from pydantic import SecretStr +from datetime import date + +from ..core.serializable import Serializable + + +class Country(Serializable): + name: str + phone_code: int + + class Config: + field_serializers = {"name": lambda s: s.upper()} + + +class Address(Serializable): + post_code: int + country: Country + + +class CardDetails(Serializable): + number: SecretStr + expires: date + + +class Hobby(Serializable): + name: str + info: str + + class Config: + fields = {"name": {"alias": "NAME"}} + + +class User(Serializable): + first_name: str + second_name: str + address: Address + card_details: CardDetails + hobbies: List[Hobby] + + class Config: + field_serializers = {"first_name": lambda f: f.upper()} + + +class A(Serializable): + field: str + + class Config: + field_serializers = {"field": lambda s: s.lower()} + + +class B(Serializable): + field: str + + class Config: + field_serializers = {"field": lambda s: s.upper()} + + +class C(Serializable): + a: A + b: B + + +class D(Serializable): + field: str + + class Config: + field_serializers = {"field": lambda s: s.upper()} + + +class E(D): + class Config: + field_serializers = {"field": lambda s: s} + + +class F(Serializable): + a: str + + class Config: + field_serializers = {"a": lambda s: s.lower()} + + +class G(Serializable): + b: str + + class Config: + field_serializers = {"b": lambda s: s.upper()} + + +class H(F, G): + ... + + +class I(G, F): + ... + + +class J(Serializable): + a: int + + +class K(Serializable): + j: J + + class Config: + field_serializers = {"j": lambda self, _: self.j.a} + + +class L(Serializable): + a: str + + class Config: + field_serializers = {"a": 12} + + +class M(Serializable): + a: str + + class Config: + field_serializers = {"a": lambda a, b, c: (a, b, c)} + + +class N(Serializable): + a: str + + class Config: + field_serializers = {"a": lambda: "should fail"} + + +class RootModel(Serializable): + __root__: int + + class Config: + field_serializers = {"__root__": lambda s: s ** 2} + + +def user_fixture() -> User: + return User( + first_name="John", + second_name="Doe", + address=Address(post_code=123456, country=Country(name="usa", phone_code=1)), + card_details=CardDetails(number=4212934504460000, expires=date(2020, 5, 1)), + hobbies=[ + Hobby(name="Programming", info="Writing code and stuff"), + Hobby(name="Gaming", info="Hell Yeah!!!"), + ], + ) + + +class TestFieldSerializerConfigOption(unittest.TestCase): + def test_exclude_keys_User(self): + user = user_fixture() + + exclude_keys = { + "second_name": True, + "address": {"post_code": True, "country": {"phone_code"}}, + "card_details": True, + # You can exclude fields from specific members of a tuple/list by index: + "hobbies": {-1: {"info"}}, + } + + expect = { + "first_name": "JOHN", + "address": {"country": {"name": "USA"}}, + "hobbies": [ + { + "name": "Programming", + "info": "Writing code and stuff", + }, + {"name": "Gaming"}, + ], + } + + self.assertDictEqual(user.dict(exclude=exclude_keys), expect) + + def test_include_keys_User(self): + user = user_fixture() + + include_keys = { + "first_name": True, + "address": {"country": {"name"}}, + "hobbies": {0: True, -1: {"name"}}, + } + + expect = { + "first_name": "JOHN", + "address": {"country": {"name": "USA"}}, + "hobbies": [ + { + "name": "Programming", + "info": "Writing code and stuff", + }, + {"name": "Gaming"}, + ], + } + + self.assertDictEqual(user.dict(include=include_keys), expect) + + def test_exclude_keys_by_alias_User(self): + user = user_fixture() + + exclude_keys = { + "second_name": True, + "address": {"post_code": True, "country": {"phone_code"}}, + "card_details": True, + # You can exclude fields from specific members of a tuple/list by index: + "hobbies": {-1: {"info"}}, + } + + expect = { + "first_name": "JOHN", + "address": {"country": {"name": "USA"}}, + "hobbies": [ + { + "NAME": "Programming", + "info": "Writing code and stuff", + }, + {"NAME": "Gaming"}, + ], + } + + self.assertDictEqual(user.dict(exclude=exclude_keys, by_alias=True), expect) + + def test_composed_fields_dont_mangle_C(self): + o = C(a=A(field="A"), b=B(field="b")) + + expect = {"a": {"field": "a"}, "b": {"field": "B"}} + self.assertDictEqual(o.dict(), expect) + + def test_override_in_subclass_D_E(self): + o = D(field="a") + self.assertEqual(o.dict()["field"], "A") + + subclass_o = E(field="a") + + self.assertEqual(subclass_o.dict()["field"], "a") + + def test_root_model_RootModel(self): + o = RootModel(__root__=12) + self.assertEqual(o.dict()["__root__"], 144) + + def test_multi_inheritance_H_I(self): + # H(F, G) + h = H(a="a", b="b") + # I(G, H) + i = I(a="a", b="b") + + expect = { + "a": "a", + "b": "B", + } + + self.assertDictEqual(h.dict(), expect) + self.assertDictEqual(i.dict(), expect) + + def test_pull_up_K(self): + o = K(j=J(a=12)) + + expect = {"j": 12} + + self.assertDictEqual(o.dict(), expect) + + def test_raises_value_error_L(self): + o = L(a="a") + with self.assertRaises(ValueError): + o.dict() + + def test_raises_runtime_error_too_many_params_M(self): + o = M(a="a") + + with self.assertRaises(RuntimeError): + o.dict() + + def test_raises_runtime_error_too_few_params_N(self): + o = N(a="a") + + with self.assertRaises(RuntimeError): + o.dict() diff --git a/python/lib/core/setup.py b/python/lib/core/setup.py index 3f485cccf..e69c55a56 100644 --- a/python/lib/core/setup.py +++ b/python/lib/core/setup.py @@ -20,6 +20,6 @@ author_email='', url='', license='', - install_requires=[], + install_requires=["pydantic"], packages=find_namespace_packages(exclude=['dmod.test', 'schemas', 'ssl', 'src']) ) diff --git a/python/lib/modeldata/dmod/modeldata/data/object_store_manager.py b/python/lib/modeldata/dmod/modeldata/data/object_store_manager.py index dffb5a8cf..33045f461 100644 --- a/python/lib/modeldata/dmod/modeldata/data/object_store_manager.py +++ b/python/lib/modeldata/dmod/modeldata/data/object_store_manager.py @@ -560,8 +560,8 @@ def reload(self, reload_from: str, serialized_item: Optional[str] = None) -> Dat response_obj.release_conn() # If we can safely infer it, make sure the "type" key is set in cases when it is missing - if len(self.supported_dataset_types) == 1 and Dataset._KEY_TYPE not in response_data: - response_data[Dataset._KEY_TYPE] = list(self.supported_dataset_types)[0].name + if len(self.supported_dataset_types) == 1 and "type" not in response_data: + response_data["type"] = list(self.supported_dataset_types)[0].name dataset = Dataset.factory_init_from_deserialized_json(response_data) dataset.manager = self diff --git a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py index 54ba9ccd9..0b9f22255 100644 --- a/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py +++ b/python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py @@ -1,7 +1,10 @@ -from numbers import Number -from typing import Collection, Dict, FrozenSet, List, Union +from typing import Collection, FrozenSet, List, Optional, TYPE_CHECKING, Union +from pydantic import Field, PrivateAttr, validator from dmod.core.serializable import Serializable +if TYPE_CHECKING: + from pydantic.typing import AbstractSetIntStr, DictStrAny, MappingIntStrAny + class Partition(Serializable): """ @@ -13,56 +16,84 @@ class Partition(Serializable): in the context of the related hydrofabric. """ - __slots__ = ["_catchment_ids", "_hash_val", "_nexus_ids", "_partition_id", "_remote_downstream_nexus_ids", - "_remote_upstream_nexus_ids"] + partition_id: int + catchment_ids: FrozenSet[str] + nexus_ids: FrozenSet[str] + """ + Note that, at the time this is committed, partition ids should always be integers. This is so they can easily + correspond to MPI ranks. However, because of how the expected + """ + remote_upstream_nexus_ids: FrozenSet[str] = Field(default_factory=frozenset) + remote_downstream_nexus_ids: FrozenSet[str] = Field(default_factory=frozenset) + + _hash_val: Optional[int] = PrivateAttr(None) + + class Config: + fields = { + "catchment_ids": {"alias": "cat-ids"}, + "partition_id": {"alias": "id"}, + "nexus_ids": {"alias": "nex-ids"}, + "remote_upstream_nexus_ids": {"alias": "remote-up"}, + "remote_downstream_nexus_ids": {"alias": "remote-down"}, + } + + def _serialize_frozenset(value: FrozenSet[str]) -> List[str]: + return list(value) - _KEY_CATCHMENT_IDS = 'cat-ids' - _KEY_PARTITION_ID = 'id' - # Note that these need to be included in the JSON, but initially aren't actually used at the JSON level - _KEY_NEXUS_IDS = 'nex-ids' - _KEY_REMOTE_UPSTREAM_NEXUS_IDS = 'remote-up' - _KEY_REMOTE_DOWNSTREAM_NEXUS_IDS = 'remote-down' + field_serializers = { + "catchment_ids": _serialize_frozenset, + "nexus_ids": _serialize_frozenset, + "remote_upstream_nexus_ids": _serialize_frozenset, + "remote_downstream_nexus_ids": _serialize_frozenset, + } - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - try: - # TODO: later these may be required, but for now, keep optional - if cls._KEY_REMOTE_UPSTREAM_NEXUS_IDS in json_obj: - remote_up = json_obj[cls._KEY_REMOTE_UPSTREAM_NEXUS_IDS] - else: - remote_up = [] - if cls._KEY_REMOTE_DOWNSTREAM_NEXUS_IDS in json_obj: - remote_down = json_obj[cls._KEY_REMOTE_UPSTREAM_NEXUS_IDS] - else: - remote_down = [] - return Partition(catchment_ids=json_obj[cls._KEY_CATCHMENT_IDS], nexus_ids=json_obj[cls._KEY_NEXUS_IDS], - remote_up_nexuses=remote_up, remote_down_nexuses=remote_down, - partition_id=int(json_obj[cls._KEY_PARTITION_ID])) - except: - return None - - def __init__(self, partition_id: int, catchment_ids: Collection[str], nexus_ids: Collection[str], - remote_up_nexuses: Collection[str] = tuple(), remote_down_nexuses: Collection[str] = tuple()): - self._partition_id = partition_id - self._catchment_ids = frozenset(catchment_ids) - self._nexus_ids = frozenset(nexus_ids) - self._remote_upstream_nexus_ids = frozenset(remote_up_nexuses) - self._remote_downstream_nexus_ids = frozenset(remote_down_nexuses) - - self._hash_val = None - - def __eq__(self, other): + def __init__( + self, + # required, but for backwards compatibility, None + partition_id: int = None, + catchment_ids: Collection[str] = None, + nexus_ids: Collection[str] = None, + # non-required fields + remote_up_nexuses: Collection[str] = None, + remote_down_nexuses: Collection[str] = None, + **data + ): + # if data exists, assume fields specified using their alias; no backwards compatibility. + if data: + super().__init__(**data) + return + + + if remote_up_nexuses is None or remote_down_nexuses is None: + super().__init__( + partition_id=partition_id, + catchment_ids=catchment_ids, + nexus_ids=nexus_ids, + **data + ) + return + + super().__init__( + partition_id=partition_id, + catchment_ids=catchment_ids, + nexus_ids=nexus_ids, + remote_upstream_nexus_ids=remote_up_nexuses, + remote_downstream_nexus_ids=remote_down_nexuses + ) + + + def __eq__(self, other: object): if not isinstance(other, self.__class__) or other.partition_id != self.partition_id: return False else: return other.__hash__() == self.__hash__() - def __lt__(self, other): + def __lt__(self, other: "Partition"): # Go first by id, so this is clearly true - if self._partition_id < other._partition_id: + if self.partition_id < other.partition_id: return True # Again, going by id first, having greater id is also clear - elif self._partition_id > other._partition_id: + elif self.partition_id > other.partition_id: return False # Also can't be (strictly) less-than AND equal-to elif self == other: @@ -79,116 +110,37 @@ def __hash__(self): self._hash_val = hash(','.join(cat_id_list)) return self._hash_val - @property - def catchment_ids(self) -> FrozenSet[str]: - """ - Get the frozen set of ids for all catchments in this partition. - - Returns - ------- - Set[str] - The frozen set of string ids for all catchments in this partition. - """ - return self._catchment_ids - - @property - def nexus_ids(self) -> FrozenSet[str]: - """ - Get the frozen set of ids for all nexuses in this partition. - - Returns - ------- - Set[str] - The frozen set of string ids for all nexuses in this partition. - """ - return self._nexus_ids - - @property - def partition_id(self) -> int: - """ - Get the id of this partition. - - Note that, at the time this is committed, partition ids should always be integers. This is so they can easily - correspond to MPI ranks. However, because of how the expected - - Returns - ------- - str - The id of this partition, as a string. - """ - return self._partition_id - - @property - def remote_downstream_nexus_ids(self) -> FrozenSet[str]: - """ - Get the frozen set of ids for all remote downstream nexuses in this partition. - - Returns - ------- - Set[str] - The frozen set of string ids for all remote downstream nexuses in this partition. - """ - return self._remote_downstream_nexus_ids - - @property - def remote_upstream_nexus_ids(self) -> FrozenSet[str]: - """ - Get the frozen set of ids for all remote upstream nexuses in this partition. - - Returns - ------- - Set[str] - The frozen set of string ids for all remote upstream nexuses in this partition. - """ - return self._remote_upstream_nexus_ids - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - """ - Get the instance represented as a dict (i.e., a JSON-like object). - - Note that, as described in the main docstring for the class, there are extra keys in the dict/JSON currently - that don't correspond to any attributes of the instance. This is for consistency with other tools. - - Returns - ------- - dict - The instance as a dict - """ - return { - self._KEY_PARTITION_ID: str(self.partition_id), - self._KEY_CATCHMENT_IDS: list(self.catchment_ids), - self._KEY_NEXUS_IDS: list(self.nexus_ids), - self._KEY_REMOTE_UPSTREAM_NEXUS_IDS: list(self.remote_upstream_nexus_ids), - self._KEY_REMOTE_DOWNSTREAM_NEXUS_IDS: list(self.remote_downstream_nexus_ids) - } - - class PartitionConfig(Serializable): """ A type to easily encapsulate the JSON object that is output from the NextGen partitioner. """ - _KEY_PARTITIONS = 'partitions' + partitions: FrozenSet[Partition] - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - try: - return PartitionConfig([Partition.factory_init_from_deserialized_json(serial_p) for serial_p in json_obj[cls._KEY_PARTITIONS]]) - except: - return None + @validator("partitions") + def _sort_partitions(cls, value: FrozenSet[Partition]) -> FrozenSet[Partition]: + return frozenset(sorted(value)) + + class Config: + def _serialize_frozenset(value: FrozenSet[Partition]) -> List[Partition]: + return list(value) + + field_serializers = { + "partitions": _serialize_frozenset + } @classmethod def get_serial_property_key_partitions(cls) -> str: - return cls._KEY_PARTITIONS + return "partitions" - def __init__(self, partitions: Collection[Partition]): - self._partitions = frozenset(partitions) + def __init__(self, partitions: Collection[Partition], **data): + super().__init__(partitions=partitions, **data) - def __eq__(self, other): + def __eq__(self, other: object): if not isinstance(other, PartitionConfig): return False other_partitions_dict = dict() - for other_p in other._partitions: + for other_p in other.partitions: other_partitions_dict[other_p.partition_id] = other_p other_pids = set([p2.partition_id for p2 in other.partitions]) @@ -197,7 +149,7 @@ def __eq__(self, other): return False return True - def __hash__(self): + def __hash__(self) -> int: """ Get the unique hash for this instance. @@ -206,22 +158,50 @@ def __hash__(self): Returns ------- - - """ - # - return hash(','.join([str(p.__hash__()) for p in sorted(self._partitions)])) - - @property - def partitions(self) -> List[Partition]: - """ - Get the (sorted) list of partitions for this config. - - Returns - ------- - List[Partition] - The (sorted) list of partitions for this config. - """ - return sorted(self._partitions) - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - return {self._KEY_PARTITIONS: [p.to_dict() for p in self.partitions]} + int + Hash of instance + """ + return hash(",".join([str(p.__hash__()) for p in sorted(self.partitions)])) + + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False + ) -> "DictStrAny": + # reasons why dict is overridden here: + # pydantic will serialize from inner types outward, serializing each type as a dictionary, + # list, or primitive and replacing its previous type with the new "serialized" type. + # Consequently, this means hashable container types like tuples and frozensets that contain + # values that "serialize" to a non-hashable type (non-primitive, in this case) will raise a + # `TypeError: unhashable type: 'dict'`. In the case of PartitionConfig, + # FronzenSet[Partition] "serializes" inner Partition types as dictionaries which are not + # hashable. To get around this, we will momentarily swap the `partitions` field for a + # non-hashable container type, serialize using `.dict()`, and swap back in the original + # `partitions` container. + + # 1. take a reference to partitions: FrozenSet[Partition] + partitions = self.partitions + + # 2. cast and set partitions to a list, a non-hashable container type + self.partitions = list(partitions) + + # 3. serialize + serial = super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + # 4. replace partitions with its hashable representation + self.partitions = partitions + return serial diff --git a/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py b/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py index b059b4b72..283b99020 100644 --- a/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py +++ b/python/lib/modeldata/dmod/modeldata/subset/hydrofabric_subset.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from hypy import Catchment, Nexus -from typing import Collection, Optional, Sequence, Set, Tuple, Union +from typing import Collection, Optional, Set, Tuple +from pydantic import PrivateAttr from ..hydrofabric import Hydrofabric from .subset_definition import SubsetDefinition @@ -22,17 +23,19 @@ class HydrofabricSubset(SubsetDefinition, ABC): made in the case of invalid objects. In such cases, the hash is equal to the super class hash output plus ``1``. """ - __slots__ = ["_hydrofabric"] + hydrofabric: Hydrofabric - def __init__(self, catchment_ids: Collection[str], nexus_ids: Collection[str], hydrofabric: Hydrofabric): - super().__init__(catchment_ids=catchment_ids, nexus_ids=nexus_ids) + class Config: + arbitrary_types_allowed = True + + def __init__(self, catchment_ids: Collection[str], nexus_ids: Collection[str], hydrofabric: Hydrofabric, **data): + super().__init__(catchment_ids=catchment_ids, nexus_ids=nexus_ids, hydrofabric=hydrofabric, **data) if not self.validate_hydrofabric(hydrofabric): raise RuntimeError("Insufficient or wrongly formatted hydrofabric when trying to create {} object".format( self.__class__.__name__ )) - self._hydrofabric = hydrofabric - def __eq__(self, other): + def __eq__(self, other: object): if isinstance(other, self.__class__): return self.validate_hydrofabric() == other.validate_hydrofabric() and super().__eq__(other) else: @@ -46,7 +49,7 @@ def __hash__(self): @property @abstractmethod - def catchments(self) -> Tuple[Catchment]: + def catchments(self) -> Tuple[Catchment, ...]: """ Get the associated catchments as ::class:`Catchment` objects. @@ -59,7 +62,7 @@ def catchments(self) -> Tuple[Catchment]: @property @abstractmethod - def nexuses(self) -> Tuple[Nexus]: + def nexuses(self) -> Tuple[Nexus, ...]: """ Get the associated nexuses as ::class:`Nexus` objects. @@ -100,6 +103,9 @@ class SimpleHydrofabricSubset(HydrofabricSubset): Simple ::class:`HydrofabricSubset` type. """ + _catchments: Set[Catchment] = PrivateAttr(default_factory=set) + _nexuses: Set[Nexus] = PrivateAttr(default_factory=set) + @classmethod def factory_create_from_base_and_hydrofabric(cls, subset_def: SubsetDefinition, hydrofabric: Hydrofabric, *args, **kwargs) \ @@ -127,17 +133,12 @@ def factory_create_from_base_and_hydrofabric(cls, subset_def: SubsetDefinition, return cls(catchment_ids=subset_def.catchment_ids, nexus_ids=subset_def.nexus_ids, hydrofabric=hydrofabric, *args, **kwargs) - __slots__ = ["_catchments", "_nexuses"] - - def __init__(self, catchment_ids: Collection[str], nexus_ids: Collection[str], hydrofabric: Hydrofabric, *args, - **kwargs): - self._catchments: Set[Catchment] = set() - self._nexuses: Set[Nexus] = set() - super().__init__(catchment_ids, nexus_ids, hydrofabric) + def __init__(self, catchment_ids: Collection[str], nexus_ids: Collection[str], hydrofabric: Hydrofabric, **data): + super().__init__(catchment_ids=catchment_ids, nexus_ids=nexus_ids, hydrofabric=hydrofabric, **data) # Since super __init__ validates, and validate function make sure ids are recognized, these won't ever be None - for cid in catchment_ids: + for cid in self.catchment_ids: self._catchments.add(hydrofabric.get_catchment_by_id(cid)) - for nid in nexus_ids: + for nid in self.nexus_ids: self._nexuses.add(hydrofabric.get_nexus_by_id(nid)) @property @@ -147,19 +148,19 @@ def catchments(self) -> Tuple[Catchment]: Returns ------- - Tuple[Catchment] + Tuple[Catchment, ...] The associated catchments as ::class:`Catchment` objects. """ return tuple(self._catchments) @property - def nexuses(self) -> Tuple[Nexus]: + def nexuses(self) -> Tuple[Nexus, ...]: """ Get the associated nexuses as ::class:`Nexus` objects. Returns ------- - Tuple[Catchment] + Tuple[Catchment, ...] The associated nexuses as ::class:`Nexus` objects. """ return tuple(self._nexuses) @@ -184,11 +185,11 @@ def validate_hydrofabric(self, hydrofabric: Optional[Hydrofabric] = None) -> boo otherwise. """ if hydrofabric is None: - hydrofabric = self._hydrofabric - for cid in self._catchment_ids: + hydrofabric = self.hydrofabric + for cid in self.catchment_ids: if not hydrofabric.is_catchment_recognized(cid): return False - for nid in self._nexus_ids: + for nid in self.nexus_ids: if not hydrofabric.is_nexus_recognized(nid): return False return True diff --git a/python/lib/modeldata/dmod/modeldata/subset/subset_definition.py b/python/lib/modeldata/dmod/modeldata/subset/subset_definition.py index d9b14f25b..baa128fea 100644 --- a/python/lib/modeldata/dmod/modeldata/subset/subset_definition.py +++ b/python/lib/modeldata/dmod/modeldata/subset/subset_definition.py @@ -1,5 +1,5 @@ -from numbers import Number -from typing import Collection, Tuple, Dict, Union +from typing import Collection, Tuple +from pydantic import validator from dmod.core.serializable import Serializable @@ -13,34 +13,29 @@ class SubsetDefinition(Serializable): to be immutable. """ - __slots__ = ["_catchment_ids", "_nexus_ids"] + catchment_ids: Tuple[str, ...] + nexus_ids: Tuple[str, ...] - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - try: - return cls(**json_obj) - except Exception as e: - return None + @validator("catchment_ids", "nexus_ids") + def _sort_and_dedupe_fields(cls, value: Tuple[str, ...]) -> Tuple[str, ...]: + return tuple(sorted(set(value))) - def __init__(self, catchment_ids: Collection[str], nexus_ids: Collection[str]): - self._catchment_ids = tuple(sorted(set(catchment_ids))) - self._nexus_ids = tuple(sorted(set(nexus_ids))) + def __init__(self, catchment_ids: Collection[str], nexus_ids: Collection[str], **data): + super().__init__(catchment_ids=catchment_ids, nexus_ids=nexus_ids, **data) - def __eq__(self, other): - return isinstance(other, SubsetDefinition) \ - and self.catchment_ids == other.catchment_ids \ - and self.nexus_ids == other.nexus_ids + def __eq__(self, other: object): + return ( + isinstance(other, SubsetDefinition) + and self.catchment_ids == other.catchment_ids + and self.nexus_ids == other.nexus_ids + ) def __hash__(self): - joined_cats = ','.join(self.catchment_ids) - joined_nexs = ','.join(self.nexus_ids) - joined_all = ','.join((joined_cats, joined_nexs)) + joined_cats = ",".join(self.catchment_ids) + joined_nexs = ",".join(self.nexus_ids) + joined_all = ",".join((joined_cats, joined_nexs)) return hash(joined_all) - @property - def catchment_ids(self) -> Tuple[str]: - return self._catchment_ids - @property def id(self): """ @@ -53,10 +48,3 @@ def id(self): The unique id of this instance. """ return self.__hash__() - - @property - def nexus_ids(self) -> Tuple[str]: - return self._nexus_ids - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - return {'catchment_ids': list(self.catchment_ids), 'nexus_ids': list(self.nexus_ids)} diff --git a/python/lib/modeldata/dmod/test/it_object_store_dataset_manager.py b/python/lib/modeldata/dmod/test/it_object_store_dataset_manager.py index 3af8d9ddb..3fec85778 100644 --- a/python/lib/modeldata/dmod/test/it_object_store_dataset_manager.py +++ b/python/lib/modeldata/dmod/test/it_object_store_dataset_manager.py @@ -253,7 +253,7 @@ def test_get_data_1_b(self): data_dict = json.loads(self.manager.get_data(dataset_name, item_name=serial_file_name).decode()) - self.assertEqual(dataset_name, data_dict[Dataset._KEY_NAME]) + self.assertEqual(dataset_name, data_dict["name"]) def test_list_files_1_a(self): """ diff --git a/python/lib/modeldata/dmod/test/test_partition.py b/python/lib/modeldata/dmod/test/test_partition.py new file mode 100644 index 000000000..be9fa6f7f --- /dev/null +++ b/python/lib/modeldata/dmod/test/test_partition.py @@ -0,0 +1,213 @@ +import unittest +from ..modeldata.hydrofabric.partition import Partition, PartitionConfig + + +class TestPartition(unittest.TestCase): + partition_instance = Partition( + nexus_ids=["2"], + catchment_ids=["42"], + partition_id=0, + remote_up_nexuses=["1"], + remote_down_nexuses=["3"], + ) + + serialized_partition = { + "cat-ids": ["42"], + "id": 0, + "remote-up": ["1"], + "nex-ids": ["2"], + "remote-down": ["3"], + } + + def test_programmatically_create_partition(self): + """Test creating an instance programmatically""" + o = self.partition_instance + + self.assertEqual(len(o.catchment_ids), 1) + self.assertEqual(len(o.remote_upstream_nexus_ids), 1) + self.assertEqual(len(o.nexus_ids), 1) + self.assertEqual(len(o.remote_downstream_nexus_ids), 1) + + self.assertIn + self.assertEqual(o.partition_id, 0) + self.assertIn("42", o.catchment_ids) + self.assertIn("1", o.remote_upstream_nexus_ids) + self.assertIn("2", o.nexus_ids) + self.assertIn("3", o.remote_downstream_nexus_ids) + + def test_factory_init_from_deserialized_json(self): + """ + Test creating an instance from a dictionary, then re-serializing equals the original dict. + """ + data = self.serialized_partition + o = Partition.factory_init_from_deserialized_json(data) + self.assertIsNotNone(o) + self.assertDictEqual(data, o.to_dict()) # type: ignore + + def test_eq(self): + """ + Test equality of instances. Tests instances created programmatically and from dict + deserialization. + """ + o1 = self.partition_instance + o2 = self.partition_instance + + o3 = Partition.factory_init_from_deserialized_json(self.serialized_partition) + self.assertEqual(o1, o1) + self.assertEqual(o1, o2) + self.assertEqual(o1, o3) + + def test_hash(self): + """ + Test instances hash to the same value based on their data, not the order of their data. + """ + catchment_ids = ["1", "2", "3"] + rev_catchment_ids = catchment_ids[::-1] + + o1 = Partition( + # these fields are used by __hash__ + catchment_ids=catchment_ids, + partition_id=0, + nexus_ids=["2"], + remote_up_nexuses=["1"], + remote_down_nexuses=["3"], + ) + + o2 = Partition( + # these fields are used by __hash__ + catchment_ids=rev_catchment_ids, + partition_id=0, + nexus_ids=["2"], + remote_up_nexuses=["1"], + remote_down_nexuses=["3"], + ) + + self.assertNotEqual(catchment_ids, rev_catchment_ids) + self.assertEqual(hash(o1), hash(o2)) + + def test_to_dict(self): + """Test serializing to dict""" + o = Partition.factory_init_from_deserialized_json(self.serialized_partition) + + self.assertIsNotNone(o) + self.assertDictEqual(o.to_dict(), self.serialized_partition) # type: ignore + + +class TestPartitionConfig(unittest.TestCase): + partition_config_instance = PartitionConfig(partitions=[TestPartition.partition_instance]) + + serialized_partition_config = {"partitions": [TestPartition.serialized_partition]} + + def test_programmatically_create_partition(self): + """Test creating an instance programmatically""" + o = self.partition_config_instance + + self.assertEqual(len(o.partitions), 1) + + def test_factory_init_from_deserialized_json(self): + """Test creating an instance programmatically""" + data = self.serialized_partition_config + o = PartitionConfig.factory_init_from_deserialized_json(data) + + self.assertIsNotNone(o) + self.assertDictEqual(data, o.to_dict()) # type: ignore + + def test_to_dict(self): + o = PartitionConfig.factory_init_from_deserialized_json( + self.serialized_partition_config + ) + self.assertIsNotNone(o) + self.assertDictEqual(self.serialized_partition_config, o.to_dict()) # type: ignore + + def test_hash(self): + """ + Test instances hash to the same value based on their data, not the order of their data. + """ + self.assertEqual( + hash(self.partition_config_instance), hash(self.partition_config_instance) + ) + + # from dictionary + o = PartitionConfig.factory_init_from_deserialized_json( + self.partition_config_instance.to_dict() + ) + self.assertEqual(hash(self.partition_config_instance), hash(o)) + + catchment_ids = ["1", "2", "3"] + + o1 = PartitionConfig( + partitions=[ + Partition( + nexus_ids=["1"], + remote_up_nexuses=["2"], + remote_down_nexuses=["3"], + partition_id=0, + catchment_ids=catchment_ids, + ) + ] + ) + + o2 = PartitionConfig( + partitions=[ + Partition( + nexus_ids=["2222"], + remote_up_nexuses=["1111"], + remote_down_nexuses=["3333"], + partition_id=0, + catchment_ids=catchment_ids, + ) + ] + ) + + # same partition and catchment ids + # NOTE: this is the expected behavior + self.assertEqual(hash(o1), hash(o2)) + + def test_duplicate_partitions_removed_during_init(self): + catchment_ids = ["1", "2", "3"] + rev_catchment_ids = catchment_ids[::-1] + + o1 = Partition( + nexus_ids=["2"], + remote_up_nexuses=["1"], + remote_down_nexuses=["3"], + partition_id=0, + catchment_ids=catchment_ids, + ) + + o2 = Partition( + nexus_ids=["2"], + remote_up_nexuses=["1"], + remote_down_nexuses=["3"], + partition_id=0, + catchment_ids=rev_catchment_ids, + ) + + duplicate_partition_inst = PartitionConfig(partitions=[o1, o1]) + self.assertEqual(len(duplicate_partition_inst.partitions), 1) + + duplicate_partition_same_data_inst = PartitionConfig(partitions=[o1, o2]) + self.assertEqual(len(duplicate_partition_same_data_inst.partitions), 1) + catchment_ids = ["1", "2", "3"] + + o1 = Partition( + nexus_ids=["2"], + remote_up_nexuses=["1"], + remote_down_nexuses=["3"], + partition_id=0, + catchment_ids=catchment_ids, + ) + + o3 = Partition( + nexus_ids=["2222"], + remote_up_nexuses=["1111"], + remote_down_nexuses=["3333"], + partition_id=0, + catchment_ids=catchment_ids, + ) + + same_catchment_id_and_partition_id = PartitionConfig(partitions=[o1, o3]) + + # NOTE: this is the expected behavior + self.assertEqual(len(same_catchment_id_and_partition_id.partitions), 1) + diff --git a/python/lib/modeldata/setup.py b/python/lib/modeldata/setup.py index 7c389a0e7..176c58c05 100644 --- a/python/lib/modeldata/setup.py +++ b/python/lib/modeldata/setup.py @@ -4,23 +4,33 @@ ROOT = Path(__file__).resolve().parent try: - with open(ROOT / 'README.md', 'r') as readme: + with open(ROOT / "README.md", "r") as readme: long_description = readme.read() except: - long_description = '' + long_description = "" -exec(open(ROOT / 'dmod/modeldata/_version.py').read()) +exec(open(ROOT / "dmod/modeldata/_version.py").read()) setup( - name='dmod-modeldata', + name="dmod-modeldata", version=__version__, - description='', + description="", long_description=long_description, - author='', - author_email='', - url='', - license='', - install_requires=['numpy>=1.20.1', 'pandas', 'geopandas', 'dmod-communication>=0.4.2', 'dmod-core>=0.3.0', 'minio', - 'aiohttp<=3.7.4', 'hypy@git+https://github.com/NOAA-OWP/hypy@master#egg=hypy&subdirectory=python'], - packages=find_namespace_packages(exclude=['dmod.test', 'schemas', 'ssl', 'src']) + author="", + author_email="", + url="", + license="", + install_requires=[ + "numpy>=1.20.1", + "pandas", + "geopandas", + "dmod-communication>=0.4.2", + "dmod-core>=0.3.0", + "minio", + "aiohttp<=3.7.4", + "hypy@git+https://github.com/NOAA-OWP/hypy@master#egg=hypy&subdirectory=python", + "gitpython", + "pydantic", + ], + packages=find_namespace_packages(exclude=["dmod.test", "schemas", "ssl", "src"]), ) diff --git a/python/lib/scheduler/dmod/scheduler/job/job.py b/python/lib/scheduler/dmod/scheduler/job/job.py index e7cf9b912..2b21eeea0 100644 --- a/python/lib/scheduler/dmod/scheduler/job/job.py +++ b/python/lib/scheduler/dmod/scheduler/job/job.py @@ -1,26 +1,33 @@ from abc import ABC, abstractmethod from datetime import datetime -from numbers import Number +from pydantic import Field, PrivateAttr, validator, root_validator +from pydantic.fields import ModelField +from warnings import warn from dmod.core.execution import AllocationParadigm from dmod.communication import ExternalRequest, ModelExecRequest, NGENRequest, SchedulerRequestMessage from dmod.core.serializable import Serializable from dmod.core.meta_data import DataRequirement +from dmod.core.enum import PydanticEnum from dmod.modeldata.hydrofabric import PartitionConfig -from enum import Enum -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING, Union +from typing_extensions import Self from uuid import UUID from uuid import uuid4 as uuid_func from ..resources import ResourceAllocation - -if TYPE_CHECKING: - from .. import RsaKeyPair +from .. import RsaKeyPair import logging +if TYPE_CHECKING: + from pydantic.typing import AbstractSetIntStr, MappingIntStrAny, DictStrAny + +# SAFETY: tuple can be used in this context because this sentinel is being used to verify if the data is being +# deserialized from json. Tuple's are not datatypes in json or deserialized json. +JOB_CLASS_SENTINEL = tuple() -class JobExecStep(Enum): +class JobExecStep(PydanticEnum): """ A component of a JobStatus, representing the particular step within a "phase" encoded within the current status. @@ -133,7 +140,7 @@ def uid(self) -> int: return self._uid -class JobExecPhase(Enum): +class JobExecPhase(PydanticEnum): """ A component of a JobStatus, representing the high level transition stage at which a status exists. """ @@ -218,14 +225,29 @@ class JobStatus(Serializable): """ Representation of a ::class:`Job`'s status as a combination of phase and exec step. """ - _NAME_DELIMITER = ':' + _NAME_DELIMITER: ClassVar[str] = ':' - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict) -> 'JobStatus': - try: - cls(phase=JobExecPhase.get_for_name(json_obj['phase']), step=JobExecStep.get_for_name(json_obj['step'])) - except: - return None + # NOTE: `None` is valid input, default value for field will be used. + phase: Optional[JobExecPhase] = Field(JobExecPhase.UNKNOWN) + # NOTE: field value will be derived from `phase` field if field is unset or None. + step: Optional[JobExecStep] + + @validator("phase", pre=True) + def _set_default_phase_if_none(cls, value: Optional[JobExecPhase], field: ModelField) -> JobExecPhase: + if value is None: + return field.default + + return value + + @validator("step", always=True) + def _set_default_or_derived_step_if_none(cls, value: Optional[JobExecStep], values: Dict[str, JobExecPhase]) -> JobExecStep: + # implicit assertion that `phase` key has already been processed by it's validator + phase: JobExecPhase = values["phase"] + + if value is None: + return phase.default_start_step + + return value @classmethod def get_for_name(cls, name: str) -> 'JobStatus': @@ -258,26 +280,20 @@ def get_for_name(cls, name: str) -> 'JobStatus': if len(parsed_list) != 2: return JobStatus(JobExecPhase.UNKNOWN, JobExecStep.DEFAULT) - return JobStatus(phase=JobExecPhase.get_for_name(parsed_list[0]), - step=JobExecStep.get_for_name(parsed_list[1])) + phase, step = parsed_list + return JobStatus(phase=phase, step=step) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, JobStatus): return self.job_exec_phase == other.job_exec_phase and self.job_exec_step == other.job_exec_step else: return False - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) - def __init__(self, phase: Optional[JobExecPhase], step: Optional[JobExecStep] = None): - self._phase = JobExecPhase.UNKNOWN if phase is None else phase - if step is not None: - self._step = step - elif self._phase is not None: - self._step = self._phase.default_start_step - else: - self._step = JobExecStep.DEFAULT + def __init__(self, phase: Optional[JobExecPhase], step: Optional[JobExecStep] = None, **data): + super().__init__(phase=phase, step=step, **data) def get_for_new_step(self, step: JobExecStep) -> 'JobStatus': """ @@ -316,22 +332,15 @@ def is_interrupted(self) -> bool: @property def job_exec_phase(self) -> JobExecPhase: - return self._phase + return self.phase # type: ignore @property def job_exec_step(self) -> JobExecStep: - return self._step + return self.step # type: ignore @property def name(self) -> str: - return self.job_exec_phase.name + self._NAME_DELIMITER + self.job_exec_step.name - - def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]: - serial = dict() - serial['phase'] = self.job_exec_phase.name - serial['step'] = self.job_exec_step.name - return serial - + return f"{self.job_exec_phase.name}{self._NAME_DELIMITER}{self.job_exec_step.name}" class Job(Serializable, ABC): """ @@ -343,8 +352,111 @@ class Job(Serializable, ABC): The hash value of a job is calculated as the hash of it's ::attribute:`job_id`. """ + allocation_paradigm: AllocationParadigm + """The ::class:`AllocationParadigm` type value that was used or should be used to make allocations.""" + + allocation_priority: int = 0 + """A score for how this job should be prioritized with respect to allocation.""" + + allocations: Optional[Tuple[ResourceAllocation, ...]] + """The scheduler resource allocations for this job, or ``None`` if it is queued or otherwise not yet allocated.""" + + cpu_count: int = Field(gt=0) + """The number of CPUs for this job.""" + + data_requirements: List[DataRequirement] = Field(default_factory=list) + """List of ::class:`DataRequirement` objects representing all data needed for the job.""" + + job_id: str = Field(default_factory=lambda: str(uuid_func())) + """The unique identifier for this particular job.""" + + last_updated: datetime = Field(default_factory=datetime.now) + """ The last time this objects state was updated.""" + + memory_size: int = Field(gt=0) + """The amount of the memory needed for this job.""" + + # TODO: do we need to account for jobs for anything other than model exec? + model_request: ExternalRequest + """The underlying configuration for the model execution that is being requested.""" + + partition_config: Optional[PartitionConfig] + """This job's partitioning configuration.""" + + rsa_key_pair: Optional[RsaKeyPair] + """The ::class:`'RsaKeyPair'` for this job's shared SSH RSA keys, or ``None`` if not has been set.""" + + status: JobStatus = Field(default_factory=lambda: JobStatus(JobExecPhase.INIT)) + """The ::class:`JobStatus` of this object.""" + + job_class: Type[Self] = JOB_CLASS_SENTINEL + """A type or subtype of ::class:`Self`. This can be provided as a str (e.g. "Job"), but will be coerced into a Type + object. Class names, not including module namespace, are used when coercing from a str into a Type (i.e. "job.Job" + is invalid; "Job" is valid). This field is required when factory deserializing from a dictionary. The field defaults + to the type of Self when programmatically creating an instance. It may be possible to specify a `job_class` during + programmatic initialization, however that capability is subtype dependent. + + Notably, the `job_class` field of subtypes of Job are also covariant in Self. Meaning, the `job_class` field of a + subtype S can only be S or a subtype of S. Sibling and super types of S are not allowed. + """ + + @classmethod + def _subclass_search(cls, t: Union[str, Any]) -> Optional[Type[Self]]: + if isinstance(t, str): + # base case + if t == cls.__name__: + return cls + + current_level: List[Type[Self]] = cls.__subclasses__() + # bfs subclass search + while True: + next_level: List[Type[Self]] = list() + for subclass in current_level: + if t == subclass.__name__: + return subclass + next_level.extend(subclass.__subclasses__()) + + # no more levels to explore + if not next_level: + raise ValueError( + f"`t`: {t!r} must be a str with value name of Type[{cls.__name__}]. This includes subtypes of `{cls.__name__}`" + ) + + current_level = next_level + + return None + + @validator("job_class", pre=True, always=True) + def _validate_job_class(cls: Self, value: Union[str, Type[Self]]) -> Type[Self]: + # default case. Is unreachable when factory init from json. + if value is JOB_CLASS_SENTINEL: + return cls + + subclass = cls._subclass_search(value) + if subclass is not None: + return subclass + + if value == cls: + return value + + if issubclass(value, cls): + return value + + raise ValueError( + f"`job_class` field must be a Type[{cls.__name__}]. This includes subtypes of `{cls.__name__}`" + ) + + class Config: + fields = { + "partition_config": {"alias": "partitioning"} + } + field_serializers = { + "job_class": lambda cls: cls.__name__, + "last_updated": lambda self, value: value.strftime(self.get_datetime_str_format()) + } + @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): + def factory_init_from_deserialized_json(cls, json_obj: dict) -> Optional[Self]: """ Factory create a new instance of the correct subtype based on a JSON object dictionary deserialized from received JSON, where this includes a ``job_class`` property containing the name of the appropriate subtype. @@ -358,36 +470,21 @@ def factory_init_from_deserialized_json(cls, json_obj: dict): A new object of the correct subtype instantiated from the deserialize JSON object dictionary, or ``None`` if this cannot be done successfully. """ - job_type_key = 'job_class' - recursive_loop_key = 'base_type_invoked_twice' + try: + if "job_class" not in json_obj: + raise KeyError("missing `job_class` field") - if job_type_key not in json_obj: - return None + subclass = cls._subclass_search(json_obj["job_class"]) + + if subclass is None: + raise ValueError("`job_class` field must be provided as a type `str`") - # Avoid accidental recursive infinite loop by adding an indicator key and bailing if we already see it - if recursive_loop_key in json_obj: + json_obj["job_class"] = subclass + return subclass(**json_obj) + except: return None - else: - json_obj[recursive_loop_key] = True - - # Traverse class type tree and get all subtypes of Job - subclasses = [] - subclasses.extend(cls.__subclasses__()) - traversed_subclasses = set() - while len(subclasses) > len(traversed_subclasses): - for s in subclasses: - if s not in traversed_subclasses: - subclasses.extend(s.__subclasses__()) - traversed_subclasses.add(s) - - for subclass in subclasses: - subclass_name = subclass.__name__ - if subclass_name == json_obj[job_type_key]: - json_obj.pop(job_type_key) - return subclass.factory_init_from_deserialized_json(json_obj) - return None - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if other is None: return False elif isinstance(other, Job): @@ -400,39 +497,12 @@ def __eq__(self, other): # infinite loop (perhaps via some shared interface where that's appropriate) return False - def __hash__(self): + def __hash__(self) -> int: return hash(self.job_id) - def __lt__(self, other): + def __lt__(self, other: "Job") -> bool: return self.allocation_priority < other.allocation_priority - @property - @abstractmethod - def allocation_paradigm(self) -> AllocationParadigm: - """ - The ::class:`AllocationParadigm` type value that was used or should be used to make allocations. - - Returns - ------- - AllocationParadigm - The ::class:`AllocationParadigm` type value that was used or should be used to make allocations. - """ - pass - - @property - @abstractmethod - def allocation_priority(self) -> int: - """ - Get a score for how this job should be prioritized with respect to allocation, with high scores being more - likely to received allocation. - - Returns - ------- - int - A score for how this job should be prioritized with respect to allocation. - """ - pass - @property @abstractmethod def allocation_service_names(self) -> Optional[Tuple[str]]: @@ -451,53 +521,28 @@ def allocation_service_names(self) -> Optional[Tuple[str]]: """ pass - @property @abstractmethod - def allocations(self) -> Optional[Tuple[ResourceAllocation]]: - """ - The resource allocations that have been allocated for this job. - - Returns - ------- - Optional[List[ResourceAllocation]] - The scheduler resource allocations for this job, or ``None`` if it is queued or otherwise not yet allocated. - """ + def set_allocations(self, allocations: List[ResourceAllocation]): pass - @allocations.setter @abstractmethod - def allocations(self, allocations: List[ResourceAllocation]): + def set_data_requirements(self, data_requirements: List[DataRequirement]): pass - @property @abstractmethod - def cpu_count(self) -> int: - """ - The number of CPUs for this job. - - Returns - ------- - int - The number of CPUs for this job. - """ + def set_partition_config(self, part_config: PartitionConfig): pass - @property @abstractmethod - def data_requirements(self) -> List[DataRequirement]: - """ - List of ::class:`DataRequirement` objects representing all data needed for the job. + def set_status(self, status: JobStatus): + pass - Returns - ------- - List[DataRequirement] - List of ::class:`DataRequirement` objects representing all data needed for the job. - """ + @abstractmethod + def set_status_phase(self, phase: JobExecPhase): pass - @data_requirements.setter @abstractmethod - def data_requirements(self, data_requirements: List[DataRequirement]): + def set_status_step(self, step: JobExecStep): pass @property @@ -513,89 +558,6 @@ def is_partitionable(self) -> bool: """ pass - @property - @abstractmethod - def job_id(self): - """ - The unique identifier for this particular job. - - Returns - ------- - The unique identifier for this particular job. - """ - pass - - @property - @abstractmethod - def last_updated(self) -> datetime: - """ - The last time this objects state was updated. - - Returns - ------- - datetime - The last time this objects state was updated. - """ - pass - - @property - @abstractmethod - def memory_size(self) -> int: - """ - The amount of the memory needed for this job. - - Returns - ------- - int - The amount of the memory needed for this job. - """ - pass - - # TODO: do we need to account for jobs for anything other than model exec? - @property - @abstractmethod - def model_request(self) -> ExternalRequest: - """ - Get the underlying configuration for the model execution that is being requested. - - Returns - ------- - ExternalRequest - The underlying configuration for the model execution that is being requested. - """ - pass - - @property - @abstractmethod - def partition_config(self) -> Optional[PartitionConfig]: - """ - Get this job's partitioning configuration. - - Returns - ------- - PartitionConfig - This job's partitioning configuration. - """ - pass - - @partition_config.setter - @abstractmethod - def partition_config(self, part_config: PartitionConfig): - pass - - @property - @abstractmethod - def rsa_key_pair(self) -> Optional['RsaKeyPair']: - """ - The ::class:`'RsaKeyPair'` for this job's shared SSH RSA keys. - - Returns - ------- - Optional['RsaKeyPair'] - The ::class:`'RsaKeyPair'` for this job's shared SSH RSA keys, or ``None`` if not has been set. - """ - pass - @property @abstractmethod def should_release_resources(self) -> bool: @@ -609,24 +571,6 @@ def should_release_resources(self) -> bool: """ pass - @property - @abstractmethod - def status(self) -> JobStatus: - """ - The ::class:`JobStatus` of this object. - - Returns - ------- - JobStatus - The ::class:`JobStatus` of this object. - """ - pass - - @status.setter - @abstractmethod - def status(self, status: JobStatus): - pass - @property def status_phase(self) -> JobExecPhase: """ @@ -639,11 +583,6 @@ def status_phase(self) -> JobExecPhase: """ return self.status.job_exec_phase - @status_phase.setter - @abstractmethod - def status_phase(self, phase: JobExecPhase): - pass - @property def status_step(self) -> JobExecStep: """ @@ -656,11 +595,6 @@ def status_step(self) -> JobExecStep: """ return self.status.job_exec_step - @status_step.setter - @abstractmethod - def status_step(self, step: JobExecStep): - pass - @property @abstractmethod def worker_data_requirements(self) -> List[List[DataRequirement]]: @@ -674,6 +608,45 @@ def worker_data_requirements(self) -> List[List[DataRequirement]]: """ pass + def _setter_methods(self) -> Dict[str, Callable]: + """Mapping of attribute name to setter method. This supports backwards functional compatibility.""" + # TODO: remove once migration to setters by down stream users is complete + return { + "allocations": self.set_allocations, + "data_requirements": self.set_data_requirements, + "partition_config": self.set_partition_config, + "status": self.set_status, + # derived properties + "status_phase": self.set_status_phase, + "status_step": self.set_status_step, + } + + def __setattr__(self, name: str, value: Any): + """ + Use property setter method when available. + + Note, all setter methods should modify their associated property using the instance `__dict__`. + This ensures that calls to, for example, `set_id` don't raise a warning, while `o.id = "new + id"` do. + + Example: + ``` + class SomeJob(Job): + id: str + + def set_id(self, value: str): + self.__dict__["id"] = value + ``` + """ + if name not in self._setter_methods(): + return super().__setattr__(name, value) + + setter_fn = self._setter_methods()[name] + + message = f"Setting by attribute is deprecated. Use `{self.__class__.__name__}.{setter_fn.__name__}` method instead." + warn(message, DeprecationWarning) + + setter_fn(value) class JobImpl(Job): """ @@ -682,197 +655,80 @@ class JobImpl(Job): Job ids are simply the string cast of generated UUID values, stored within the ::attribute:`job_uuid` property. """ - @classmethod - def _parse_serialized_allocation_paradigm(cls, json_obj: dict, key: str): - paradigm = AllocationParadigm.get_from_name(name=json_obj[key], strict=True) if key in json_obj else None - if not isinstance(paradigm, AllocationParadigm): - if paradigm is None: - type_name = 'None' - else: - type_name = paradigm.__class__.__name__ - raise RuntimeError(cls._get_invalid_type_message().format(key, str.__name__, type_name)) - return paradigm - - @classmethod - def _parse_serialized_allocations(cls, json_obj: dict, key: Optional[str] = None): - if key is None: - key = 'allocations' + # NOTE: more specific ExternalRequest subtype than super class + model_request: ModelExecRequest - if key not in json_obj: - return None + _worker_data_requirements: Optional[List[List[DataRequirement]]] = PrivateAttr(None) + _allocation_service_names: Optional[Tuple[str]] = PrivateAttr(None) - serial_alloc_list = json_obj[key] - if not isinstance(serial_alloc_list, list): - raise RuntimeError("Invalid format for allocations list value '{}'".format(str(serial_alloc_list))) - allocations = [] - for serial_alloc in serial_alloc_list: - if not isinstance(serial_alloc, dict): - raise RuntimeError("Invalid format for allocation value '{}'".format(str(serial_alloc_list))) - allocation = ResourceAllocation.factory_init_from_dict(serial_alloc) - if not isinstance(allocation, ResourceAllocation): - raise RuntimeError( - "Unable to deserialize `{}` to resource allocation while deserializing {}".format( - str(allocation), cls.__name__)) - allocations.append(allocation) - return allocations + @validator("allocation_paradigm", pre=True) + def _parse_allocation_paradigm(cls, value: Union[AllocationParadigm, str]) -> Union[str, AllocationParadigm]: + if isinstance(value, AllocationParadigm): + return value - @classmethod - def _parse_serialized_data_requirements(cls, json_obj: dict, key: Optional[str] = None): - if key is None: - key = 'data_requirements' + # NOTE: potentially remove in future. There are cases in codebase where kabob case is being used. + return value.replace("-", "_") - if key not in json_obj: - return None + @validator("status", pre=True) + def _parse_status(cls, value: Optional[Union[str, JobStatus]], field: ModelField) -> JobStatus: + if value is None: + if field.default_factory is None: + raise RuntimeError("unreachable") + return field.default_factory() - serial_list = json_obj[key] - if not isinstance(serial_list, list): - raise RuntimeError("Invalid format for data requirements list value '{}'".format(str(serial_list))) - data_req_list = [] - for serial_data_req in serial_list: - if not isinstance(serial_data_req, dict): - raise RuntimeError("Invalid format for data requirements value '{}'".format(str(serial_list))) - data_req = DataRequirement.factory_init_from_deserialized_json(serial_data_req) - if not isinstance(data_req, DataRequirement): - msg = "Unable to deserialize `{}` to nested data requirements while deserializing {}" - raise RuntimeError(msg.format(serial_data_req, cls.__name__)) - data_req_list.append(data_req) - return data_req_list + if isinstance(value, JobStatus): + return value - @classmethod - def _parse_serialized_job_status(cls, json_obj: dict, key: Optional[str] = None): - # Set this to the default value if it is initially None - if key is None: - key = 'status' - status_str = cls.parse_simple_serialized(json_obj=json_obj, key=key, expected_type=str, required_present=False) - if status_str is None: - return None - return JobStatus.get_for_name(name=status_str) + value = str(value) + return JobStatus.get_for_name(name=value) - @classmethod - def _parse_serialized_last_updated(cls, json_obj: dict, key: Optional[str] = None): - date_str_converter = lambda date_str: datetime.strptime(date_str, cls.get_datetime_str_format()) - if key is None: - key = 'last_updated' - if key in json_obj: - return cls.parse_simple_serialized(json_obj=json_obj, key=key, expected_type=datetime, - converter=date_str_converter, required_present=False) - else: - return None + @validator("last_updated", pre=True) + def _parse_serialized_last_updated(cls, value: Union[str, datetime]) -> datetime: + if isinstance(value, datetime): + return value - @classmethod - def _parse_serialized_partition_config(cls, json_obj: dict, key: Optional[str] = None): - if key is None: - key = 'partitioning' - if key in json_obj: - return PartitionConfig.factory_init_from_deserialized_json(json_obj[key]) - else: - return None + try: + value = str(value) + return datetime.strptime(value, cls.get_datetime_str_format()) + except: + return datetime.now() - @classmethod - def _parse_serialized_rsa_key_pair(cls, json_obj: dict, key: Optional[str] = None, warn_if_missing: bool = False): - # Doing this here for now to avoid import errors - # TODO: find a better way for this - from .. import RsaKeyPair - - # Set this to the default value if it is initially None - if key is None: - # TODO: set somewhere globally - key = 'rsa_key_pair' - if key not in json_obj: - if warn_if_missing: - # TODO: log this better. NJF changed print to logging.warning, anything else needed? - msg = 'Warning: expected serialized RSA key at {} when deserializing {} object' - logging.warning(msg.format(key, cls.__name__)) - return None - if key not in json_obj or json_obj[key] is None: - return None - rsa_key_pair = RsaKeyPair.factory_init_from_deserialized_json(json_obj=json_obj[key]) - if rsa_key_pair is None: - raise RuntimeError('Could not deserialized child RsaKeyPair when deserializing ' + cls.__name__) - else: - return rsa_key_pair + @validator("data_requirements", pre=True) + def _populate_default_data_requirements(cls, value: Optional[List[DataRequirement]]) -> List[DataRequirement]: + if value is None: + return list() + return value - # TODO: unit test - # TODO: consider moving this up to Job or even Serializable + @validator("model_request", pre=True) + def _deserialize_model_request(cls, value: Union[Dict[str, Any], ModelExecRequest]) -> ModelExecRequest: + if isinstance(value, ModelExecRequest): + return value - @classmethod - def deserialize_core_attributes(cls, json_obj: dict): - """ - Deserialize the core attributes of the basic ::class:`JobImpl` implementation from the provided dictionary and - return as a tuple. + return ModelExecRequest.factory_init_correct_subtype_from_deserialized_json(value) - Parameters - ---------- - json_obj + @validator("job_id", pre=True) + def _validate_job_id(cls, value: Optional[Union[UUID, str]], field: ModelField) -> str: + if value is None: + if field.default_factory is None: + raise RuntimeError("unreachable") + return field.default_factory() - Returns - ------- - The tuple with parse values of (cpus, memory, paradigm, priority, job_id, rsa_key_pair, status, allocations, - updated, partitioning) from the provided dictionary. - """ - int_converter = lambda x: int(x) - cpus = cls.parse_simple_serialized(json_obj=json_obj, key='cpu_count', expected_type=int, - converter=int_converter) - memory = cls.parse_simple_serialized(json_obj=json_obj, key='memory_size', expected_type=int, - converter=int_converter) - paradigm = cls._parse_serialized_allocation_paradigm(json_obj=json_obj, key='allocation_paradigm') - priority = cls.parse_simple_serialized(json_obj=json_obj, key='allocation_priority', expected_type=int, - converter=int_converter) - job_id = cls.parse_serialized_job_id(serialized_value=None, json_obj=json_obj, key='job_id') - rsa_key_pair = cls._parse_serialized_rsa_key_pair(json_obj=json_obj) - status = cls._parse_serialized_job_status(json_obj=json_obj) - allocations = cls._parse_serialized_allocations(json_obj=json_obj) - updated = cls._parse_serialized_last_updated(json_obj=json_obj) - partitioning = cls._parse_serialized_partition_config(json_obj=json_obj, key='partitioning') - return cpus, memory, paradigm, priority, job_id, rsa_key_pair, status, allocations, updated, partitioning + if isinstance(value, UUID): + return str(value) - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. + return str(UUID(value)) - Parameters - ---------- - json_obj + @root_validator(pre=True) + def _parse_job_id(cls, values: Dict[str, Any]) -> Dict[str, Any]: + job_id = values.get("job_id") + if job_id is not None: + return values - Returns - ------- - A new object of this type instantiated from the deserialize JSON object dictionary - """ + values["job_id"] = cls.parse_serialized_job_id(job_id, **values) + return values - try: - cpus, memory, paradigm, priority, job_id, rsa_key_pair, status, allocations, updated, partitioning = \ - cls.deserialize_core_attributes(json_obj) - - if 'model_request' in json_obj: - model_request = ModelExecRequest.factory_init_correct_subtype_from_deserialized_json(json_obj['model_request']) - else: - # TODO: add serialize/deserialize support for other situations/requests (also change 'model_request' property name) - msg = "Type {} can only support deserializing JSON containing a {} under the 'model_request' key" - raise RuntimeError(msg.format(cls.__name__, ModelExecRequest.__name__)) - - obj = cls(cpu_count=cpus, memory_size=memory, model_request=model_request, allocation_paradigm=paradigm, - alloc_priority=priority) - - if job_id is not None: - obj.job_id = job_id - if rsa_key_pair is not None: - obj.rsa_key_pair = rsa_key_pair - if status is not None: - obj.status = status - if updated is not None: - obj._last_updated = updated - if allocations is not None: - obj.allocations = allocations - obj.data_requirements = cls._parse_serialized_data_requirements(json_obj) - if partitioning is not None: - obj.partition_config = partitioning - - return obj - - except RuntimeError as e: - logging.error(e) - return None + # TODO: unit test + # TODO: consider moving this up to Job or even Serializable @classmethod def parse_serialized_job_id(cls, serialized_value: Optional[str], **kwargs): @@ -923,46 +779,38 @@ def parse_serialized_job_id(cls, serialized_value: Optional[str], **kwargs): RuntimeError Raised if the parameter does not parse to a UUID. """ + if serialized_value is not None: + return serialized_value + key_key = 'key' - json_obj_key = 'json_obj' # First, try to obtain a serialized value, if one was not already set - if serialized_value is None and kwargs is not None and json_obj_key in kwargs and key_key in kwargs: - if isinstance(kwargs[json_obj_key], dict) and kwargs[key_key] in kwargs[json_obj_key]: - try: - serialized_value = cls.parse_simple_serialized(json_obj=kwargs[json_obj_key], key=kwargs[key_key], - expected_type=str, converter=lambda x: str(x), - required_present=False) - except: - # TODO: consider logging this - return None - # Bail here if we don't have a serialized_value to work with - if serialized_value is None: - return None - try: - return UUID(str(serialized_value)) - except ValueError as e: - msg = "Failed parsing parameter value `{}` to UUID object: {}".format(str(serialized_value), str(e)) - raise RuntimeError(msg) + if kwargs is not None and key_key in kwargs: + if kwargs[key_key] in kwargs: + return kwargs[kwargs[key_key]] + + return None + def __init__(self, cpu_count: int, memory_size: int, model_request: ExternalRequest, - allocation_paradigm: Union[str, AllocationParadigm], alloc_priority: int = 0): - self._cpu_count = cpu_count - self._memory_size = memory_size - self._model_request = model_request - if isinstance(allocation_paradigm, AllocationParadigm): - self._allocation_paradigm = allocation_paradigm - else: - self._allocation_paradigm = AllocationParadigm.get_from_name(name=allocation_paradigm) - self._allocation_priority = alloc_priority - self._job_uuid = uuid_func() - self._rsa_key_pair = None - self._status = JobStatus(JobExecPhase.INIT) - self._allocations = None - self._data_requirements = None - self._worker_data_requirements = None - self._allocation_service_names = None - self._partition_config = None + allocation_paradigm: Union[str, AllocationParadigm], alloc_priority: int = 0, **data): + if data: + super().__init__( + allocation_paradigm=allocation_paradigm, + cpu_count=cpu_count, + memory_size=memory_size, + model_request=model_request, + **data, + ) + return + + super().__init__( + allocation_paradigm=allocation_paradigm, + allocation_priority=alloc_priority, + cpu_count=cpu_count, + memory_size=memory_size, + model_request=model_request, + ) self._reset_last_updated() def _process_per_worker_data_requirements(self) -> List[List[DataRequirement]]: @@ -974,11 +822,13 @@ def _process_per_worker_data_requirements(self) -> List[List[DataRequirement]]: List[List[DataRequirement]] List (indexed analogously to worker allocations) of lists of per-worker data requirements. """ + if self.allocations is None: + return [] # TODO: implement this properly/more efficiently - return [list(self.data_requirements) for a in self.allocations] + return [list(self.data_requirements) for _ in self.allocations] def _reset_last_updated(self): - self._last_updated = datetime.now() + self.last_updated = datetime.now() def add_allocation(self, allocation: ResourceAllocation): """ @@ -990,44 +840,16 @@ def add_allocation(self, allocation: ResourceAllocation): allocation : ResourceAllocation A resource allocation object to add. """ - if self._allocations is None: - self._allocations = list() - self._allocations.append(allocation) + if self.allocations is None: + self.set_allocations(tuple()) + self.set_allocations((*self.allocations, allocation)) # type: ignore self._allocation_service_names = None self._reset_last_updated() - @property - def allocation_paradigm(self) -> AllocationParadigm: - """ - The ::class:`AllocationParadigm` type value that was used or should be used to make allocations. - - For this type, the value is set as a private attribute during initialization, based on the value of the - ::attribute:`SchedulerRequestMessage.allocation_paradigm` string property present within the provided - ::class:`SchedulerRequestMessage` init param. - - Returns - ------- - AllocationParadigm - The ::class:`AllocationParadigm` type value that was used or should be used to make allocations. - """ - return self._allocation_paradigm - - @property - def allocation_priority(self) -> int: - """ - A score for how this job should be prioritized with respect to allocation, with high scores being more likely to - received allocation. - - Returns - ------- - int - A score for how this job should be prioritized with respect to allocation. - """ - return self._allocation_priority - - @allocation_priority.setter - def allocation_priority(self, priority: int): - self._allocation_priority = priority + def set_allocation_priority(self, priority: int): + # NOTE: set using dict to avoid deprecation warning thrown by `__setattr__`. See `Job.__setattr__` + # docstring for more detail. + self.__dict__["allocation_priority"] = priority self._reset_last_updated() @property @@ -1050,7 +872,7 @@ def allocation_service_names(self) -> Optional[Tuple[str]]: allocations. """ if self._allocation_service_names is None and self.allocations is not None and len(self.allocations) > 0: - service_names = [] + service_names: List[str] = [] # TODO: read this from request metadata base_name = "{}-worker".format(self.model_request.get_model_name()) num_allocations = len(self.allocations) @@ -1059,42 +881,24 @@ def allocation_service_names(self) -> Optional[Tuple[str]]: self._allocation_service_names = tuple(service_names) return self._allocation_service_names - @property - def allocations(self) -> Optional[Tuple[ResourceAllocation]]: - return None if self._allocations is None else tuple(self._allocations) - - @allocations.setter - def allocations(self, allocations: Union[List[ResourceAllocation], Tuple[ResourceAllocation]]): - if isinstance(allocations, tuple): - self._allocations = list(allocations) + def set_allocations(self, allocations: Union[List[ResourceAllocation], Tuple[ResourceAllocation]]): + if isinstance(allocations, list): + # NOTE: set using dict to avoid deprecation warning thrown by `__setattr__`. See `Job.__setattr__` + # docstring for more detail. + self.__dict__["allocations"] = tuple(allocations) else: - self._allocations = allocations + # NOTE: set using dict to avoid deprecation warning thrown by `__setattr__`. See `Job.__setattr__` + # docstring for more detail. + self.__dict__["allocations"] = allocations self._allocation_service_names = None self._reset_last_updated() - @property - def cpu_count(self) -> int: - return self._cpu_count - - @property - def data_requirements(self) -> List[DataRequirement]: - """ - List of ::class:`DataRequirement` objects representing all data needed for the job. - - Returns - ------- - List[DataRequirement] - List of ::class:`DataRequirement` objects representing all data needed for the job. - """ - if self._data_requirements is None: - self._data_requirements = [] - return self._data_requirements - - @data_requirements.setter - def data_requirements(self, data_requirements: List[DataRequirement]): + def set_data_requirements(self, data_requirements: List[DataRequirement]): # Make sure to reset worker data requirements if this is changed self._worker_data_requirements = None - self._data_requirements = data_requirements + # NOTE: set using dict to avoid deprecation warning thrown by `__setattr__`. See `Job.__setattr__` + # docstring for more detail. + self.__dict__["data_requirements"] = data_requirements self._reset_last_updated() @property @@ -1111,68 +915,25 @@ def is_partitionable(self) -> bool: """ return self.model_request is not None and isinstance(self.model_request, NGENRequest) - @property - def job_id(self) -> Optional[str]: - """ - The unique job id for this job in the manager, if one has been set for it, or ``None``. - - The getter for the property returns the ::attribute:`UUID.bytes` field of the ::attribute:`job_uuid` property, - if it is set, or ``None`` if it is not set. - - The setter for the property will actually set the ::attribute:`job_uuid` attribute, via a call to the setter for - the ::attribute:`job_uuid` property. ::attribute:`job_id`'s setter can accept either a ::class:`UUID` or a - string, with the latter case being used to initialize a ::class:`UUID` object. - - Returns - ------- - Optional[str] - The unique job id for this job in the manager, if one has been set for it, or ``None``. - """ - return str(self._job_uuid) if isinstance(self._job_uuid, UUID) else None - - @job_id.setter - def job_id(self, job_id: Union[str, UUID]): + def set_job_id(self, job_id: Union[str, UUID]): job_uuid = job_id if isinstance(job_id, UUID) else UUID(str(job_id)) - if job_uuid != self._job_uuid: - self._job_uuid = job_uuid + job_uuid = str(job_uuid) + if job_uuid != self.job_id: + # NOTE: set using dict to avoid deprecation warning thrown by `__setattr__`. See `Job.__setattr__` + # docstring for more detail. + self.__dict__["job_id"] = job_uuid self._reset_last_updated() - @property - def memory_size(self) -> int: - return self._memory_size - - @property - def last_updated(self) -> datetime: - return self._last_updated - - @property - def model_request(self) -> ExternalRequest: - """ - Get the underlying configuration for the model execution that is being requested. - - Returns - ------- - ExternalRequest - The underlying configuration for the model execution that is being requested. - """ - return self._model_request - - @property - def partition_config(self) -> Optional[PartitionConfig]: - return self._partition_config - - @partition_config.setter - def partition_config(self, part_config: PartitionConfig): - self._partition_config = part_config + def set_partition_config(self, part_config: PartitionConfig): + # NOTE: set using dict to avoid deprecation warning thrown by `__setattr__`. See `Job.__setattr__` + # docstring for more detail. + self.__dict__["partition_config"] = part_config - @property - def rsa_key_pair(self) -> Optional['RsaKeyPair']: - return self._rsa_key_pair - - @rsa_key_pair.setter - def rsa_key_pair(self, key_pair: 'RsaKeyPair'): - if key_pair != self._rsa_key_pair: - self._rsa_key_pair = key_pair + def set_rsa_key_pair(self, key_pair: 'RsaKeyPair'): + if key_pair != self.rsa_key_pair: + # NOTE: set using dict to avoid deprecation warning thrown by `__setattr__`. See `Job.__setattr__` + # docstring for more detail. + self.__dict__["rsa_key_pair"] = key_pair self._reset_last_updated() @property @@ -1190,34 +951,21 @@ def should_release_resources(self) -> bool: # TODO: confirm that allocations should be maintained for stopped output jobs while in eval or calibration phase return self.status_step == JobExecStep.FAILED or self.status_phase == JobExecPhase.CLOSED - @property - def status(self) -> JobStatus: - return self._status - - @status.setter - def status(self, new_status: JobStatus): - if new_status != self._status: - self._status = new_status + def set_status(self, status: JobStatus): + if status != self.status: + # NOTE: set using dict to avoid deprecation warning thrown by `__setattr__`. See `Job.__setattr__` + # docstring for more detail. + self.__dict__["status"] = status self._reset_last_updated() - @property - def status_phase(self) -> JobExecPhase: - return super().status_phase + def set_status_phase(self, phase: JobExecPhase): + self.set_status(JobStatus(phase=phase, step=phase.default_start_step)) - @status_phase.setter - def status_phase(self, phase: JobExecPhase): - self.status = JobStatus(phase=phase, step=phase.default_start_step) + def set_status_step(self, step: JobExecStep): + self.set_status(JobStatus(phase=self.status.job_exec_phase, step=step)) @property - def status_step(self) -> JobExecStep: - return super().status_step - - @status_step.setter - def status_step(self, new_step: JobExecStep): - self.status = JobStatus(phase=self.status.job_exec_phase, step=new_step) - - @property - def worker_data_requirements(self) -> List[List[DataRequirement]]: + def worker_data_requirements(self) -> Optional[List[List[DataRequirement]]]: """ List of lists of per-worker data requirements, indexed analogously to worker allocations. @@ -1230,64 +978,59 @@ def worker_data_requirements(self) -> List[List[DataRequirement]]: self._worker_data_requirements = self._process_per_worker_data_requirements() return self._worker_data_requirements - def to_dict(self) -> dict: - """ - Get the representation of this instance as a dictionary or dictionary-like object (e.g., a JSON object). - - { - "job_class" : "", - "cpu_count" : 4, - "memory_size" : 1000, - "model_request" : {}, - "allocation_paradigm" : "SINGLE_NODE", - "allocation_priority" : 0, - "job_id" : "12345678-1234-5678-1234-567812345678", - "rsa_key_pair" : {}, - "status" : INIT:DEFAULT, - "last_updated" : "2020-07-10 12:05:45", - "allocations" : [...], - 'data_requirements" : [...], - "partitioning" : { "partitions": [ ... ] } - } - - Returns - ------- - dict - the representation of this instance as a dictionary or dictionary-like object (e.g., a JSON object) - """ - serial = dict() - - serial['job_class'] = self.__class__.__name__ - serial['cpu_count'] = self.cpu_count - serial['memory_size'] = self.memory_size - - # TODO: support other scenarios along with deserializing (maybe even eliminate RequestedJob subtype) - if isinstance(self.model_request, ModelExecRequest): - request_key = 'model_request' - else: - msg = "Type {} can only support serializing to JSON when fulfilled request is a {}" - raise RuntimeError(msg.format(self.__class__.__name__, ModelExecRequest.__name__)) - serial[request_key] = self.model_request.to_dict() - - if self.allocation_paradigm: - serial['allocation_paradigm'] = self.allocation_paradigm.name - serial['allocation_priority'] = self.allocation_priority - if self.job_id is not None: - serial['job_id'] = str(self.job_id) - if self.rsa_key_pair is not None: - serial['rsa_key_pair'] = self.rsa_key_pair.to_dict() - serial['status'] = self.status.name - serial['last_updated'] = self._last_updated.strftime(self.get_datetime_str_format()) - serial['data_requirements'] = [] - for dr in self.data_requirements: - serial['data_requirements'].append(dr.to_dict()) - if self.allocations is not None and len(self.allocations) > 0: - serial['allocations'] = [] - for allocation in self.allocations: - serial['allocations'].append(allocation.to_dict()) - if self.partition_config is not None: - serial['partitioning'] = self.partition_config.to_dict() - + def _setter_methods(self) -> Dict[str, Callable]: + return { + **super()._setter_methods(), + "allocation_priority": self.set_allocation_priority, + "job_id": self.set_job_id, + "rsa_key_pair": self.set_rsa_key_pair, + } + + def dict( + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = True, # Note, this follows Serializable convention + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = True, + ) -> "DictStrAny": + def add(*fields: str, collection: Union[Set[str], Dict[str, bool]]) -> Union[Set[str], Dict[str, bool]]: + if isinstance(collection, set): + collection_copy = {*collection} + for field in fields: + collection_copy.add(field) + return collection_copy + + elif isinstance(exclude, dict): + collection_copy = {**collection} + for field in fields: + collection_copy[field] = True + return collection_copy + + return collection + + exclude = exclude or set() + + # conditionally exclude `allocations` and `partitioning` if allocations is None or is empty + if self.allocations is None or not len(self.allocations): + exclude = add("allocations", "partitioning", collection=exclude) + + serial = super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + # serialize status as "{PHASE}:{STEP}" + if "status" not in exclude: + serial["status"] = self.status.name return serial @@ -1297,119 +1040,44 @@ class RequestedJob(JobImpl): in the form of a ::class:`SchedulerRequestMessage` object. """ - @classmethod - def factory_init_from_deserialized_json(cls, json_obj: dict): - """ - Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. - - Parameters - ---------- - json_obj - - Returns - ------- - A new object of this type instantiated from the deserialize JSON object dictionary - """ - - originating_request_key = 'originating_request' - - try: - cpus, memory, paradigm, priority, job_id, rsa_key_pair, status, allocations, updated, partitioning = \ - cls.deserialize_core_attributes(json_obj) - - if originating_request_key not in json_obj: - msg = 'Key for originating request ({}) not present when deserialize {} object' - raise RuntimeError(msg.format(originating_request_key, cls.__name__)) - request = SchedulerRequestMessage.factory_init_from_deserialized_json(json_obj[originating_request_key]) - if request is None: - msg = 'Invalid serialized scheduler request when deserialize {} object' - raise RuntimeError(msg.format(cls.__name__)) - except Exception as e: - logging.error(e) - return None - - # Create the object initially from the request - new_obj = cls(job_request=request) - - # Then update its properties based on the deserialized values, as those are considered most correct - - # Use property setter for job id to handle string or UUID - new_obj.job_id = job_id - - new_obj._cpu_count = cpus - new_obj._memory_size = memory - new_obj._allocation_paradigm = paradigm - new_obj._allocation_priority = priority - new_obj._rsa_key_pair = rsa_key_pair - new_obj._status = status - new_obj._allocations = allocations - new_obj.data_requirements = cls._parse_serialized_data_requirements(json_obj) - new_obj._partition_config = partitioning - - # Do last_updated last, as any usage of setters above might cause the value to be maladjusted - new_obj._last_updated = updated - - return new_obj - - def __init__(self, job_request: SchedulerRequestMessage): - self._originating_request = job_request - super().__init__(cpu_count=job_request.cpus, memory_size=job_request.memory, - model_request=job_request.model_request, - allocation_paradigm=job_request.allocation_paradigm) - self.data_requirements = self.model_request.data_requirements - - @property - def model_request(self) -> ExternalRequest: - """ - Get the underlying configuration for the model execution that is being requested. - - Returns - ------- - ExternalRequest - The underlying configuration for the model execution that is being requested. - """ - return self.originating_request.model_request - - @property - def originating_request(self) -> SchedulerRequestMessage: - """ - The original request that resulted in the creation of this job. - - Returns - ------- - SchedulerRequestMessage - The original request that resulted in the creation of this job. - """ - return self._originating_request - - def to_dict(self) -> dict: - """ - Get the representation of this instance as a dictionary or dictionary-like object (e.g., a JSON object). - - { - "job_class" : "", - "cpu_count" : 4, - "memory_size" : 1000, - "allocation_paradigm" : "SINGLE_NODE", - "allocation_priority" : 0, - "job_id" : "12345678-1234-5678-1234-567812345678", - "rsa_key_pair" : {}, - "status" : INIT:DEFAULT, - "last_updated" : "2020-07-10 12:05:45", - "allocations" : [...], - 'data_requirements" : [...], - "partitioning" : { "partitions": [ ... ] }, - "originating_request" : {} - } - - Returns - ------- - dict - the representation of this instance as a dictionary or dictionary-like object (e.g., a JSON object) - """ - dictionary = super().to_dict() - # To avoid this being messy, rely on the superclass's implementation and the returned dict, but remove the - # 'model_request' key/value, since this is contained within the originating serialized scheduler request - dictionary.pop('model_request') - dictionary['originating_request'] = self.originating_request.to_dict() - return dictionary + originating_request: SchedulerRequestMessage + """The original request that resulted in the creation of this job.""" + + class Config: # type: ignore + fields = { + # exclude `model_request` during serialization + "model_request": {"exclude": True} + } + + def __init__(self, job_request: SchedulerRequestMessage = None, **data): + if data: + # NOTE: in previous version of code, `model_request` was always a derived field. + # this allows `model_request` be separately specified + if "model_request" in data: + super().__init__(**data) + return + + originating_request = data.get("originating_request") + if originating_request is None: + # this should fail, let pydantic handle that. + super().__init__(**data) + return + + if isinstance(originating_request, SchedulerRequestMessage): + # inject + data["model_request"] = originating_request.model_request + + data["model_request"] = originating_request.get("model_request") + super().__init__(**data) + return + + # NOTE: consider refactoring this into `from_job_request` class method. + super().__init__( + cpu_count=job_request.cpus, + memory_size=job_request.memory, + model_request=job_request.model_request, + allocation_paradigm=job_request.allocation_paradigm, + originating_request=job_request, + ) + # NOTE: this implicitly resets `last_updated` field + self.set_data_requirements(job_request.model_request.data_requirements) diff --git a/python/lib/scheduler/dmod/scheduler/resources/resource.py b/python/lib/scheduler/dmod/scheduler/resources/resource.py index 755f58c07..4cbd45016 100644 --- a/python/lib/scheduler/dmod/scheduler/resources/resource.py +++ b/python/lib/scheduler/dmod/scheduler/resources/resource.py @@ -1,21 +1,28 @@ from abc import ABC, abstractmethod -from enum import Enum -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from typing_extensions import Self +from pydantic import Field, Extra, validator, root_validator +from functools import cache +from warnings import warn -class ResourceAvailability(Enum): +from dmod.core.enum import PydanticEnum +from dmod.core.serializable import Serializable + + +class ResourceAvailability(PydanticEnum): ACTIVE = 1, INACTIVE = 2, UNKNOWN = -1 -class ResourceState(Enum): +class ResourceState(PydanticEnum): READY = 1 NOT_READY = 2, UNKNOWN = -1 -class AbstractProcessingAssetPool(ABC): +class AbstractProcessingAssetPool(Serializable, ABC): """ Abstract representation of some collection of assets used for processing jobs/tasks. @@ -26,10 +33,14 @@ class AbstractProcessingAssetPool(ABC): ::method:`factory_init_from_dict` class method, and serialization using the ::method:`to_dict` method. """ + cpu_count: int + memory: int + pool_id: str + unique_id_separator: str = ":" + @classmethod - @abstractmethod def factory_init_from_dict(cls, init_dict: Dict[str, Any], - ignore_extra_keys: bool = False) -> 'AbstractProcessingAssetPool': + ignore_extra_keys: bool = False) -> Self: """ Initialize a new object from the given dictionary, raising a ::class:`ValueError` if there are missing expected keys or there are extra keys when the method is not set to ignore them. @@ -69,45 +80,24 @@ def factory_init_from_dict(cls, init_dict: Dict[str, Any], TypeError If any parameters sourced from the init dictionary are not of a supported type for that param. """ - pass + original_extra_level = getattr(cls.Config, "extra", None) - def __init__(self, pool_id: str, cpu_count: int, memory: int): - self._pool_id = pool_id - self._cpu_count = cpu_count - self._memory = memory - self.unique_id_separator = ':' + if ignore_extra_keys: + setattr(cls.Config, "extra", Extra.ignore) + else: + setattr(cls.Config, "extra", Extra.forbid) - @property - def cpu_count(self) -> int: - return self._cpu_count + o = cls.parse_obj(init_dict) - @cpu_count.setter - def cpu_count(self, cpu_count: int): - self._cpu_count = cpu_count + if original_extra_level is None: + delattr(cls.Config, "extra") + else: + setattr(cls.Config, "extra", original_extra_level) - @property - def memory(self) -> int: - return self._memory + return o - @memory.setter - def memory(self, memory: int): - self._memory = memory - - @property - def pool_id(self) -> str: - return self._pool_id - - @abstractmethod - def to_dict(self) -> Dict[str, Union[str, int]]: - """ - Convert the object to a serialized dictionary. - - Returns - ------- - Dict[str, Union[str, int]] - The object as a serialized dictionary - """ - pass + class Config: + extra = Extra.forbid @property @abstractmethod @@ -132,13 +122,7 @@ class SingleHostProcessingAssetPool(AbstractProcessingAssetPool, ABC): creation. """ - def __init__(self, pool_id: str, hostname: str, cpu_count: int, memory: int): - super().__init__(pool_id=pool_id, cpu_count=cpu_count, memory=memory) - self._hostname = hostname - - @property - def hostname(self) -> str: - return self._hostname + hostname: str class Resource(SingleHostProcessingAssetPool): @@ -165,63 +149,92 @@ class Resource(SingleHostProcessingAssetPool): are expected to never change for a resource. """ - @classmethod - def factory_init_from_dict(cls, init_dict: dict, ignore_extra_keys: bool = False) -> 'Resource': - """ - Initialize a new object from the given dictionary, raising a ::class:`ValueError` if there are missing expected - keys or there are extra keys when the method is not set to ignore them. + availability: ResourceAvailability + """ + The availability of the resource. - Note that this method will allow ::class:`ResourceAvailability` and ::class:`ResourceState` values for the - init values of ``availability`` and ``state`` respectively, in addition to strings. It will also convert - numeric types from string values appropriately. + Note that the property setter accepts both string and ::class:`ResourceAvailability` values. For a string, the + argument is converted to a ::class:`ResourceAvailability` value using ::method:`get_resource_enum_value`. - Also, unlike other implementations, ``total cpus`` and ``total memory`` are expected keys, but they are not - required. If they are not present, the defaults (the respective available values) are used by the initializer. + However, if the conversion of a string with ::method:`get_resource_enum_value` returns ``None``, the setter + sets ::attribute:`availability` to the ``UNKNOWN`` enum value, rather than ``None``. This is more applicable + and allows the getter to always return an actual ::class:`ResourceAvailability` instance. + """ - parent: - """ - node_id = None - hostname = None - avail = None - state = None - cpus = None - total_cpus = None - memory = None - total_memory = None - - for param_key in init_dict: - # We don't care about non-string keys directly, but they are implicitly extra ... - if not isinstance(param_key, str): - if not ignore_extra_keys: - raise ValueError("Unexpected non-string resource init key") - else: - continue - lower_case_key = param_key.lower() - if lower_case_key == 'node_id' and node_id is None: - node_id = init_dict[param_key] - elif lower_case_key == 'hostname' and hostname is None: - hostname = init_dict[param_key] - elif lower_case_key == 'availability' and avail is None: - avail = init_dict[param_key] - elif lower_case_key == 'state' and state is None: - state = init_dict[param_key] - elif lower_case_key == 'cpus' and cpus is None: - cpus = int(init_dict[param_key]) - elif lower_case_key == 'memorybytes' and memory is None: - memory = int(init_dict[param_key]) - elif lower_case_key == 'total cpus' and total_cpus is None: - total_cpus = int(init_dict[param_key]) - elif lower_case_key == 'total memory' and total_memory is None: - total_memory = int(init_dict[param_key]) - elif not ignore_extra_keys: - raise ValueError("Unexpected resource init key (or case-insensitive duplicate) {}".format(param_key)) - - # Make sure we have everything required set - if node_id is None or hostname is None or cpus is None or memory is None or avail is None or state is None: - raise ValueError("Insufficient valid values keyed within resource init dictionary") - - return cls(resource_id=node_id, hostname=hostname, availability=avail, state=state, cpu_count=cpus, - memory=memory, total_cpu_count=total_cpus, total_memory=total_memory) + state: ResourceState = Field(description="The readiness state of the resource.") + """ + Note that the property setter accepts both string and ::class:`ResourceState` values. For a string, the + argument is converted to a ::class:`ResourceState` value using ::method:`get_resource_enum_value`. + + However, if the conversion of a string with ::method:`get_resource_enum_value` returns ``None``, the setter sets + ::attribute:`state` to the ``UNKNOWN`` enum value, rather than ``None``. This is more applicable and allows the + getter to always return an actual ::class:`ResourceState` instance. + """ + + total_cpus: Optional[int] = Field(description="The total number of CPUs known to be on this resource.") + + total_memory: Optional[int] = Field(description="The total amount of memory known to be on this resource.") + + class Config: + fields = { + "availability": {"alias": "Availability"}, + "cpu_count": {"alias": "CPUs"}, + "hostname": {"alias": "Hostname"}, + "memory": {"alias": "MemoryBytes"}, + "pool_id": {"alias": "node_id"}, + "state": {"alias": "State"}, + "total_cpus": {"alias": "Total CPUs"}, + "total_memory": {"alias": "Total Memory"}, + "unique_id_separator": {"exclude": True} + } + + @validator("availability", pre=True) + def _validate_availability(cls, value: Optional[Any]) -> Union[Any, ResourceAvailability]: + if value is None: + return ResourceAvailability.UNKNOWN + return value + + @validator("state", pre=True) + def _validate_state(cls, value: Optional[Any]) -> Union[Any, ResourceState]: + if value is None: + return ResourceState.UNKNOWN + return value + + @root_validator(pre=True) + def _remap_alias_case_insensitive(cls, values: Dict[str, Any]) -> Dict[str, Any]: + alias_field_map = cls._alias_field_map() + + # NOTE: consider removing this in the future and enforcing case sensitive keys + new_values: Dict[str, Any] = dict() + for k, v in values.items(): + if k.lower() in alias_field_map: + new_values[alias_field_map[k.lower()]] = v + continue + new_values[k] = v + return new_values + + @root_validator() + def _set_total_cpus_and_total_memory_if_unset(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if values.get("total_cpus") is None: + values["total_cpus"] = values["cpu_count"] + + if values.get("total_memory") is None: + values["total_memory"] = values["memory"] + + msg_template = "`{}` cannot be larger than `{}`. {} > {}" + if values["cpu_count"] > values["total_cpus"]: + raise ValueError(msg_template.format("cpu_count", "total_cpus", values["cpu_count"], values["total_cpus"])) + + if values["memory"] > values["total_memory"]: + raise ValueError(msg_template.format("memory", "total_memory", values["memory"], values["total_memory"])) + + return values + + @classmethod + @cache + def _alias_field_map(cls) -> Dict[str, str]: + """Mapping of lower cased alias names to cased alias names.""" + return {v.alias.lower(): v.alias for v in cls.__fields__.values()} @classmethod def generate_unique_id(cls, resource_id: str, separator: str): @@ -239,7 +252,7 @@ def generate_unique_id(cls, resource_id: str, separator: str): str The derived unique id. """ - return cls.__name__ + separator + resource_id + return f"{cls.__name__}{separator}{resource_id}" @classmethod def get_cpu_hash_key(cls) -> str: @@ -251,7 +264,7 @@ def get_cpu_hash_key(cls) -> str: str The hash key value for serialized dictionaries/hashes representations. """ - return 'CPUs' + return "CPUs" @classmethod def get_resource_enum_value(cls, enum_type: Union[Type[ResourceAvailability], Type[ResourceState]], @@ -288,28 +301,53 @@ def get_resource_enum_value(cls, enum_type: Union[Type[ResourceAvailability], Ty return val return None - def __eq__(self, other): + + def __eq__(self, other: object): if not isinstance(other, Resource): return super().__eq__(other) else: return self.resource_id == other.resource_id and self.hostname == other.hostname \ and self.availability == other.availability and self.state == other.state \ and self.cpu_count == other.cpu_count and self.memory == other.memory \ - and self.total_cpu_count == other.total_cpu_count and self.total_memory == other.total_memory - - def __init__(self, resource_id: str, hostname: str, availability: Union[str, ResourceAvailability], - state: Union[str, ResourceState], cpu_count: int, memory: int, total_cpu_count: Optional[int], - total_memory: Optional[int]): - super().__init__(pool_id=resource_id, hostname=hostname, cpu_count=cpu_count, memory=memory) - - self._availability = None - self.availability = availability - - self._state = state - self.state = state - - self._total_cpu_count = cpu_count if total_cpu_count is None else total_cpu_count - self._total_memory = memory if total_memory is None else total_memory + and self.total_cpus == other.total_cpus and self.total_memory == other.total_memory + + def __init__( + self, + resource_id: str = None, + hostname: str = None, + availability: Union[str, ResourceAvailability] = None, + state: Union[str, ResourceState] = None, + cpu_count: int = None, + memory: int = None, + total_cpu_count: Optional[int] = None, + total_memory: Optional[int] = None, + **data + ): + if data: + # NOTE: this can be removed alias field names _are_ case sensitive + potentially_aliased_fields = { + "availability": availability, + "hostname": hostname, + "state": state, + "total_memory": total_memory + } + + for field_name, value in potentially_aliased_fields.items(): + if value is not None: + data[field_name] = value + super().__init__(**data) + return + + super().__init__( + pool_id=resource_id, + hostname=hostname, + cpu_count=cpu_count, + memory=memory, + availability=availability, + state=state, + total_cpus=total_cpu_count, + total_memory=total_memory + ) def allocate(self, cpu_count: int, memory: int) -> Tuple[int, int, bool]: """ @@ -352,32 +390,12 @@ def allocate(self, cpu_count: int, memory: int) -> Tuple[int, int, bool]: self.memory = 0 return allocated_cpus, allocated_mem, is_fully_allocated - @property - def availability(self) -> ResourceAvailability: - """ - The availability of the resource. - - Note that the property setter accepts both string and ::class:`ResourceAvailability` values. For a string, the - argument is converted to a ::class:`ResourceAvailability` value using ::method:`get_resource_enum_value`. - - However, if the conversion of a string with ::method:`get_resource_enum_value` returns ``None``, the setter - sets ::attribute:`availability` to the ``UNKNOWN`` enum value, rather than ``None``. This is more applicable - and allows the getter to always return an actual ::class:`ResourceAvailability` instance. - - Returns - ------- - ResourceAvailability - The availability of the resource. - """ - return self._availability - - @availability.setter - def availability(self, availability: Union[str, ResourceAvailability]): + def set_availability(self, availability: Union[str, ResourceAvailability]): if isinstance(availability, ResourceAvailability): enum_val = availability else: enum_val = self.get_resource_enum_value(ResourceAvailability, availability) - self._availability = ResourceAvailability.UNKNOWN if enum_val is None else enum_val + self.__dict__["availability"] = ResourceAvailability.UNKNOWN if enum_val is None else enum_val def is_allocatable(self) -> bool: """ @@ -412,88 +430,64 @@ def release(self, cpu_count: int, memory: int): self.memory = self.memory + memory @property - def resource_id(self) -> str: - return self.pool_id - - @property - def state(self) -> ResourceState: + def total_cpu_count(self) -> int: """ - The readiness state of the resource. - - Note that the property setter accepts both string and ::class:`ResourceState` values. For a string, the - argument is converted to a ::class:`ResourceState` value using ::method:`get_resource_enum_value`. - - However, if the conversion of a string with ::method:`get_resource_enum_value` returns ``None``, the setter sets - ::attribute:`state` to the ``UNKNOWN`` enum value, rather than ``None``. This is more applicable and allows the - getter to always return an actual ::class:`ResourceState` instance. - + The total number of CPUs known to be on this resource. Returns ------- - ResourceState - The readiness state of the resource. + int + The total number of CPUs known to be on this resource. """ - return self._state + # NOTE: total cpus will be set or derived from `cpu_count` + return self.total_cpus # type: ignore - @state.setter - def state(self, state: Union[str, ResourceState]): + + @property + def resource_id(self) -> str: + return self.pool_id + + def set_state(self, state: Union[str, ResourceState]): if isinstance(state, ResourceState): enum_val = state else: enum_val = self.get_resource_enum_value(ResourceState, state) - self._state = ResourceState.UNKNOWN if enum_val is None else enum_val - - def to_dict(self) -> Dict[str, Union[str, int]]: - """ - Convert the object to a serialized dictionary. + self.__dict__["state"] = ResourceState.UNKNOWN if enum_val is None else enum_val - Key names are as shown in the example below. Enum values are represented as the lower-case version of the name - for the given value. Values shown for CPU and Memory are the max values. + @property + def unique_id(self) -> str: + return self.generate_unique_id(resource_id=self.resource_id, separator=self.unique_id_separator) - E.g.: - { - 'node_id': "Node-0001", - 'Hostname': "my-host", - 'Availability': "active", - 'State': "ready", - 'CPUs': 18, - 'MemoryBytes': 33548128256, - 'Total CPUs': 18, - 'Total Memory: 33548128256 + def _setter_methods(self) -> Dict[str, Callable]: + """Mapping of attribute name to setter method. This supports backwards functional compatibility.""" + # TODO: remove once migration to setters by down stream users is complete + return { + "state": self.set_state, + "availability": self.set_availability, } - Returns - ------- - Dict[str, Union[str, int]] - The object as a serialized dictionary. + def __setattr__(self, name: str, value: Any): """ - return {'node_id': self.resource_id, 'Hostname': self.hostname, 'Availability': self.availability.name.lower(), - 'State': self.state.name.lower(), self.get_cpu_hash_key(): self.cpu_count, 'MemoryBytes': self.memory, - 'Total CPUs': self.total_cpu_count, 'Total Memory': self.total_memory} + Use property setter method when available. - @property - def total_cpu_count(self) -> int: - """ - The total number of CPUs known to be on this resource. + Note, all setter methods should modify their associated property using the instance `__dict__`. + This ensures that calls to, for example, `set_id` don't raise a warning, while `o.id = "new + id"` do. - Returns - ------- - int - The total number of CPUs known to be on this resource. - """ - return self._total_cpu_count + Example: + ``` + class SomeJob(Job): + id: str - @property - def total_memory(self) -> int: + def set_id(self, value: str): + self.__dict__["id"] = value + ``` """ - The total amount of memory known to be on this resource. + if name not in self._setter_methods(): + return super().__setattr__(name, value) - Returns - ------- - int - The total amount of memory known to be on this resource. - """ - return self._total_memory + setter_fn = self._setter_methods()[name] - @property - def unique_id(self) -> str: - return self.generate_unique_id(resource_id=self.resource_id, separator=self.unique_id_separator) + message = f"Setting by attribute is deprecated. Use `{self.__class__.__name__}.{setter_fn.__name__}` method instead." + warn(message, DeprecationWarning) + + setter_fn(value) diff --git a/python/lib/scheduler/dmod/scheduler/resources/resource_allocation.py b/python/lib/scheduler/dmod/scheduler/resources/resource_allocation.py index f78289358..379e341c8 100644 --- a/python/lib/scheduler/dmod/scheduler/resources/resource_allocation.py +++ b/python/lib/scheduler/dmod/scheduler/resources/resource_allocation.py @@ -1,5 +1,6 @@ from datetime import datetime -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union +from pydantic import root_validator, validator from .resource import SingleHostProcessingAssetPool @@ -8,54 +9,36 @@ class ResourceAllocation(SingleHostProcessingAssetPool): Implementation of ::class:`SingleHostProcessingAssetPool` representing a sub-collection of processing assets on a resource that have been allocated for a job. """ - - @classmethod - def factory_init_from_dict(cls, alloc_dict: dict, ignore_extra_keys: bool = False) -> 'ResourceAllocation': - """ - parent: - """ - node_id = None - hostname = None - cpus_allocated = None - mem = None - created = None - separator = None - - for param_key in alloc_dict: - # We don't care about non-string keys directly, but they are implicitly extra ... - if not isinstance(param_key, str): - if not ignore_extra_keys: - raise ValueError("Unexpected non-string allocation key") - else: - continue - lower_case_key = param_key.lower() - if lower_case_key == 'node_id' and node_id is None: - node_id = alloc_dict[param_key] - elif lower_case_key == 'hostname' and hostname is None: - hostname = alloc_dict[param_key] - elif lower_case_key == 'cpus_allocated' and cpus_allocated is None: - cpus_allocated = int(alloc_dict[param_key]) - elif lower_case_key == 'mem' and mem is None: - mem = int(alloc_dict[param_key]) - elif lower_case_key == 'created' and created is None: - created = alloc_dict[param_key] - elif lower_case_key == 'separator' and separator is None: - separator = alloc_dict[param_key] - elif not ignore_extra_keys: - raise ValueError("Unexpected allocation key (or case-insensitive duplicate) {}".format(param_key)) - - # Make sure we have everything required set - if node_id is None or hostname is None or cpus_allocated is None or mem is None: - raise ValueError("Insufficient valid values keyed within allocation dictionary") - - deserialized = cls(resource_id=node_id, hostname=hostname, cpus_allocated=cpus_allocated, requested_memory=mem, - created=created) - if isinstance(separator, str): - deserialized.unique_id_separator = separator - - return deserialized - - def __eq__(self, other): + created: datetime + + class Config: + fields = { + "pool_id": {"alias": "node_id"}, + "hostname": {"alias": "Hostname"}, + "cpu_count": {"alias": "cpus_allocated"}, + "memory": {"alias": "mem"}, + "created": {"alias": "Created"}, + "unique_id_separator": {"alias": "separator"}, + } + field_serializers = { + "created": lambda v: v.timestamp() + } + + @validator("created", pre=True) + def _validate_datetime(cls, value) -> datetime: + if value is None: + return datetime.now() + elif isinstance(value, datetime): + return value + elif isinstance(value, float): + return datetime.fromtimestamp(value) + return datetime.fromtimestamp(float(value)) + + @root_validator(pre=True) + def _lowercase_all_keys(cls, values: Dict[str, Any]) -> Dict[str, Any]: + return {k.lower(): v for k, v in values.items()} + + def __eq__(self, other: object) -> bool: if not isinstance(other, ResourceAllocation): return False else: @@ -65,38 +48,22 @@ def __eq__(self, other): and self.memory == other.memory \ and self.created == other.created - def __init__(self, resource_id: str, hostname: str, cpus_allocated: int, requested_memory: int, - created: Optional[Union[str, float, datetime]] = None): - super().__init__(pool_id=resource_id, hostname=hostname, cpu_count=cpus_allocated, memory=requested_memory) - self._set_created(created) - - def _set_created(self, created: Optional[Union[str, float, datetime]] = None): - """ - A "private" method for setting the ::attribute:`created` property, potentially converting to value to set. - - A ``None`` argument is interpreted as ``now``. Other non-datetime args are interpreted as string or numeric - epoch timestamp representations (i.e., values like those from ::method:`datetime.timestamp`). - - Parameters - ---------- - created - The value to set. - """ - if created is None: - self._created = datetime.now() - elif isinstance(created, datetime): - self._created = created - elif isinstance(created, float): - self._created = datetime.fromtimestamp(created) - else: - self._created = datetime.fromtimestamp(float(created)) - - @property - def created(self) -> datetime: - return self._created + def __init__( + self, + resource_id: str = None, + hostname: str = None, + cpus_allocated: int = None, + requested_memory: int = None, + created: Optional[Union[str, float, datetime]] = None, + **data + ): + if data: + super().__init__(cpus_allocated=cpus_allocated, **data) + return + super().__init__(pool_id=resource_id, hostname=hostname, cpu_count=cpus_allocated, memory=requested_memory, created=created) def get_unique_id(self, separator: str) -> str: - return self.__class__.__name__ + separator + self.resource_id + separator + str(self.created.timestamp()) + return f"{self.__class__.__name__}{separator}{self.resource_id}{separator}{str(self.created.timestamp())}" @property def node_id(self) -> str: @@ -128,10 +95,6 @@ def resource_id(self) -> str: """ return self.pool_id - def to_dict(self) -> Dict[str, Union[str, int]]: - return {'node_id': self.node_id, 'Hostname': self.hostname, 'cpus_allocated': self.cpu_count, - 'mem': self.memory, 'Created': self.created.timestamp(), 'separator': self.unique_id_separator} - @property def unique_id(self) -> str: return self.get_unique_id(self.unique_id_separator) diff --git a/python/lib/scheduler/dmod/scheduler/rsa_key_pair.py b/python/lib/scheduler/dmod/scheduler/rsa_key_pair.py index ffc790ece..fd24f1f44 100644 --- a/python/lib/scheduler/dmod/scheduler/rsa_key_pair.py +++ b/python/lib/scheduler/dmod/scheduler/rsa_key_pair.py @@ -3,29 +3,235 @@ from cryptography.hazmat.backends import default_backend from dmod.core.serializable import Serializable from pathlib import Path -from typing import Dict, Union +from pydantic import Field, PrivateAttr, validator +from typing import ClassVar, Dict, Optional, Tuple, Union +from typing_extensions import Self import datetime import os from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKeyWithSerialization +class _RsaKeyPair(Serializable): + """ + This is a shim object that enables partial instantiation of a :class:`RsaKeyPair`. This class exposes methods and + properties to interact, generate, and write a key pair. However, it does not expose a way to serialize or + deserialize keys and other associated metadata from a dictionary. For the functionality, see :class:`RsaKeyPair`. + """ + + directory: Path + """ + The directory in which the key pair files have been or will be written, as a :class:`Path`. + + If `None` is provided, `directory` defaults to ``$HOME/.ssh/``. If the default or provided directory does not + exists, it and any intermediate directories will be created. Directory inputs that exist and are not directories + (i.e. a file) will raise a ValueError. + """ + + name: str = Field(min_length=1) + """Basename of private key file.""" + + _priv_key: RSAPrivateKeyWithSerialization = PrivateAttr(None) + _priv_key_pem: bytes = PrivateAttr(None) + _is_deserialized: bool = PrivateAttr(False) + + @validator("directory", pre=True) + def _validate_directory(cls, value: Union[str, Path, None]) -> Union[str, Path]: + if value is None: + return Path.home() / ".ssh" + + if isinstance(value, str): + return value.strip() + + return value + + @validator("directory") + def _post_validate_directory(cls, value: Path) -> Path: + if not value.exists(): + value.mkdir(parents=True) + + elif not value.is_dir(): + raise ValueError(f"Existing non-directory file at path provided for key pair directory. {value!r}") + + return value + + @validator("name") + def _validate_name(cls, value: str) -> str: + return value.strip() + + @property + def private_key_file(self) -> Path: + """ + + Returns + ------- + Path + Path to private key file. Is not guaranteed to exist. + """ + return self.directory / self.name + + @property + def public_key_file(self) -> Path: + """ + Same as private key filepath, but with the suffix ".pub". + + Returns + ------- + Path + Path to public key file. Is not guaranteed to exist. + """ + return self.directory / f"{self.name}.pub" + + @property + def private_key_pem(self) -> bytes: + """ + + Returns + ------- + bytes + Encoded private key in PEM format + """ + if self._priv_key_pem is None: + self._priv_key_pem = self._private_key_bytes_from_private_key(self._private_key) + return self._priv_key_pem # type: ignore + + def delete_key_files(self) -> Tuple[bool, bool]: + """ + Delete the files at the paths specified by :attr:`private_key_file` and :attr:`public_key_file`, as long as + there is an existing, regular (i.e., from :method:`Path.is_file`) file at the individual paths. + + Note that whether a delete is performed for one file is independent of what the state of the other. I.e., if + the private key file does not exist, thus resulting in no attempt to delete it, this will not affect whether + there is a delete operation on the public key file. + + Returns + ------- + tuple + A tuple of boolean values, representing whether the private key file and the public key file respectively + were deleted + """ + deleted_private = False + deleted_public = False + if self.private_key_file.exists() and self.private_key_file.is_file(): + self.private_key_file.unlink() + deleted_private = True + if self.public_key_file.exists() and self.public_key_file.is_file(): + self.public_key_file.unlink() + deleted_public = True + return deleted_private, deleted_public + + def write_key_files(self, write_private: bool = True, write_public: bool = True): + """ + Write private and/or public keys to files at :attr:`private_key_file` and :attr:`public_key_file` respectively, + assuming the respective file does not already exist. + + Parameters + ---------- + write_private : bool + An option, ``True`` by default, for whether the private key should be written to :attr:`private_key_file` + + write_public : bool + An option, ``True`` by default, for whether the public key should be written to :attr:`public_key_file` + """ + # if fail to write private key file, delete any existing pub / priv key files. + try: + if write_private and not self.private_key_file.exists(): + self._write_private_key(self._private_key, raise_on_fail=True) + except Exception as e: + if self.public_key_file.exists(): + _, deleted_public = self.delete_key_files() + if not deleted_public: + raise RuntimeError(f"Failed to write private key file. During failure, failed to remove public key file. '{self.public_key_file}'") from e + raise e + + # NOTE: if cannot write pub key file, priv key file, if it exists, will not be removed. + if write_public and not self.public_key_file.exists(): + self._write_public_key(self._private_key, raise_on_fail=True) + + @property + def _private_key(self) -> RSAPrivateKeyWithSerialization: + """ + Serialized private key. Lazily loads private key from :property:`private_key_file` or dynamically generates one. + + If the private key is loaded from :property:`private_key_file` and :property:`public_key_file` does not exist, a + public key is written to disk at :property:`public_key_file`. + """ + if self._priv_key is None and self.private_key_file.exists(): + priv_key_file = self.private_key_file.read_bytes() + self._priv_key = serialization.load_pem_private_key(priv_key_file, None, default_backend()) + self._is_deserialized = True -class RsaKeyPair(Serializable): + self._write_public_key(self._priv_key, overwrite=False, raise_on_fail=True) + + elif self._priv_key is None: + self._priv_key = rsa.generate_private_key(backend=default_backend(), public_exponent=65537, key_size=3072) + + return self._priv_key # type: ignore + + @staticmethod + def _public_key_bytes_from_private_key(private_key: RSAPrivateKeyWithSerialization) -> bytes: + return private_key.public_key().public_bytes(serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH) + + @staticmethod + def _private_key_bytes_from_private_key(private_key: RSAPrivateKeyWithSerialization) -> bytes: + return private_key.private_bytes(encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption()) + + @staticmethod + def _read_private_key_ctime(location: Path) -> datetime.datetime: + return datetime.datetime.fromtimestamp(os.path.getctime(str(location))) + + @staticmethod + def __try_write(content: str, location: Path, overwrite: bool = False, raise_on_fail: bool = False) -> bool: + if not overwrite and location.exists(): + return False + try: + location.write_text(content) + except Exception as e: + if raise_on_fail: + raise e + return False + return True + + def _write_public_key(self, private_key: RSAPrivateKeyWithSerialization, overwrite: bool = False, raise_on_fail: bool = False) -> bool: + pub_key = self._public_key_bytes_from_private_key(private_key).decode("utf-8") + return self.__try_write(pub_key, self.public_key_file, overwrite=overwrite, raise_on_fail=raise_on_fail) + + def _write_private_key(self, private_key: RSAPrivateKeyWithSerialization, overwrite: bool = False, raise_on_fail: bool = False) -> bool: + priv_key = self._private_key_bytes_from_private_key(private_key).decode("utf-8") + return self.__try_write(priv_key, self.private_key_file, overwrite=overwrite, raise_on_fail=raise_on_fail) + + def _delete_existing_key_files_if_priv_keys_differ(self): + # Remove any existing private/public key files unless the contents match serialized private key value + if self.private_key_file.exists(): + priv_key_file_bytes = self.private_key_file.read_bytes() + + if priv_key_file_bytes != self._private_key_bytes_from_private_key(self._private_key): + self.public_key_file.unlink(missing_ok=True) + self.private_key_file.unlink() + raise RuntimeError("Existing private key from file does not match provided private.") + + elif self.public_key_file.exists(): + # Always remove an existing public key file if there was not a private key file + self.public_key_file.unlink() + + +class RsaKeyPair(_RsaKeyPair, Serializable): """ Representation of an RSA key pair and certain meta properties, in particular a name for the key and a pair of :class:`Path` objects for its private and public key files. Keys may be either dynamically generated or deserialized from existing files. Key file basenames are derived from the :attr:`name` value for the object, which is set from an init param that - defaults to ``id_rsa`` if not provided. The public key file will have the same basename as the private key file, - except with the ``.pub`` extension added. + defaults to ``id_rsa`` if not provided. However, :attr:`name` is a required field when initializing from a + dictionary. The public key file will have the same basename as the private key file, except with the ``.pub`` + extension added. When the private key file already exists, the private key will be deserialized from the file contents. This will happen immediately when the object is created. - When the private key file does not already exists, the actual keys will be generated dynamically, though this is - performed lazily. The :method:`generate_key_pair` method will trigger all necessary lazy instantiations and also - cause the key files to be written. + When the private key file does not already exists, the actual keys will be generated dynamically -- but not written + to a file. Use the :method:`write_key_files` to write key pairs to a file. Note that rich comparisons for ``==`` and ``<`` are expressly defined, with the other implementations being derived from these two. @@ -34,19 +240,49 @@ class RsaKeyPair(Serializable): # The basename of the private key file will always be the key pair's name self.name == self.private_key_file.name - # The returned generation time property value will always be equal to the time stamp of the private key file - self.generation_time == datetime.datetime.fromtimestamp(os.path.getctime(str(self.private_key_file))) + """ + private_key: RSAPrivateKeyWithSerialization + """ + Serialized private key for this key pair object. """ - _SERIAL_DATETIME_STR_FORMAT = '%Y-%m-%d %H:%M:%S.%f' - _SERIAL_KEY_DIRECTORY = 'directory' - _SERIAL_KEY_NAME = 'name' - _SERIAL_KEY_PRIVATE_KEY = 'private_key' - _SERIAL_KEY_GENERATION_TIME = 'generation_time' - _SERIAL_KEYS_REQUIRED = [_SERIAL_KEY_NAME, _SERIAL_KEY_DIRECTORY, _SERIAL_KEY_PRIVATE_KEY, _SERIAL_KEY_GENERATION_TIME] + + generation_time: datetime.datetime + + _pub_key: bytes = PrivateAttr(None) + __private_key_text: str = PrivateAttr(None) + + _SERIAL_DATETIME_STR_FORMAT: ClassVar[str] = '%Y-%m-%d %H:%M:%S.%f' + + @validator("generation_time", pre=True) + def _validate_datetime(cls, value: Union[str, datetime.datetime]) -> datetime.datetime: + if isinstance(value, datetime.datetime): + return value + + return datetime.datetime.strptime(value, cls.get_datetime_str_format()) + + @validator("private_key", pre=True) + def _validate_private_key(cls, value: Union[str, RSAPrivateKeyWithSerialization ]) -> RSAPrivateKeyWithSerialization: + if isinstance(value, RSAPrivateKeyWithSerialization): + return value + + priv_key_bytes = value.encode("utf-8") + return serialization.load_pem_private_key(priv_key_bytes, None, default_backend()) + + class Config: # type: ignore + arbitrary_types_allowed = True + validate_assignment = True + def _serialize_datetime(self: "RsaKeyPair", value: datetime.datetime) -> str: + return value.strftime(self.get_datetime_str_format()) + + field_serializers = { + "generation_time": _serialize_datetime, + "private_key": lambda self, _: self._private_key_text, + "directory": lambda directory: str(directory), + } @classmethod - def factory_init_from_deserialized_json(cls, json_obj: Dict[str, str]): + def factory_init_from_deserialized_json(cls, json_obj: Dict[str, str]) -> Optional[Self]: """ Factory create a new instance of this type based on a JSON object dictionary deserialized from received JSON. @@ -73,204 +309,104 @@ def factory_init_from_deserialized_json(cls, json_obj: Dict[str, str]): err_msg_start = 'Cannot deserialize {} object'.format(cls.__name__) try: # Sanity check serialized structure - for key in cls._SERIAL_KEYS_REQUIRED: - if key not in json_obj: - raise RuntimeError('{}: missing required serial {} key'.format(err_msg_start, key)) - # Parse the generation time - gen_time_str = json_obj[cls._SERIAL_KEY_GENERATION_TIME] - try: - gen_time_val = datetime.datetime.strptime(gen_time_str, cls.get_datetime_str_format()) - except: - raise RuntimeError('{}: invalid format for generation time ({})'.format(err_msg_start, gen_time_str)) - # Create the instance, passing serialize values for directory and name - try: - new_obj = RsaKeyPair(directory=json_obj[cls._SERIAL_KEY_DIRECTORY], name=json_obj[cls._SERIAL_KEY_NAME]) - except ValueError as ve: - raise RuntimeError('{}: problem with directory - {}'.format(err_msg_start, str(ve))) - # Manually set the generation time attribute - new_obj._generation_time = gen_time_val - # Set the private key value from serialized data - priv_key_str = json_obj[cls._SERIAL_KEY_PRIVATE_KEY] - priv_key_bytes = priv_key_str.encode('utf-8') - new_obj._priv_key = serialization.load_pem_private_key(priv_key_bytes, None, default_backend()) - # Remove any existing private/public key files unless the contents match serialized private key value - if new_obj.private_key_file.exists(): - try: - with new_obj.private_key_file.open('rb') as priv_key_file: - priv_key_file_bytes = priv_key_file.read() - if priv_key_file_bytes != priv_key_bytes: - raise RuntimeError('clear key file') - except: - new_obj.public_key_file.unlink(missing_ok=True) - new_obj.private_key_file.unlink() - elif new_obj.public_key_file.exists(): - # Always remove an existing public key file if there was not a private key file - new_obj.public_key_file.unlink() - # Finally, return the instance - return new_obj - - except RuntimeError as e: + for field in cls.__fields__.values(): + if field.alias not in json_obj: + raise RuntimeError('{}: missing required serial {} key'.format(err_msg_start, field.alias)) + + o = cls(**json_obj) + o._is_deserialized = True + return o + except: # TODO: log error return None - def __eq__(self, other: 'RsaKeyPair') -> bool: + def __eq__(self, other: Self) -> bool: return other is not None \ and self.generation_time == other.generation_time \ - and self._get_private_key_text() == other._get_private_key_text() \ + and self._private_key_text == other._private_key_text \ and self.private_key_file.absolute() == other.private_key_file.absolute() - def __ge__(self, other): + def __ge__(self, other: Self): return not self < other - def __gt__(self, other): + def __gt__(self, other: Self): return not self <= other - def __init__(self, directory: Union[str, Path, None], name: str = 'id_rsa'): + def __init__(self, directory: Union[str, Path, None], name: str = "id_rsa", **data): """ Initialize an instance. - Initializing an instance, setting the ``directory`` and ``name`` properties, and creating the other required - backing attributes used by the object, setting them to ``None`` (except for ::attribute:`_files_written`, which - is set to ``False``. - Parameters ---------- directory : str, Path, None The path (either as a :class:`Path` or string) to the parent directory for the backing key files, or - ``None`` if the default of ``.ssh/`` in the user's home directory should be used. + ``None`` if the default of ``{$HOME}/.ssh/`` should be used. name : str The name to use for the key pair, which will also be the basename of the private key file and the basis of the basename of the public key file (``id_rsa`` by default). """ - self._name = name.strip() - if self._name is None or len(self._name) < 1: - raise ValueError("Invalid key pair name") - - self.directory = directory - - self._public_key_file = None - self._private_key_file = None - - self._priv_key = None - self._priv_key_pem = None - self._pub_key = None - - self._private_key_text = None - self._public_key_text = None - - self._is_deserialized = None - self._generation_time = None - self._files_written = False - # Track whether actually in the process of writing something already, to not double-write during lazy load - self._is_writing_private_file = False - self._is_writing_public_file = False + # If `data` exists, we assume we are deserializing a message with all required fields. + # NOTE: method, `factory_init_from_deserialized_json`, verifies that all fields are passed + # before trying to initialize. + if data: + super().__init__( + directory=directory, + name=name, + **data + ) + # indirectly set `_private_key` property of parent class `_RsaKeyPair`. + # as a result a public key file will not be created during initialization even if a + # private key file exists and its contents match the passed `private_key` field and a + # public key file does not exist. + self._priv_key = self.private_key + self._delete_existing_key_files_if_priv_keys_differ() + + # If `data` does not exists, partially initialize using fields we have, then derive / create + # all required byt unspecified fields. Then, fully initialize. + else: + key_pair = _RsaKeyPair(directory=directory, name=name) + # lazily generate or load private key + private_key = key_pair._private_key + # could raise `RuntimeError` + key_pair._delete_existing_key_files_if_priv_keys_differ() + key_pair.write_key_files() + generation_time = key_pair._read_private_key_ctime(key_pair.private_key_file) + + super().__init__( + directory=directory, + name=name, + private_key=private_key, + generation_time=generation_time, + ) + + # transfer how the key pair was created + self._is_deserialized = key_pair._is_deserialized + # no one should access this directly nor through property, `_private_key`, but just in case. + self._priv_key = self.private_key def __hash__(self) -> int: - hash_str = '{}:{}:{}'.format(self._get_private_key_text(), + hash_str = '{}:{}:{}'.format(self._private_key_text, str(self.private_key_file.absolute()), self.generation_time.strftime(self.get_datetime_str_format())) - return hash_str.__hash__() + return hash(hash_str) - def __le__(self, other: 'RsaKeyPair') -> bool: + def __le__(self, other: Self) -> bool: return self == other or self < other - def __lt__(self, other: 'RsaKeyPair') -> bool: + def __lt__(self, other: Self) -> bool: if self.generation_time != other.generation_time: return self.generation_time < other.generation_time - elif self._get_private_key_text != other._get_private_key_text: - return self._get_private_key_text < other._get_private_key_text + elif self._private_key_text != other._private_key_text: + return self._private_key_text < other._private_key_text else: return self.private_key_file.absolute() < other.private_key_file.absolute() - def _get_private_key_text(self): - if self._private_key_text is None: - self._load_key_text() - return self._private_key_text - - def _load_key_text(self): - if self._private_key_text is None: - self._private_key_text = self.private_key_pem.decode('utf-8') - if self._public_key_text is None: - self._public_key_text = self.public_key.decode('utf-8') - - def _read_private_key_ctime(self, skip_file_exists_check=False): - if skip_file_exists_check or self.private_key_file.exists(): - return datetime.datetime.fromtimestamp(os.path.getctime(str(self.private_key_file))) - else: - return None - - def delete_key_files(self) -> tuple: - """ - Delete the files at the paths specified by :attr:`private_key_file` and :attr:`public_key_file`, as long as - there is an existing, regular (i.e., from :method:`Path.is_file`) file at the individual paths. - - Note that whether a delete is performed for one file is independent of what the state of the other. I.e., if - the private key file does not exist, thus resulting in no attempt to delete it, this will not affect whether - there is a delete operation on the public key file. - - Returns - ------- - tuple - A tuple of boolean values, representing whether the private key file and the public key file respectively - were deleted - """ - deleted_private = False - deleted_public = False - if self.private_key_file.exists() and self.private_key_file.is_file(): - self.private_key_file.unlink() - deleted_private = True - if self.public_key_file.exists() and self.public_key_file.is_file(): - self.public_key_file.unlink() - deleted_public = True - return deleted_private, deleted_public - - @property - def directory(self) -> Path: - """ - The directory in which the key pair files have been or will be written, as a :class:`Path`. - - The property getter will lazily instantiate the backing attribute to ``/.ssh/`` if the attribute is - set to ``None``. This is done using the property setter function, thus triggering its potential side effects. - - The property setter will accept string or ::class:`Path` objects, as well as ``None``. - - The setter may, as a side effect, create the directory represented by the argument in the filesystem. This is - done in cases when a valid argument other than ``None`` is received, and no file or directory currently exists - in the file system at that path. For string arguments, the string is first stripped of whitespace and converted - to a ::class:`Path` object before checking if the directory should be created. All of this logic is executed - before setting the backing attribute, so if an error is raised, then the attribute value will not be modified. - - In particular, if the setter receives an argument representing a path to an existing, non-directory file, then a - the setter will raise ::class:`ValueError`, and the attribute will remain unchanged. - - Returns - ------- - Path - The directory in which the key pair files have been or will be written - """ - if self._directory is None: - self.directory = Path.home().joinpath(".ssh") - return self._directory - - @directory.setter - def directory(self, d: Union[str, Path, None]): - # Make sure we are working with either None or the equivalent Path object for a path as a string - d_path = Path(d.strip()) if isinstance(d, str) else d - if d_path is not None: - if not d_path.exists(): - d_path.mkdir() - elif not d_path.is_dir(): - raise ValueError("Existing non-directory file at path provided for key pair directory") - self._directory = d_path - @property - def generation_time(self): - if self._generation_time is None: - if not self.private_key_file.exists(): - self.write_key_files() - self._generation_time = self._read_private_key_ctime(skip_file_exists_check=True) - return self._generation_time + def _private_key_text(self) -> str: + if self.__private_key_text is None: + self.__private_key_text = self.private_key_pem.decode("utf-8") + return self.__private_key_text # type: ignore @property def is_deserialized(self) -> bool: @@ -278,159 +414,21 @@ def is_deserialized(self) -> bool: Whether this object was deserialized from an already-existing file or serialized object, as opposed to being created and dynamically generating its keys. - pre: self._is_deserialized is not None or self._priv_key is None - - post: self._is_deserialized is not None and self._priv_key is not None - Returns ------- bool Whether this object was created from a pre-existing private key file """ - if self._is_deserialized is None: - # We don't actually need the value directly, but the lazy instantiation will set _is_deserialized as a side- - # effect, since it intrinsically has to determine whether it can/should deserialized the private key - priv_key = self.private_key return self._is_deserialized @property - def name(self): - return self._name - - @property - def private_key(self) -> RSAPrivateKeyWithSerialization: - """ - Get the private key for this key pair object, lazily instantiating if necessary either through deserialization - or by dynamically generating a key. - - Note that, since lazy instantiation requires determining if the value should be deserialized, the attribute - backing the :attr:`is_deserialized` property is set as a side effect when performing that step. - - post: self._is_deserialized is not None - - Returns - ------- - RSAPrivateKeyWithSerialization - The actual RSA private key object - """ - if self._priv_key is None and self.private_key_file.exists(): - with self.private_key_file.open('rb') as priv_key_file: - self._priv_key = serialization.load_pem_private_key(priv_key_file.read(), None, default_backend()) - if not self.public_key_file.exists(): - self.write_key_files(write_private=False) - self._files_written = True - self._is_deserialized = True - elif self._priv_key is None: - self._priv_key = rsa.generate_private_key(backend=default_backend(), public_exponent=65537, key_size=3072) - return self._priv_key - - @property - def private_key_file(self) -> Path: - """ - Get the path to the private key file, lazily instantiating using the :attr:`name` and :method:`directory`. - - Returns - ------- - Path - The path to the private key file - """ - if self._private_key_file is None: - self._private_key_file = None if self.directory is None else self.directory.joinpath(self._name) - return self._private_key_file - - @property - def private_key_pem(self): - if self._priv_key_pem is None: - self._priv_key_pem = self.private_key.private_bytes(encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption()) - return self._priv_key_pem - - @property - def public_key(self): + def public_key(self) -> bytes: if self._pub_key is None: - self._pub_key = self.private_key.public_key().public_bytes(serialization.Encoding.OpenSSH, - serialization.PublicFormat.OpenSSH) + self._pub_key = self._public_key_bytes_from_private_key(self.private_key) return self._pub_key - @property - def public_key_file(self) -> Path: - """ - Get the path to the public key file, lazily instantiating based on the :attr:`name` and :method:`directory`. - - Returns - ------- - Path - The path to the public key file - """ - if self._public_key_file is None: - self._public_key_file = None if self.directory is None else self.directory.joinpath(self._name + '.pub') - return self._public_key_file - - def to_dict(self) -> Dict[str, str]: - """ - Serialize to a dictionary representation of string keys and values. - - The format is as follows: - - { - 'name': 'name_value', - 'directory': 'directory_path_as_string', - 'private_key': 'private_key_text', - 'generation_time': 'generation_time_str' - } - - Returns - ------- - Dict[str, str] - The serialized form of this instance as a dictionary object with string keys and string values. - """ - return { - self._SERIAL_KEY_NAME: self.name, - self._SERIAL_KEY_DIRECTORY: str(self.directory), - self._SERIAL_KEY_PRIVATE_KEY: self._get_private_key_text(), - self._SERIAL_KEY_GENERATION_TIME: self.generation_time.strftime(self.get_datetime_str_format()) - } - - def write_key_files(self, write_private=True, write_public=True): - """ - Write private and/or public keys to files at :attr:`private_key_file` and :attr:`public_key_file` respectively, - assuming the respective file does not already exist. - - Parameters - ---------- - write_private : bool - An option, ``True`` by default, for whether the private key should be written to :attr:`private_key_file` - - write_public : bool - An option, ``True`` by default, for whether the public key should be written to :attr:`public_key_file` - """ - # Keep track of whether we are in the process of writing public/private files. - # Also, adjust parameter values based on whether this is nested inside another call due to lazy loading. - # I.e., both the param and the corresponding instance variable will only be True for the highest applicable - # call/scope in the stack. - if self._is_writing_private_file: - write_private = False - else: - self._is_writing_private_file = write_private - - if self._is_writing_public_file: - write_public = False - else: - self._is_writing_public_file = write_public - - # Next, actually perform the writes, loading things as necessary via property getters - try: - self._load_key_text() - if write_private and not self.private_key_file.exists(): - self.private_key_file.write_text(self._get_private_key_text()) - self._is_deserialized = False - if write_public and not self.public_key_file.exists(): - self.public_key_file.write_text(self._public_key_text) - finally: - # Finally, put back instance values to False appropriately if True and the param is True (indicating this is - # the highest call in the stack and should not be skipped for the public/private key file) - if self._is_writing_private_file and write_private: - self._is_writing_private_file = False - if self._is_writing_public_file and write_public: - self._is_writing_public_file = False + def write_key_files(self, write_private: bool = True, write_public: bool = True): + super().write_key_files(write_private=write_private, write_public=write_public) + if write_private: + # update generation time + self.generation_time = self._read_private_key_ctime(self.private_key_file) diff --git a/python/lib/scheduler/dmod/test/test_JobImpl.py b/python/lib/scheduler/dmod/test/test_JobImpl.py index 256304189..44a2f3739 100644 --- a/python/lib/scheduler/dmod/test/test_JobImpl.py +++ b/python/lib/scheduler/dmod/test/test_JobImpl.py @@ -1,24 +1,26 @@ import unittest -from ..scheduler.job.job import JobImpl +from ..scheduler.job.job import JobImpl, JobStatus, JobExecPhase, JobExecStep from ..scheduler.resources.resource_allocation import ResourceAllocation from dmod.communication import NWMRequest from uuid import UUID +from typing import List + class TestJobImpl(unittest.TestCase): def setUp(self) -> None: self._nwm_model_request = NWMRequest.factory_init_from_deserialized_json( - {"model": {"nwm": {"version": 2.0, "output": "streamflow", "domain": "blah", "parameters": {}}}, + {"model": {"nwm": {"version": 2.0, "output": "streamflow", "domain": "blah", "parameters": {}, "config_data_id": "42"}}, "session-secret": "f21f27ac3d443c0948aab924bddefc64891c455a756ca77a4d86ec2f697cd13c"}) - self._example_jobs = [] + self._example_jobs: List[JobImpl]= [] self._example_jobs.append(JobImpl(4, 1000, model_request=self._nwm_model_request, allocation_paradigm='single-node')) - self._uuid_str_vals = [] + self._uuid_str_vals: List[str] = [] self._uuid_str_vals.append('12345678-1234-5678-1234-567812345678') - self._resource_allocations = [] + self._resource_allocations: List[ResourceAllocation] = [] self._resource_allocations.append(ResourceAllocation('node001', 'node001', 4, 1000)) def tearDown(self) -> None: @@ -189,5 +191,177 @@ def test_allocations_1_f(self): self.assertLess(initial_last_updated, job.last_updated) # TODO: add tests for rest of setters that should update last_updated property + def test_set_allocation_priority(self): + """ + Update allocation priority. + This should implicitly change the instance's `last_updated` field to the current time. + """ + example_index_job = 0 + job = self._example_jobs[example_index_job] + outdated_last_updated = job.last_updated + prior_allocation_priority = job.allocation_priority + + job.set_allocation_priority(prior_allocation_priority + 1) + self.assertEqual(job.allocation_priority, prior_allocation_priority + 1) + self.assertGreater(job.last_updated, outdated_last_updated) + + def test_add_allocation(self): + """ + Test that a resource allocation is added and that the instance's `last_updated` field is implicitly updated. + """ + example_index_job = 0 + job = self._example_jobs[example_index_job] + resource_allocation = self._resource_allocations[example_index_job] + + # we should not have any allocations up to this point + self.assertIsNone(job.allocations) + outdated_last_updated = job.last_updated + + job.add_allocation(resource_allocation) + + self.assertIsNotNone(job.allocations) + self.assertIsInstance(job.allocations, tuple) + self.assertEqual(len(job.allocations), 1) # type: ignore + + self.assertEqual(job.allocations[0], resource_allocation) # type: ignore + + self.assertGreater(job.last_updated, outdated_last_updated) + + def test_set_allocations(self): + """ + Test setting resource allocations and that the instance's `last_updated` field is implicitly updated. + """ + example_index_job = 0 + job = self._example_jobs[example_index_job] + resource_allocation = self._resource_allocations[example_index_job] + + # we should not have any allocations up to this point + self.assertIsNone(job.allocations) + outdated_last_updated = job.last_updated + + job.set_allocations((resource_allocation, )) + + self.assertIsNotNone(job.allocations) + self.assertIsInstance(job.allocations, tuple) + self.assertEqual(len(job.allocations), 1) # type: ignore + + self.assertEqual(job.allocations[0], resource_allocation) # type: ignore + + # assert `last_updated` was updated and is greater than previous value + self.assertGreater(job.last_updated, outdated_last_updated) + + def test_set_data_requirements(self): + # importing here, not needed elsewhere + from dmod.core.meta_data import DataRequirement, DataCategory, DataDomain, DataFormat, DiscreteRestriction, StandardDatasetIndex + example_index_job = 0 + job = self._example_jobs[example_index_job] + + outdated_last_updated = job.last_updated + + domain = DataDomain( + data_format=DataFormat.NWM_CONFIG, + discrete=[DiscreteRestriction(variable=StandardDatasetIndex.DATA_ID, values=["42"])] + ) + data_reqs = [DataRequirement(category=DataCategory.CONFIG, domain=domain, is_input=True)] + + # data requirements should be an empty list at this point + self.assertFalse(job.data_requirements) + job.set_data_requirements(data_reqs) + + self.assertTrue(job.data_requirements) + self.assertIsInstance(job.data_requirements, list) + self.assertEqual(len(job.data_requirements), 1) # type: ignore + + # assert `last_updated` was updated and is greater than previous value + self.assertGreater(job.last_updated, outdated_last_updated) + + def test_set_job_id(self): + from uuid import UUID + example_index_job = 0 + job = self._example_jobs[example_index_job] + + fake_job_ids = ["00000000-0000-0000-0000-000000000000", UUID("11111111-1111-1111-1111-111111111111")] + + # test setting with `str` and `UUID` + for i, job_id in enumerate(fake_job_ids): + with self.subTest(i=i): + old_last_updated = job.last_updated + old_job_id = job.job_id + + self.assertIsInstance(old_job_id, str) + + job.set_job_id(job_id) + self.assertEqual(str(job_id), job.job_id) + + # assert `last_updated` was updated and is greater than previous value + self.assertGreater(job.last_updated, old_last_updated) + + def test_set_partition_config(self): + from dmod.modeldata.hydrofabric import Partition, PartitionConfig + + example_index_job = 0 + job = self._example_jobs[example_index_job] + + partition_config = PartitionConfig(partitions=[Partition(partition_id=42, catchment_ids=["42"], nexus_ids=["42"])]) + + # we should not have any partition configs up to this point + self.assertIsNone(job.partition_config) + job.set_partition_config(partition_config) + self.assertEqual(job.partition_config, partition_config) + + def test_set_rsa_key_pair(self): + from ..scheduler.rsa_key_pair import RsaKeyPair + from tempfile import TemporaryDirectory + example_index_job = 0 + job = self._example_jobs[example_index_job] + outdated_last_updated = job.last_updated + + self.assertIsNone(job.rsa_key_pair) + + with TemporaryDirectory() as dir: + key_pair = RsaKeyPair(directory=dir) + job.set_rsa_key_pair(key_pair) + self.assertEqual(job.rsa_key_pair, key_pair) + + # assert `last_updated` was updated and is greater than previous value + self.assertGreater(job.last_updated, outdated_last_updated) + + def test_set_status(self): + example_index_job = 0 + job = self._example_jobs[example_index_job] + outdated_last_updated = job.last_updated + + status = JobStatus(phase=None) + self.assertNotEqual(status, job.status) + job.set_status(status) + + self.assertEqual(status, job.status) + self.assertGreater(job.last_updated, outdated_last_updated) + + def test_set_status_phase(self): + example_index_job = 0 + job = self._example_jobs[example_index_job] + outdated_last_updated = job.last_updated + + new_status_phase = JobExecPhase.MODEL_EXEC + self.assertNotEqual(job.status_phase, new_status_phase) + + job.set_status_phase(new_status_phase) + self.assertEqual(job.status_phase, new_status_phase) + + # assert `last_updated` was implicitly updated and is greater than previous value + self.assertGreater(job.last_updated, outdated_last_updated) + + def test_set_status_step(self): + example_index_job = 0 + job = self._example_jobs[example_index_job] + outdated_last_updated = job.last_updated + + new_status_step = JobExecStep.AWAITING_ALLOCATION + self.assertNotEqual(job.status_phase, new_status_step) + + job.set_status_step(new_status_step) + self.assertEqual(job.status_step, new_status_step) - # TODO: add tests for status_phase and status_step + # assert `last_updated` was implicitly updated and is greater than previous value + self.assertGreater(job.last_updated, outdated_last_updated) diff --git a/python/lib/scheduler/dmod/test/test_job.py b/python/lib/scheduler/dmod/test/test_job.py index 9a9f8e5cc..c7b1c533f 100644 --- a/python/lib/scheduler/dmod/test/test_job.py +++ b/python/lib/scheduler/dmod/test/test_job.py @@ -2,14 +2,18 @@ from ..scheduler.job.job import Job, JobImpl, RequestedJob from dmod.core.meta_data import TimeRange from dmod.communication import NWMRequest, NGENRequest, SchedulerRequestMessage +from typing import Any, List, TYPE_CHECKING + +if TYPE_CHECKING: + from dmod.communication import ModelExecRequest class TestJob(unittest.TestCase): def setUp(self) -> None: - self._example_jobs = [] - self._model_requests = [] - self._model_requests_json = [] + self._example_jobs: List[RequestedJob] = [] + self._model_requests: List["ModelExecRequest"]= [] + self._model_requests_json: Dict[str, Any] = [] # Example 0 - simple JobImpl instance based on NWMRequest for model_request value self._model_requests_json.append({ diff --git a/python/lib/scheduler/dmod/test/test_resource.py b/python/lib/scheduler/dmod/test/test_resource.py new file mode 100644 index 000000000..148bcceca --- /dev/null +++ b/python/lib/scheduler/dmod/test/test_resource.py @@ -0,0 +1,124 @@ +import unittest +from pydantic import ValidationError +from .scheduler_test_utils import _mock_resources +from ..scheduler.resources import Resource, ResourceAvailability + + +class TestResource(unittest.TestCase): + def setUp(self) -> None: + self._resource = Resource( + resource_id="1", + hostname="somehost", + availability="active", + state="ready", + cpu_count=4, + memory=(2 ** 30) * 8, + total_cpu_count=8, + total_memory=(2 ** 30) * 16, + ) + + def test_factory_init_from_dict_coerces_fields_correctly(self): + for i, input in enumerate(_mock_resources): + with self.subTest(i=i): + o = Resource.factory_init_from_dict(input) + assert o.resource_id == input["node_id"] + assert o.pool_id == input["node_id"] + assert o.hostname == input["Hostname"] + assert ( + o.availability.name.casefold() == input["Availability"].casefold() + ) + assert o.state.name.casefold() == input["State"].casefold() + assert o.memory == input["MemoryBytes"] + assert o.cpu_count == input["CPUs"] + assert o.total_cpus == input["CPUs"] + assert o.total_memory == input["MemoryBytes"] + + def test_factory_init_from_dict_works_case_insensitively(self): + input = { + "NODE_ID": "Node-0003", + "hostname": "hostname3", + "AVAILABILITY": "active", + "state": "ready", + "CPUS": 42, + "memorybytes": 200000000000, + } + o = Resource.factory_init_from_dict(input) + assert o.resource_id == input["NODE_ID"] + assert o.pool_id == input["NODE_ID"] + assert o.hostname == input["hostname"] + assert o.availability.name.casefold() == input["AVAILABILITY"].casefold() + assert o.state.name.casefold() == input["state"].casefold() + assert o.memory == input["memorybytes"] + assert o.cpu_count == input["CPUS"] + assert o.total_cpus == input["CPUS"] + assert o.total_memory == input["memorybytes"] + + def test_set_availability(self): + resource = self._resource + availability = ResourceAvailability.UNKNOWN + resource.set_availability(availability) + assert resource.availability == ResourceAvailability.UNKNOWN + + availability = ResourceAvailability.ACTIVE + resource.set_availability(availability) + assert resource.availability == ResourceAvailability.ACTIVE + + availability = ResourceAvailability.INACTIVE + resource.set_availability(availability) + assert resource.availability == ResourceAvailability.INACTIVE + + resource.set_availability("unknown") + assert resource.availability == ResourceAvailability.UNKNOWN + + resource.set_availability("active") + assert resource.availability == ResourceAvailability.ACTIVE + + resource.set_availability("inactive") + assert resource.availability == ResourceAvailability.INACTIVE + + # remove in future + with self.assertWarns(DeprecationWarning): + availability = ResourceAvailability.UNKNOWN + resource.availability = availability + assert resource.availability == ResourceAvailability.UNKNOWN + + with self.assertWarns(DeprecationWarning): + availability = ResourceAvailability.ACTIVE + resource.availability = availability + assert resource.availability == ResourceAvailability.ACTIVE + + with self.assertWarns(DeprecationWarning): + availability = ResourceAvailability.INACTIVE + resource.availability = availability + assert resource.availability == ResourceAvailability.INACTIVE + + def test_eq(self): + resource = self._resource + assert resource == resource + assert resource == Resource.factory_init_from_dict(resource.to_dict()) + + def test_init_with_more_cpu_than_total_cpu(self): + with self.assertRaises(ValidationError): + Resource( + cpu_count=8, + total_cpu_count=4, + resource_id="1", + hostname="somehost", + availability="active", + state="ready", + memory=8, + total_memory=8, + ) + + def test_init_with_more_memory_than_total_memory(self): + with self.assertRaises(ValidationError): + Resource( + memory=8, + total_memory=4, + resource_id="1", + hostname="somehost", + availability="active", + state="ready", + cpu_count=8, + total_cpu_count=8, + ) diff --git a/python/lib/scheduler/dmod/test/test_rsa_key_pair.py b/python/lib/scheduler/dmod/test/test_rsa_key_pair.py index e4d1276ec..6615ed9b5 100644 --- a/python/lib/scheduler/dmod/test/test_rsa_key_pair.py +++ b/python/lib/scheduler/dmod/test/test_rsa_key_pair.py @@ -1,16 +1,23 @@ import unittest +from pathlib import Path +from tempfile import TemporaryDirectory from ..scheduler.rsa_key_pair import RsaKeyPair +from typing import Dict class TestRsaKeyPair(unittest.TestCase): def setUp(self) -> None: - self.rsa_key_pairs = dict() + self.rsa_key_pairs: Dict[int, RsaKeyPair] = dict() self.rsa_key_pairs[1] = RsaKeyPair(directory='.', name='id_rsa_1') + self.serial_rsa_key_pairs: Dict[int, dict] = dict() + self.serial_rsa_key_pairs[1] = self.rsa_key_pairs[1].to_dict() + + def tearDown(self) -> None: - self.rsa_key_pairs[1].private_key_file.unlink() - self.rsa_key_pairs[1].public_key_file.unlink() + self.rsa_key_pairs[1].private_key_file.unlink(missing_ok=True) + self.rsa_key_pairs[1].public_key_file.unlink(missing_ok=True) def test_generate_key_pair_1_a(self): """ @@ -37,7 +44,7 @@ def test_generate_key_pair_1_c(self): # This should result in the same file names as key_pair, and so the constructor should resolve that it needs to # load the key, not regenerate it reserialized_key = RsaKeyPair(directory=key_pair.directory, name=key_pair.name) - self.assertTrue(key_pair, reserialized_key) + self.assertEqual(key_pair, reserialized_key) def test_generate_key_pair_1_d(self): """ @@ -61,4 +68,157 @@ def test_generate_key_pair_1_e(self): # This should result in the same file names as key_pair, and so the constructor should resolve that it needs to # load the key, not regenerate it reserialized_key = RsaKeyPair(directory=key_pair.directory, name=key_pair.name) - self.assertTrue(key_pair.private_key_pem, reserialized_key.private_key_pem) + self.assertEqual(key_pair.private_key_pem, reserialized_key.private_key_pem) + + def test_generate_key_pair_1_from_dict_a(self): + """ + """ + key_pair = self.rsa_key_pairs[1] + key_pair.write_key_files() + # This should result in the same file names as key_pair, and so the constructor should resolve that it needs to + # load the key, not regenerate it + + key_pair_dict = key_pair.to_dict() + key_pair_from_dict = RsaKeyPair.factory_init_from_deserialized_json(key_pair_dict) + self.assertEqual(key_pair_from_dict, key_pair) + + def test_delete_key_files(self): + """ + Verify that the `delete_key_files` method deletes both public and private key _if they + exist_ to start with. + """ + key_pair = self.rsa_key_pairs[1] + key_pair.delete_key_files() + self.assertFalse(key_pair.private_key_file.exists()) + self.assertFalse(key_pair.public_key_file.exists()) + + def test_factory_init_from_deserialized_json_does_not_write_key_files_on_init(self): + """ + verify key files are not created if they do not already exist on factory init. + """ + key_pair = self.rsa_key_pairs[1] + kp_as_dict = key_pair.to_dict() + key_pair.delete_key_files() + + kp_from_factory = RsaKeyPair.factory_init_from_deserialized_json(kp_as_dict) + + assert kp_from_factory is not None + self.assertFalse(kp_from_factory.private_key_file.exists()) + self.assertFalse(kp_from_factory.public_key_file.exists()) + + def test_factory_init_from_deserialized_json_verifies_private_key_matches_successfully(self): + """ + verify key files are not created if they do not already exist on factory init. + """ + key_pair = self.rsa_key_pairs[1] + self.assertTrue(key_pair.private_key_file.exists()) + + kp_as_dict = key_pair.to_dict() + kp_from_factory = RsaKeyPair.factory_init_from_deserialized_json(kp_as_dict) + assert kp_from_factory is not None + # this should have been called in __init__ + kp_from_factory._delete_existing_key_files_if_priv_keys_differ() # type: ignore + self.assertTrue(kp_from_factory.private_key_file.exists()) + + def test_factory_init_from_deserialized_json_does_not_write_pub_key_file_when_priv_exists(self): + """ + verify pub key file is not created by factory init if priv key file already exists. + are not created if they do not already exist on factory init. + """ + key_pair = self.rsa_key_pairs[1] + self.assertTrue(key_pair.private_key_file.exists()) + self.assertTrue(key_pair.public_key_file.exists()) + + key_pair.public_key_file.unlink(missing_ok=True) + self.assertFalse(key_pair.public_key_file.exists()) + + kp_as_dict = key_pair.to_dict() + kp_from_factory = RsaKeyPair.factory_init_from_deserialized_json(kp_as_dict) + assert kp_from_factory is not None + + # main concern being tested + self.assertFalse(kp_from_factory.public_key_file.exists()) + + self.assertTrue(kp_from_factory.private_key_file.exists()) + + def test_factory_init_from_deserialized_json_is_deserialized(self): + """ + verify object `is_deserialized` property is true on factory init with no key files on disk. + """ + key_pair = self.rsa_key_pairs[1] + + kp_as_dict = key_pair.to_dict() + + # remove key files + key_pair.delete_key_files() + self.assertFalse(key_pair.private_key_file.exists()) + self.assertFalse(key_pair.public_key_file.exists()) + + kp_from_factory = RsaKeyPair.factory_init_from_deserialized_json(kp_as_dict) + assert kp_from_factory is not None + + # main concern being tested + self.assertTrue(kp_from_factory.is_deserialized) + + def test_factory_init_from_deserialized_json_is_deserialized_with_key_files_present(self): + """ + verify object `is_deserialized` property is true on factory init key files on disk. + """ + key_pair = self.serial_rsa_key_pairs[1] + + kp_from_factory = RsaKeyPair.factory_init_from_deserialized_json(key_pair) + assert kp_from_factory is not None + + # main concern being tested + self.assertTrue(kp_from_factory.is_deserialized) + + def test_is_deserialized_is_false_when_key_is_generated(self): + """ + verify object `is_deserialized` property is false when key is generated. + """ + with TemporaryDirectory() as dir: + key_pair = RsaKeyPair(directory=dir, name="test_is_deserialized") + self.assertFalse(key_pair.is_deserialized) + + def test_is_deserialized_is_true_when_key_is_present(self): + """ + verify object `is_deserialized` property is false when key is generated. + """ + key_pair = self.rsa_key_pairs[1] + kp = RsaKeyPair(directory=key_pair.directory, name=key_pair.name) + self.assertTrue(kp.is_deserialized) + + def test_reassign_directory_to_default(self): + """ + verify object `is_deserialized` property is false when key is generated. + """ + key_pair = self.rsa_key_pairs[1] + default_location = Path.home() / ".ssh" + self.assertNotEqual(key_pair.directory, default_location) + + o_pub_key = key_pair.public_key_file + o_priv_key = key_pair.private_key_file + + key_pair.directory = None + self.assertEqual(key_pair.directory, default_location) + + # remove original public key and private key + o_priv_key.unlink(missing_ok=True) + o_pub_key.unlink(missing_ok=True) + + def test_reassign_directory_creates_directory_if_not_exist(self): + """ + verify object `is_deserialized` property is false when key is generated. + """ + key_pair = self.rsa_key_pairs[1] + with TemporaryDirectory() as dir: + dir = Path(dir) + new_dir = dir / ".ssh" + + self.assertFalse(new_dir.exists()) + self.assertNotEqual(key_pair.directory, new_dir) + + key_pair.directory = new_dir + + self.assertTrue(new_dir.exists()) + self.assertEqual(key_pair.directory, new_dir) diff --git a/python/lib/scheduler/setup.py b/python/lib/scheduler/setup.py index 817e02aa4..7e10fb7cc 100644 --- a/python/lib/scheduler/setup.py +++ b/python/lib/scheduler/setup.py @@ -21,7 +21,7 @@ url='', license='', install_requires=['docker', 'Faker', 'dmod-communication>=0.8.0', 'dmod-modeldata>=0.7.1', 'dmod-redis>=0.1.0', - 'dmod-core>=0.2.0', 'cryptography', 'uri', 'pyyaml'], + 'dmod-core>=0.2.0', 'cryptography', 'uri', 'pyyaml', 'pydantic'], packages=find_namespace_packages(exclude=['dmod.test', 'src']) ) diff --git a/requirements.txt b/requirements.txt index f3142c57c..a7a99a3e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,3 +29,4 @@ channels channels-redis Pint django_rq +pydantic