diff --git a/docs/_templates/README.md b/docs/_templates/README.md new file mode 100644 index 000000000..78625de95 --- /dev/null +++ b/docs/_templates/README.md @@ -0,0 +1,5 @@ +# `templates` directory + +This directory is appended to the `templates_path` Sphinx config variable in `conf.py`. Jinja templates in this directory are available to Sphinx. This is used to provide an Autosummary template that supports an automatic API reference. + +Note that templates used for the *theme itself* are not placed in this directory! \ No newline at end of file diff --git a/docs/_templates/autosummary/base.rst b/docs/_templates/autosummary/base.rst new file mode 100644 index 000000000..35816c986 --- /dev/null +++ b/docs/_templates/autosummary/base.rst @@ -0,0 +1,15 @@ +{% block title -%} + +{{ ("``" ~ objname ~ "``") | underline}} + +{%- endblock %} +{% block base %} + +.. currentmodule:: {{ module }} + +.. auto{{ objtype }}:: {{ objname }} + {% if objtype in ["attribute", "data"] -%} + :no-value: + {%- endif %} + +{%- endblock %} diff --git a/docs/_templates/autosummary/class.rst b/docs/_templates/autosummary/class.rst new file mode 100644 index 000000000..ad9380a46 --- /dev/null +++ b/docs/_templates/autosummary/class.rst @@ -0,0 +1,65 @@ +{% block title -%} + +.. raw:: html + +
+ +{{ ("``" ~ objname ~ "``") | underline('=')}} + +.. raw:: html + +
+ +{%- endblock %} +{% block base %} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :members: + :member-order: alphabetical {# For consistency with Autosummary #} + {% if show_inherited_members %}:inherited-members: + {% endif %}{% if show_undoc_members %}:undoc-members: + {% endif %}{% if show_inheritance %}:show-inheritance: + {% endif %} + + {% block methods %} + + {%- set doc_methods = [] -%} + {%- for item in methods -%} + {%- if item not in ["__new__", "__init__"] and (show_inherited_members or item not in inherited_members) -%} + {%- set _ = doc_methods.append(item) -%} + {%- endif -%} + {%- endfor %} + + {% if doc_methods %} + .. rubric:: {{ _('Methods') }} + + .. autosummary:: + :nosignatures: + {% for item in doc_methods %} + ~{{ name }}.{{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block attributes %} + + {%- set doc_attributes = [] -%} + {%- for item in attributes -%} + {%- if show_inherited_members or item not in inherited_members -%} + {%- set _ = doc_attributes.append(item) -%} + {%- endif -%} + {%- endfor %} + + {% if doc_attributes %} + .. rubric:: {{ _('Attributes') }} + + .. autosummary:: + :nosignatures: + {% for item in doc_attributes %} + ~{{ name }}.{{ item }} + {%- endfor %} + {% endif %} + {% endblock %} +{% endblock %} diff --git a/docs/_templates/autosummary/module.rst b/docs/_templates/autosummary/module.rst new file mode 100644 index 000000000..b7691f7b7 --- /dev/null +++ b/docs/_templates/autosummary/module.rst @@ -0,0 +1,106 @@ +{% block title -%} + +{{ ("``" ~ objname ~ "``") | underline('=')}} + +{%- endblock %} +{% block base %} + +.. automodule:: {{ fullname }} + :no-members: + :no-inherited-members: + :no-special-members: + + {% block modules %} + + {%- set included_modules = [] -%} + {%- for item in modules -%} + {%- if item not in exclude_modules -%} + {%- set _ = included_modules.append(item) -%} + {%- endif -%} + {%- endfor -%} + + {% if included_modules %} + .. rubric:: Modules + + .. autosummary:: + :caption: Modules + :toctree: + :recursive: + {% for item in included_modules %} + ~{{ item }} + {%- endfor %} + + {% endif %} + {% endblock %} + + {% block attributes %} + {% if attributes %} + .. rubric:: {{ _('Module Attributes') }} + + .. autosummary:: + :caption: Attributes + :toctree: + :nosignatures: + {% for item in attributes %} + {{ item }} + {%- endfor %} + {% endif %} + {%- endblock -%} + + {% block functions %} + {% if functions %} + .. rubric:: {{ _('Functions') }} + + .. autosummary:: + :caption: Functions + :toctree: + :nosignatures: + {% for item in functions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block classes %} + + {%- set types = [] -%} + {%- for item in members -%} + {%- if not item.startswith('_') and not ( + item in functions + or item in attributes + or item in exceptions + or fullname ~ "." ~ item in modules + or item in methods + ) -%} + {%- set _ = types.append(item) -%} + {%- endif -%} + {%- endfor %} + + {% if types %} + .. rubric:: {{ _('Classes') }} + + .. autosummary:: + :caption: Classes + :toctree: + :nosignatures: + {% for item in types %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block exceptions %} + {% if exceptions %} + .. rubric:: {{ _('Exceptions') }} + + .. autosummary:: + :caption: Exceptions + :toctree: + :nosignatures: + {% for item in exceptions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + +{% endblock %} diff --git a/docs/concepts/tokenizables.rst b/docs/concepts/tokenizables.rst index 1adb8c3a7..6046b88db 100644 --- a/docs/concepts/tokenizables.rst +++ b/docs/concepts/tokenizables.rst @@ -238,9 +238,9 @@ Any nested ``GufeTokenizable``\s are left as-is. ChemicalSystem(name=phenol-solvent, components={'ligand': SmallMoleculeComponent(name=phenol), 'solvent': SolventComponent(name=O, K+, Cl-)}) ], 'edges': [ - Transformation(stateA=ChemicalSystem(name=benzene-solvent, components={'ligand': SmallMoleculeComponent(name=benzene), 'solvent': SolventComponent(name=O, K+, Cl-)}), stateB=ChemicalSystem(name=toluene-solvent, components={'ligand': SmallMoleculeComponent(name=toluene), 'solvent': SolventComponent(name=O, K+, Cl-)}), protocol=, name=None), - Transformation(stateA=ChemicalSystem(name=benzene-solvent, components={'ligand': SmallMoleculeComponent(name=benzene), 'solvent': SolventComponent(name=O, K+, Cl-)}), stateB=ChemicalSystem(name=styrene-solvent, components={'ligand': SmallMoleculeComponent(name=styrene), 'solvent': SolventComponent(name=O, K+, Cl-)}), protocol=, name=None), - Transformation(stateA=ChemicalSystem(name=benzene-solvent, components={'ligand': SmallMoleculeComponent(name=benzene), 'solvent': SolventComponent(name=O, K+, Cl-)}), stateB=ChemicalSystem(name=phenol-solvent, components={'ligand': SmallMoleculeComponent(name=phenol), 'solvent': SolventComponent(name=O, K+, Cl-)}), protocol=, name=None) + Transformation(stateA=ChemicalSystem(name=benzene-solvent, components={'ligand': SmallMoleculeComponent(name=benzene), 'solvent': SolventComponent(name=O, K+, Cl-)}), stateB=ChemicalSystem(name=toluene-solvent, components={'ligand': SmallMoleculeComponent(name=toluene), 'solvent': SolventComponent(name=O, K+, Cl-)}), protocol=, name=None), + Transformation(stateA=ChemicalSystem(name=benzene-solvent, components={'ligand': SmallMoleculeComponent(name=benzene), 'solvent': SolventComponent(name=O, K+, Cl-)}), stateB=ChemicalSystem(name=styrene-solvent, components={'ligand': SmallMoleculeComponent(name=styrene), 'solvent': SolventComponent(name=O, K+, Cl-)}), protocol=, name=None), + Transformation(stateA=ChemicalSystem(name=benzene-solvent, components={'ligand': SmallMoleculeComponent(name=benzene), 'solvent': SolventComponent(name=O, K+, Cl-)}), stateB=ChemicalSystem(name=phenol-solvent, components={'ligand': SmallMoleculeComponent(name=phenol), 'solvent': SolventComponent(name=O, K+, Cl-)}), protocol=, name=None) ], 'name': None, '__qualname__': 'AlchemicalNetwork', @@ -310,7 +310,7 @@ To show the structure of a keyed chain, below we have redacted all information e ('SmallMoleculeComponent-3b51f5f92521c712049da092ab061930', {...}), ('SmallMoleculeComponent-ec3c7a92771f8872dab1a9fc4911c795', {...}), ('SmallMoleculeComponent-8225dfb11f2e8157a3fcdcd673d3d40e', {...}), - ('Protocol-d01baed9cf2500c393bd6ddb35ee38aa', {...}), + ('Protocol-489fb1395a32c5183bcc1d43fa521960', {...}), ('ChemicalSystem-ba83a53f18700b3738680da051ff35f3', { 'components': { 'ligand': {':gufe-key:': 'SmallMoleculeComponent-3b51f5f92521c712049da092ab061930'}, @@ -332,12 +332,12 @@ To show the structure of a keyed chain, below we have redacted all information e ('Transformation-e8d1ccf53116e210d1ccbc3870007271', { 'stateA': {':gufe-key:': 'ChemicalSystem-3c648332ff8dccc03a1e1a3d44bc9755'}, 'stateB': {':gufe-key:': 'ChemicalSystem-ba83a53f18700b3738680da051ff35f3'}, - 'protocol': {':gufe-key:': 'DummyProtocol-d01baed9cf2500c393bd6ddb35ee38aa'}, + 'protocol': {':gufe-key:': 'DummyProtocol-489fb1395a32c5183bcc1d43fa521960'}, ...}), ('Transformation-4d0f802817071c8d14b37efd35187318', { 'stateA': {':gufe-key:': 'ChemicalSystem-3c648332ff8dccc03a1e1a3d44bc9755'}, 'stateB': {':gufe-key:': 'ChemicalSystem-655f4d0008a537fe811b11a2dc4a029e'}, - 'protocol': {':gufe-key:': 'DummyProtocol-d01baed9cf2500c393bd6ddb35ee38aa'}, + 'protocol': {':gufe-key:': 'DummyProtocol-489fb1395a32c5183bcc1d43fa521960'}, ...}), ('AlchemicalNetwork-f8bfd63bc848672aa52b081b4d68fadf', { 'nodes': [ diff --git a/docs/conf.py b/docs/conf.py index 3ae6eecf8..d24a7c8e1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -49,6 +49,9 @@ "undoc-members": True, } +# TODO: temporary workaround to get docs to build I figure out why only OpenMMSystemGeneratorFFSettings GufeQuantities won't serialize. +autodoc_pydantic_model_show_json_error_strategy = "coerce" + autosummary_generate = True intersphinx_mapping = { @@ -68,7 +71,6 @@ "rdkit", ] - # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for diff --git a/docs/environment.yaml b/docs/environment.yaml index a82af8c9c..47444d394 100644 --- a/docs/environment.yaml +++ b/docs/environment.yaml @@ -3,7 +3,7 @@ channels: - https://conda.anaconda.org/jaimergp/label/unsupported-cudatoolkit-shim - https://conda.anaconda.org/conda-forge dependencies: -- autodoc-pydantic +- autodoc-pydantic >=2.0 - openff-units - python=3.12 - sphinx diff --git a/environment.yml b/environment.yml index 896975cdf..6c3eb34f0 100644 --- a/environment.yml +++ b/environment.yml @@ -13,7 +13,7 @@ dependencies: - pint - pip - pooch - - pydantic >1 + - pydantic >=2.0 - pytest - pytest-cov - pytest-xdist @@ -25,4 +25,4 @@ dependencies: - sphinx-jsonschema==1.15 - sphinx <7.1.2 - pip: - - autodoc_pydantic<2.0.0 + - autodoc_pydantic>=2.0.0 diff --git a/gufe/components/proteincomponent.py b/gufe/components/proteincomponent.py index 7294ce82e..283a2b00f 100644 --- a/gufe/components/proteincomponent.py +++ b/gufe/components/proteincomponent.py @@ -56,8 +56,8 @@ # ions and charges pulled from amber: -# https://github.com/Amber-MD/AmberClassic/blob/42e88bf9a2214ba008140280713a430f3ecd4a90/dat/leap/lib/atomic_ions.lib#L1C1-L68C6 -ions_dict = { +# see `amber list bool: SETTINGS_CODEC = JSONCodec( cls=SettingsBaseModel, - to_dict=lambda obj: {field: getattr(obj, field) for field in obj.__fields__}, + to_dict=lambda obj: {field: getattr(obj, field) for field in obj.model_fields}, from_dict=default_from_dict, is_my_dict=functools.partial(inherited_is_my_dict, cls=SettingsBaseModel), ) diff --git a/gufe/serialization/msgpack.py b/gufe/serialization/msgpack.py index 9a59bb063..6d7cf57d8 100644 --- a/gufe/serialization/msgpack.py +++ b/gufe/serialization/msgpack.py @@ -68,7 +68,7 @@ def pack_default(obj) -> msgpack.ExtType: return msgpack.ExtType(MPEXT.NPGENERIC, npg_payload) case SettingsBaseModel(): - settings_data = {field: getattr(obj, field) for field in obj.__fields__} + settings_data = {field: getattr(obj, field) for field in obj.model_fields} settings_data.update({"__class__": obj.__class__.__qualname__, "__module__": obj.__class__.__module__}) settings_payload: bytes = msgpack.packb(settings_data, default=pack_default) return msgpack.ExtType(MPEXT.SETTINGS, settings_payload) diff --git a/gufe/settings/models.py b/gufe/settings/models.py index d18054283..980781a9c 100644 --- a/gufe/settings/models.py +++ b/gufe/settings/models.py @@ -6,40 +6,26 @@ import abc import pprint -from typing import Optional, Union +from typing import Annotated, Any, Literal, TypeAlias +from annotated_types import Ge from openff.units import unit +from pydantic import BeforeValidator, ConfigDict, Field, InstanceOf, PositiveFloat, PrivateAttr, field_validator -from gufe.vendor.openff.models.models import DefaultModel -from gufe.vendor.openff.models.types import FloatQuantity - -try: - from pydantic.v1 import Extra, Field, PositiveFloat, PrivateAttr, validator -except ImportError: - from pydantic import ( - Extra, - Field, - PositiveFloat, - PrivateAttr, - validator, - ) +from ..vendor.openff.interchange.pydantic import _BaseModel +from .types import AtmQuantity, GufeQuantity, KelvinQuantity, NanometerQuantity, specify_quantity_units -import pydantic +VoltsQuantity: TypeAlias = Annotated[GufeQuantity, specify_quantity_units("volts")] -class SettingsBaseModel(DefaultModel): +class SettingsBaseModel(_BaseModel): """Settings and modifications we want for all settings classes.""" _is_frozen: bool = PrivateAttr(default_factory=lambda: False) - - class Config: - """ - :noindex: - """ - - extra = "forbid" - arbitrary_types_allowed = False - smart_union = True + model_config = ConfigDict( + extra="forbid", + arbitrary_types_allowed=True, # needed to parse custom types + ) def _ipython_display_(self): pprint.pprint(self.dict()) @@ -50,11 +36,11 @@ def frozen_copy(self): This is intended to be used by Protocols to make their stored Settings read-only """ - copied = self.copy(deep=True) + copied = self.model_copy(deep=True) def freeze_model(model): submodels = ( - mod for field in model.__fields__ if isinstance(mod := getattr(model, field), SettingsBaseModel) + mod for field in model.model_fields if isinstance(mod := getattr(model, field), SettingsBaseModel) ) for mod in submodels: freeze_model(mod) @@ -71,11 +57,11 @@ def unfrozen_copy(self): Settings objects become frozen when within a Protocol. If you *really* need to reverse this, this method is how. """ - copied = self.copy(deep=True) + copied = self.model_copy(deep=True) def unfreeze_model(model): submodels = ( - mod for field in model.__fields__ if isinstance(mod := getattr(model, field), SettingsBaseModel) + mod for field in model.model_fields if isinstance(mod := getattr(model, field), SettingsBaseModel) ) for mod in submodels: unfreeze_model(mod) @@ -100,6 +86,15 @@ def __setattr__(self, name, value): ) return super().__setattr__(name, value) + def __eq__(self, other: Any) -> bool: + # reproduces pydantic v1 equality, since v2 checks for private attr equality, + # which results in frozen/unfrozen objects not being equal + # https://github.com/pydantic/pydantic/blob/2486e068e85c51728c9f2d344cfee2f7e11d555c/pydantic/v1/main.py#L911 + if isinstance(other, _BaseModel): + return self.model_dump() == other.model_dump() + else: + return self.model_dump() == other + class ThermoSettings(SettingsBaseModel): """Settings for thermodynamic parameters. @@ -109,40 +104,40 @@ class ThermoSettings(SettingsBaseModel): possible. """ - temperature: FloatQuantity["kelvin"] = Field(None, description="Simulation temperature, default units kelvin") - pressure: FloatQuantity["standard_atmosphere"] = Field( - None, description="Simulation pressure, default units standard atmosphere (atm)" - ) + temperature: KelvinQuantity | None = Field(None, description="Simulation temperature in kelvin)") + pressure: AtmQuantity | None = Field(None, description="Simulation pressure in standard atmosphere (atm)") ph: PositiveFloat | None = Field(None, description="Simulation pH") - redox_potential: float | None = Field(None, description="Simulation redox potential") + redox_potential: VoltsQuantity | None = Field(None, description="Simulation redox potential in millivolts (mV).") class BaseForceFieldSettings(SettingsBaseModel, abc.ABC): """Base class for ForceFieldSettings objects""" - class Config: - """:noindex:""" + ... - pass - ... +def _to_lowercase(value: Any): + """make any string input lowercase""" + if isinstance(value, (str)): + return value.lower() + else: + return value class OpenMMSystemGeneratorFFSettings(BaseForceFieldSettings): """Parameters to set up the force field with OpenMM ForceFields .. note:: - Right now we just basically just grab what we need for the - :class:`openmmforcefields.system_generators.SystemGenerator` - signature. See the `OpenMMForceField SystemGenerator documentation`_ - for more details. + Currently, this stores what is needed for the + :class:`openmmforcefields.system_generators.SystemGenerator` signature. + See the `OpenMMForceField SystemGenerator documentation`_ for more details. .. _`OpenMMForceField SystemGenerator documentation`: https://github.com/openmm/openmmforcefields#automating-force-field-management-with-systemgenerator """ - constraints: str | None = "hbonds" + constraints: Annotated[Literal["hbonds", "allbonds", "hangles"], BeforeValidator(_to_lowercase)] | None = "hbonds" """Constraints to be applied to system. One of 'hbonds', 'allbonds', 'hangles' or None, default 'hbonds'""" rigid_water: bool = True @@ -164,43 +159,24 @@ class OpenMMSystemGeneratorFFSettings(BaseForceFieldSettings): nonbonded_method: str = "PME" """ - Method for treating nonbonded interactions, currently only PME and - NoCutoff are allowed. Default PME. + Method for treating nonbonded interactions, options are currently + "CutoffNonPeriodic", "CutoffPeriodic", "Ewald", "LJPME", "NoCutoff", "PME". + Default PME. """ - nonbonded_cutoff: FloatQuantity["nanometer"] = 1.0 * unit.nanometer - """ - Cutoff value for short range nonbonded interactions. - Default 1.0 * unit.nanometer. - """ - - @validator("nonbonded_method") - def allowed_nonbonded(cls, v): - if v.lower() not in ["pme", "nocutoff"]: - errmsg = "Only PME and NoCutoff are allowed nonbonded_methods" - raise ValueError(errmsg) - return v + # TODO: currently, serialization scheme doesn't work for default values since we're not using PlainValidator + # see https://github.com/pydantic/pydantic/issues/11446 + nonbonded_cutoff: Annotated[NanometerQuantity, Ge(0)] = Field( + default=1.0 * unit.nanometer, description="Cutoff value for short range nonbonded interactions." + ) - @validator("nonbonded_cutoff") - def is_positive_distance(cls, v): - # these are time units, not simulation steps - if not v.is_compatible_with( - unit.nanometer - ): # TODO: invalid units get caught earlier and so this code is never executed - raise ValueError("nonbonded_cutoff must be in distance units (i.e. nanometers)") - if v < 0: - errmsg = "nonbonded_cutoff must be a positive value" + @field_validator("nonbonded_method", mode="after") + def allowed_nonbonded_methods(cls, v): + options = ["CutoffNonPeriodic", "CutoffPeriodic", "Ewald", "LJPME", "NoCutoff", "PME"] + if v.lower() not in [x.lower() for x in options]: + errmsg = f"Only {options} are allowed nonbonded_methods" raise ValueError(errmsg) return v - @validator("constraints") - def constraint_check(cls, v): - allowed = {"hbonds", "hangles", "allbonds"} - - if not (v is None or v.lower() in allowed): - raise ValueError(f"Bad constraints value, use one of {allowed}") - - return v - class Settings(SettingsBaseModel): """ @@ -211,8 +187,8 @@ class Settings(SettingsBaseModel): Protocols can subclass this to extend this to cater for their additional settings. """ - forcefield_settings: BaseForceFieldSettings - thermo_settings: ThermoSettings + forcefield_settings: InstanceOf[BaseForceFieldSettings] + thermo_settings: InstanceOf[ThermoSettings] @classmethod def get_defaults(cls): diff --git a/gufe/settings/types.py b/gufe/settings/types.py new file mode 100644 index 000000000..221ce76a7 --- /dev/null +++ b/gufe/settings/types.py @@ -0,0 +1,122 @@ +# adapted from from https://github.com/openforcefield/openff-interchange/blob/main/openff/interchange/_annotations.py +""" +Custom types that inherit from openff.units.Quantity and are pydantic-compatible. +""" + +from typing import Annotated, Any, TypeAlias + +from openff.units import Quantity +from pydantic import ( + AfterValidator, + BeforeValidator, + GetCoreSchemaHandler, +) +from pydantic_core import core_schema + +from ..vendor.openff.interchange._annotations import ( + _unit_validator_factory, + _unwrap_list_of_openmm_quantities, + quantity_json_serializer, + quantity_validator, + _BoxQuantity as BoxQuantity, +) + + +class _QuantityPydanticAnnotation: + @classmethod + def __get_pydantic_core_schema__( + cls, + source: Any, + handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + """ + This Annotation lets us define a GufeQuantity that is identical to + an openff-units Quantity, except it's also pydantic-compatible. + """ + json_schema = core_schema.with_info_wrap_validator_function( + function=quantity_validator, schema=core_schema.float_schema() + ) + python_schema = core_schema.with_info_wrap_validator_function( + function=quantity_validator, + schema=core_schema.is_instance_schema(Quantity), + ) + + serialize_schema = core_schema.wrap_serializer_function_ser_schema(quantity_json_serializer) + return core_schema.json_or_python_schema( + json_schema=json_schema, + python_schema=python_schema, + serialization=serialize_schema, + ) + + +GufeQuantity = Annotated[Quantity, _QuantityPydanticAnnotation] + + +def specify_quantity_units(unit_name: str) -> AfterValidator: + """Helper function for generating custom quantity types. + + Parameters + ---------- + unit_name : str + unit name to validate against (e.g. 'nanometer') + + Returns + ------- + AfterValidator + An AfterValidator for defining a custom Quantity type. + + + + """ + + return AfterValidator(_unit_validator_factory(unit_name)) + + +NanometerQuantity: TypeAlias = Annotated[ + GufeQuantity, + specify_quantity_units("nanometer"), +] +"""Convert a pint.Quantity or to nanometers, if possible.""" + +AtmQuantity: TypeAlias = Annotated[ + GufeQuantity, + specify_quantity_units("atm"), +] +"""Convert a pint.Quantity or to atm, if possible.""" + +KelvinQuantity: TypeAlias = Annotated[ + GufeQuantity, + specify_quantity_units("kelvin"), +] +"""Convert a pint.Quantity or to kelvin, if possible.""" + +# types used elsewhere in the ecosystem +NanosecondQuantity: TypeAlias = Annotated[ + GufeQuantity, + specify_quantity_units("nanosecond"), +] +"""Convert a pint.Quantity or to nanoseconds, if possible.""" + + +PicosecondQuantity: TypeAlias = Annotated[ + GufeQuantity, + specify_quantity_units("picosecond"), +] +"""Convert a pint.Quantity or to picoseconds, if possible.""" + +AngstromQuantity: TypeAlias = Annotated[ + GufeQuantity, + specify_quantity_units("angstrom"), +] +"""Convert a pint.Quantity or to angstroms, if possible.""" + +KCalPerMolQuantity: TypeAlias = Annotated[ + GufeQuantity, + specify_quantity_units("kilocalorie_per_mole"), +] +"""Convert a pint.Quantity or to kcal/mol, if possible.""" + +GufeArrayQuantity: TypeAlias = Annotated[ + GufeQuantity, + BeforeValidator(_unwrap_list_of_openmm_quantities), +] \ No newline at end of file diff --git a/gufe/tests/test_models.py b/gufe/tests/test_models.py index 4fcf81a2b..ce9313ea6 100644 --- a/gufe/tests/test_models.py +++ b/gufe/tests/test_models.py @@ -15,69 +15,126 @@ def test_settings_schema(): """Settings schema should be stable""" expected_schema = { - "title": "Settings", - "description": "Container for all settings needed by a protocol\n\nThis represents the minimal surface that all settings objects will have.\n\nProtocols can subclass this to extend this to cater for their additional settings.", - "type": "object", - "properties": { - "forcefield_settings": {"$ref": "#/definitions/BaseForceFieldSettings"}, - "thermo_settings": {"$ref": "#/definitions/ThermoSettings"}, - }, - "required": ["forcefield_settings", "thermo_settings"], - "additionalProperties": False, - "definitions": { + "$defs": { "BaseForceFieldSettings": { - "title": "BaseForceFieldSettings", + "additionalProperties": False, "description": "Base class for ForceFieldSettings objects", - "type": "object", "properties": {}, - "additionalProperties": False, + "title": "BaseForceFieldSettings", + "type": "object", }, "ThermoSettings": { - "title": "ThermoSettings", + "additionalProperties": False, "description": "Settings for thermodynamic parameters.\n\n.. note::\n No checking is done to ensure a valid thermodynamic ensemble is\n possible.", - "type": "object", "properties": { "temperature": { + "anyOf": [{"type": "number"}, {"type": "null"}], + "default": None, + "description": "Simulation temperature in kelvin)", "title": "Temperature", - "description": "Simulation temperature, default units kelvin", - "type": "number", }, "pressure": { + "anyOf": [{"type": "number"}, {"type": "null"}], + "default": None, + "description": "Simulation pressure in standard atmosphere (atm)", "title": "Pressure", - "description": "Simulation pressure, default units standard atmosphere (atm)", - "type": "number", }, - "ph": {"title": "Ph", "description": "Simulation pH", "exclusiveMinimum": 0, "type": "number"}, + "ph": { + "anyOf": [{"exclusiveMinimum": 0, "type": "number"}, {"type": "null"}], + "default": None, + "description": "Simulation pH", + "title": "Ph", + }, "redox_potential": { + "anyOf": [{"type": "number"}, {"type": "null"}], + "default": None, + "description": "Simulation redox potential in millivolts (mV).", "title": "Redox Potential", - "description": "Simulation redox potential", - "type": "number", }, }, - "additionalProperties": False, + "title": "ThermoSettings", + "type": "object", }, }, + "additionalProperties": False, + "description": "Container for all settings needed by a protocol\n\nThis represents the minimal surface that all settings objects will have.\n\nProtocols can subclass this to extend this to cater for their additional settings.", + "properties": { + "forcefield_settings": {"$ref": "#/$defs/BaseForceFieldSettings", "title": "Forcefield Settings"}, + "thermo_settings": {"$ref": "#/$defs/ThermoSettings", "title": "Thermo Settings"}, + }, + "required": ["forcefield_settings", "thermo_settings"], + "title": "Settings", + "type": "object", } - schema = Settings.schema() - assert schema == expected_schema + ser_schema = Settings.model_json_schema(mode="serialization") + val_schema = Settings.model_json_schema(mode="validation") + # TODO: should our serialization and validation schemas really be the same? + assert ser_schema == expected_schema + assert val_schema == expected_schema + + +def test_openmmffsettings_schema(): + expected_schema = { + "additionalProperties": False, + "description": "Parameters to set up the force field with OpenMM ForceFields\n\n.. note::\n Currently, this stores what is needed for the\n :class:`openmmforcefields.system_generators.SystemGenerator` signature.\n See the `OpenMMForceField SystemGenerator documentation`_ for more details.\n\n\n.. _`OpenMMForceField SystemGenerator documentation`:\n https://github.com/openmm/openmmforcefields#automating-force-field-management-with-systemgenerator", + "properties": { + "constraints": { + "anyOf": [{"enum": ["hbonds", "allbonds", "hangles"], "type": "string"}, {"type": "null"}], + "default": "hbonds", + "title": "Constraints", + }, + "rigid_water": {"default": True, "title": "Rigid Water", "type": "boolean"}, + "hydrogen_mass": {"default": 3.0, "title": "Hydrogen Mass", "type": "number"}, + "forcefields": { + "default": [ + "amber/ff14SB.xml", + "amber/tip3p_standard.xml", + "amber/tip3p_HFE_multivalent.xml", + "amber/phosaa10.xml", + ], + "items": {"type": "string"}, + "title": "Forcefields", + "type": "array", + }, + "small_molecule_forcefield": { + "default": "openff-2.1.1", + "title": "Small Molecule Forcefield", + "type": "string", + }, + "nonbonded_method": {"default": "PME", "title": "Nonbonded Method", "type": "string"}, + "nonbonded_cutoff": { + "description": "Cutoff value for short range nonbonded interactions.", + "ge": 0, + "title": "Nonbonded Cutoff", + "type": "number", + }, + }, + "title": "OpenMMSystemGeneratorFFSettings", + "type": "object", + } + ser_schema = OpenMMSystemGeneratorFFSettings.model_json_schema(mode="serialization") + val_schema = OpenMMSystemGeneratorFFSettings.model_json_schema(mode="validation") + assert ser_schema == expected_schema + assert val_schema == expected_schema def test_default_settings(): my_settings = Settings.get_defaults() my_settings.thermo_settings.temperature = 298 * unit.kelvin - my_settings.json() - my_settings.schema_json(indent=2) + my_settings.model_dump_json() + json.dumps(my_settings.model_json_schema(mode="serialization"), indent=2) class TestSettingsValidation: @pytest.mark.parametrize( "value,valid,expected", [ - ("parsnips", False, None), # shouldn't be allowed + ("Parsnips", False, None), # shouldn't be allowed + (1.0, False, None), # shouldn't be allowed ("hbonds", True, "hbonds"), ("hangles", True, "hangles"), ("allbonds", True, "allbonds"), # allowed options - ("HBonds", True, "HBonds"), # check case insensitivity TODO: cast this to lower? + ("HBonds", True, "hbonds"), # check case insensitivity (None, True, None), ], ) @@ -93,12 +150,13 @@ def test_openmmff_constraints(self, value, valid, expected): "value,valid,expected", [ (1.0 * unit.nanometer, True, 1.0 * unit.nanometer), - (1.0, True, 1.0 * unit.nanometer), # should cast float to nanometer + (0 * unit.nanometer, True, 0 * unit.nanometer), + (1.0, False, None), # requires a length unit. ("1.1 nm", True, 1.1 * unit.nanometer), - ("1.1 ", False, None), - (0, True, 0 * unit.nanometer), + ("1.1", False, None), (-1.0 * unit.nanometer, False, None), - # (1.0 * unit.angstrom, True, 0.100 * unit.nanometer), # TODO: why does this not work? + # NOTE: this is not precisely equal for smaller values due to pint unit floating point precision + (100.0 * unit.angstrom, True, 10.0 * unit.nanometer), (300 * unit.kelvin, False, None), (True, False, None), (None, False, None), @@ -116,7 +174,7 @@ def test_openmmff_nonbonded_cutoff(self, value, valid, expected): @pytest.mark.parametrize( "value,valid,expected", [ - ("pme", True, "pme"), + ("NoCutoff", True, "NoCutoff"), ("NOCUTOFF", True, "NOCUTOFF"), ("no cutoff", False, None), (1.0, False, None), @@ -127,24 +185,23 @@ def test_openmmff_nonbonded_method(self, value, valid, expected): s = OpenMMSystemGeneratorFFSettings(nonbonded_method=value) assert s.nonbonded_method == expected else: - with pytest.raises(ValueError, match="Only PME and NoCutoff are allowed"): + with pytest.raises(ValueError): _ = OpenMMSystemGeneratorFFSettings(nonbonded_method=value) @pytest.mark.parametrize( "value,valid,expected", [ (298 * unit.kelvin, True, 298 * unit.kelvin), - (298, True, 298 * unit.kelvin), - (298.0, True, 298 * unit.kelvin), ("298 kelvin", True, 298 * unit.kelvin), + (298, False, None), # requires units ("298", False, None), (298 * unit.angstrom, False, None), ], ) def test_thermo_temperature(self, value, valid, expected): if valid: - s = ThermoSettings(temperature=value) - assert s.temperature == expected + settings = ThermoSettings(temperature=value) + assert settings.temperature == expected else: with pytest.raises(ValueError): _ = ThermoSettings(temperature=value) @@ -153,7 +210,7 @@ def test_thermo_temperature(self, value, valid, expected): "value,valid,expected", [ (1.0 * unit.atm, True, 1.0 * unit.atm), - (1.0, True, 1.0 * unit.atm), + (1.0, False, None), # require units ("1 atm", True, 1.0 * unit.atm), ("1.0", False, None), ], @@ -186,9 +243,13 @@ def test_thermo_ph(self, value, valid, expected): @pytest.mark.parametrize( "value,valid,expected", [ - (1.0, True, 1.0), (None, True, None), - ("1", True, 1.0), + (1 * unit.mV, True, 1 * unit.mV), + ("1.0 mV", True, 1 * unit.mV), + ("0.001 volts", True, 1 * unit.mV), + (0.001 * unit.volt, True, 1 * unit.mV), + (0.001 * unit.nanometer, False, None), + ("0.001 nm", False, None), ], ) def test_thermo_redox(self, value, valid, expected): @@ -236,10 +297,11 @@ def test_frozen_equality(self): # the frozen-ness of Settings doesn't alter its contents # therefore a frozen/unfrozen Settings which are otherwise identical # should be considered equal - s = Settings.get_defaults() - s2 = s.frozen_copy() + s1 = Settings.get_defaults() + s2 = s1.frozen_copy() - assert s == s2 + # TODO: equality checks have changed in v2 such that this is no longer true + assert s1 == s2 def test_set_subsection(self): # check that attempting to set a subsection of settings still respects diff --git a/gufe/tests/test_serialization_migration.py b/gufe/tests/test_serialization_migration.py index 11991a8e6..0275e26a2 100644 --- a/gufe/tests/test_serialization_migration.py +++ b/gufe/tests/test_serialization_migration.py @@ -233,11 +233,11 @@ def __init__(self, settings: GrandparentSettings): self.settings = settings def _to_dict(self): - return {"settings": self.settings.dict()} + return {"settings": self.settings.model_dump()} @classmethod def _from_dict(cls, dct): - settings = GrandparentSettings.parse_obj(dct["settings"]) + settings = GrandparentSettings.model_validate(dct["settings"]) return cls(settings=settings) @classmethod diff --git a/gufe/tests/test_transformation.py b/gufe/tests/test_transformation.py index c70bdac3f..7b79a0297 100644 --- a/gufe/tests/test_transformation.py +++ b/gufe/tests/test_transformation.py @@ -35,7 +35,7 @@ def complex_equilibrium(solvated_complex): class TestTransformation(GufeTokenizableTestsMixin): cls = Transformation - repr = "Transformation(stateA=ChemicalSystem(name=, components={'ligand': SmallMoleculeComponent(name=toluene), 'solvent': SolventComponent(name=O, K+, Cl-)}), stateB=ChemicalSystem(name=, components={'protein': ProteinComponent(name=), 'solvent': SolventComponent(name=O, K+, Cl-), 'ligand': SmallMoleculeComponent(name=toluene)}), protocol=, name=None)" + repr = "Transformation(stateA=ChemicalSystem(name=, components={'ligand': SmallMoleculeComponent(name=toluene), 'solvent': SolventComponent(name=O, K+, Cl-)}), stateB=ChemicalSystem(name=, components={'protein': ProteinComponent(name=), 'solvent': SolventComponent(name=O, K+, Cl-), 'ligand': SmallMoleculeComponent(name=toluene)}), protocol=, name=None)" @pytest.fixture def instance(self, absolute_transformation): @@ -174,7 +174,7 @@ def test_deprecation_warning_on_dict_mapping(self, solvated_ligand, solvated_com class TestNonTransformation(GufeTokenizableTestsMixin): cls = NonTransformation - repr = "NonTransformation(system=ChemicalSystem(name=, components={'protein': ProteinComponent(name=), 'solvent': SolventComponent(name=O, K+, Cl-), 'ligand': SmallMoleculeComponent(name=toluene)}), protocol=, name=None)" + repr = "NonTransformation(system=ChemicalSystem(name=, components={'protein': ProteinComponent(name=), 'solvent': SolventComponent(name=O, K+, Cl-), 'ligand': SmallMoleculeComponent(name=toluene)}), protocol=, name=None)" @pytest.fixture def instance(self, complex_equilibrium): diff --git a/gufe/vendor/openff/models/__init__.py b/gufe/vendor/openff/interchange/__init__.py similarity index 100% rename from gufe/vendor/openff/models/__init__.py rename to gufe/vendor/openff/interchange/__init__.py diff --git a/gufe/vendor/openff/interchange/_annotations.py b/gufe/vendor/openff/interchange/_annotations.py new file mode 100644 index 000000000..65192bad6 --- /dev/null +++ b/gufe/vendor/openff/interchange/_annotations.py @@ -0,0 +1,257 @@ +# Vendored from https://github.com/openforcefield/openff-interchange/blob/main/openff/interchange/_annotations.py +import functools +from collections.abc import Callable +from typing import Annotated, Any + +import numpy +from annotated_types import Gt +from openff.units import Quantity # import from units so we don't have to build toolkit just for docs +from pydantic import ( + AfterValidator, + BeforeValidator, + ValidationInfo, + ValidatorFunctionWrapHandler, + WrapSerializer, + WrapValidator, +) + +PositiveFloat = Annotated[float, Gt(0)] + + +def _has_compatible_dimensionality( + quantity: Quantity, + unit: str, + convert: bool, +) -> Quantity: + """Check if a Quantity has the same dimensionality as a given unit and optionally convert.""" + if quantity.is_compatible_with(unit): + if convert: + return quantity.to(unit) + else: + return quantity + else: + raise ValueError( + f"Dimensionality of {quantity=} is not compatible with {unit=}", + ) + + +def _dimensionality_validator_factory(unit: str) -> Callable: + """Return a function, meant to be passed to a validator, that checks for a specific unit.""" + return functools.partial(_has_compatible_dimensionality, unit=unit, convert=False) + + +def _unit_validator_factory(unit: str) -> Callable: + """Return a function, meant to be passed to a validator, that checks for a specific unit.""" + return functools.partial(_has_compatible_dimensionality, unit=unit, convert=True) + + +( + _is_distance, + _is_velocity, + _is_mass, + _is_temperature, +) = ( + _dimensionality_validator_factory(unit=_unit) + for _unit in [ + "nanometer", + "nanometer / picosecond", + "unified_atomic_mass_unit", + "kelvin", + ] +) + +( + _is_dimensionless, + _is_kj_mol, + _is_nanometer, + _is_degree, + _is_elementary_charge, +) = ( + _unit_validator_factory(unit=_unit) + for _unit in [ + "dimensionless", + "kilojoule / mole", + "nanometer", + "degree", + "elementary_charge", + ] +) + + +def quantity_validator( + value: str | Quantity | dict, + handler: ValidatorFunctionWrapHandler, + info: ValidationInfo, +) -> Quantity: + """Take Quantity-like objects and convert them to Quantity objects.""" + if info.mode == "json": + assert isinstance(value, dict), "Quantity must be in dict form here." + + # this is coupled to how a Quantity looks in JSON + return Quantity(value["val"], value["unit"]) + + # some more work may be needed to work with arrays, lists, tuples, etc. + + assert info.mode == "python" + + if isinstance(value, Quantity): + return value + elif isinstance(value, str): + return Quantity(value) + elif isinstance(value, dict): + return Quantity(value["val"], value["unit"]) + if "openmm" in str(type(value)): + from openff.units.openmm import from_openmm + + return from_openmm(value) + else: + raise ValueError(f"Invalid type {type(value)} for Quantity") + + +def quantity_json_serializer( + quantity: Quantity, + nxt, +) -> dict: + """Serialize a Quantity to a JSON-compatible dictionary.""" + magnitude = quantity.m + + if isinstance(magnitude, numpy.ndarray): + # This could be something fancier, list a bytestring + magnitude = magnitude.tolist() + + return { + "val": magnitude, + "unit": str(quantity.units), + } + + +# Pydantic v2 likes to marry validators and serializers to types with Annotated +# https://docs.pydantic.dev/latest/concepts/validators/#annotated-validators +_Quantity = Annotated[ + Quantity, + WrapValidator(quantity_validator), + WrapSerializer(quantity_json_serializer), +] + +_DimensionlessQuantity = Annotated[ + Quantity, + WrapValidator(quantity_validator), + AfterValidator(_is_dimensionless), + WrapSerializer(quantity_json_serializer), +] + +_DistanceQuantity = Annotated[ + Quantity, + WrapValidator(quantity_validator), + AfterValidator(_is_distance), + WrapSerializer(quantity_json_serializer), +] + +_LengthQuantity = _DistanceQuantity + +_VelocityQuantity = Annotated[ + Quantity, + WrapValidator(quantity_validator), + AfterValidator(_is_velocity), + WrapSerializer(quantity_json_serializer), +] + +_MassQuantity = Annotated[ + Quantity, + WrapValidator(quantity_validator), + AfterValidator(_is_mass), + WrapSerializer(quantity_json_serializer), +] + +_TemperatureQuantity = Annotated[ + Quantity, + WrapValidator(quantity_validator), + AfterValidator(_is_temperature), + WrapSerializer(quantity_json_serializer), +] + +_DegreeQuantity = Annotated[ + Quantity, + WrapValidator(quantity_validator), + AfterValidator(_is_degree), + WrapSerializer(quantity_json_serializer), +] + +_ElementaryChargeQuantity = Annotated[ + Quantity, + WrapValidator(quantity_validator), + AfterValidator(_is_elementary_charge), + WrapSerializer(quantity_json_serializer), +] + +_kJMolQuantity = Annotated[ + Quantity, + WrapValidator(quantity_validator), + AfterValidator(_is_kj_mol), + WrapSerializer(quantity_json_serializer), +] + + +def _is_positions_shape(quantity: Quantity) -> Quantity: + if quantity.m.shape[1] == 3: + return quantity + else: + raise ValueError( + f"Quantity {quantity} of wrong shape ({quantity.shape}) to be positions.", + ) + + +def _duck_to_nanometer(value: Any): + """Cast list or ndarray without units to Quantity[ndarray] of nanometer.""" + if isinstance(value, (list, numpy.ndarray)): + return Quantity(value, "nanometer") + else: + return value + + +_PositionsQuantity = Annotated[ + Quantity, + WrapValidator(quantity_validator), + AfterValidator(_is_nanometer), + AfterValidator(_is_positions_shape), + BeforeValidator(_duck_to_nanometer), + WrapSerializer(quantity_json_serializer), +] + + +def _is_box_shape(quantity) -> Quantity: + if quantity.m.shape == (3, 3): + return quantity + elif quantity.m.shape == (3,): + return numpy.eye(3) * quantity + else: + raise ValueError(f"Quantity {quantity} is not a box.") + + +def _unwrap_list_of_openmm_quantities(value: Any): + """Unwrap a list of OpenMM quantities to a single Quantity.""" + if isinstance(value, list): + if any(["openmm" in str(type(element)) for element in value]): + from openff.units.openmm import from_openmm + + if len({element.unit for element in value}) != 1: + raise ValueError("All units must be the same.") + + return from_openmm(value) + + else: + return value + + else: + return value + + +_BoxQuantity = Annotated[ + Quantity, + WrapValidator(quantity_validator), + AfterValidator(_is_distance), + AfterValidator(_is_box_shape), + BeforeValidator(_duck_to_nanometer), + BeforeValidator(_unwrap_list_of_openmm_quantities), + WrapSerializer(quantity_json_serializer), +] diff --git a/gufe/vendor/openff/interchange/pydantic.py b/gufe/vendor/openff/interchange/pydantic.py new file mode 100644 index 000000000..88853fd8f --- /dev/null +++ b/gufe/vendor/openff/interchange/pydantic.py @@ -0,0 +1,21 @@ +# Vendored from https://github.com/openforcefield/openff-interchange/blob/main/openff/interchange/pydantic.py +"""Pydantic base model with custom settings.""" + +from typing import Any + +from pydantic import BaseModel, ConfigDict + + +class _BaseModel(BaseModel): + """A custom Pydantic model used by other components.""" + + model_config = ConfigDict( + validate_assignment=True, + arbitrary_types_allowed=True, + ) + + def model_dump(self, **kwargs) -> dict[str, Any]: + return super().model_dump(serialize_as_any=True, **kwargs) + + def model_dump_json(self, **kwargs) -> str: + return super().model_dump_json(serialize_as_any=True, **kwargs) diff --git a/gufe/vendor/openff/models/README.md b/gufe/vendor/openff/models/README.md deleted file mode 100644 index cdda80525..000000000 --- a/gufe/vendor/openff/models/README.md +++ /dev/null @@ -1,16 +0,0 @@ -So I just yanked what we needed from from https://github.com/openforcefield/openff-models/tree/077ed7b - -Some changes: - -* Instead of using: -```python -try: - from pydantic.v1 import BaseModel -except ImportError: - from pydantic import BaseModel # type: ignore[assignment] -``` -We are going to just use from `pydantic.v1 import BaseModel` directly since we depend on the pydantic 1.x version where that was added. - -Then for `types.py`, `models.py`, and `exceptions.py` I ran our formatting hooks + pyupgrade --py310-plus. - -I've included the LICENSE from the openff repo for good measure. diff --git a/gufe/vendor/openff/models/exceptions.py b/gufe/vendor/openff/models/exceptions.py deleted file mode 100644 index 332823711..000000000 --- a/gufe/vendor/openff/models/exceptions.py +++ /dev/null @@ -1,16 +0,0 @@ -class MissingUnitError(ValueError): - """ - Exception for data missing a unit tag. - """ - - -class UnitValidationError(ValueError): - """ - Exception for bad behavior when validating unit-tagged data. - """ - - -class UnsupportedExportError(BaseException): - """ - Exception for attempting to write to an unsupported file format. - """ diff --git a/gufe/vendor/openff/models/models.py b/gufe/vendor/openff/models/models.py deleted file mode 100644 index 67f2e1590..000000000 --- a/gufe/vendor/openff/models/models.py +++ /dev/null @@ -1,21 +0,0 @@ -from collections.abc import Callable -from typing import Any - -from openff.units import Quantity -from pydantic.v1 import BaseModel - -from .types import custom_quantity_encoder, json_loader - - -class DefaultModel(BaseModel): - """A custom Pydantic model used by other components.""" - - class Config: - """Custom Pydantic configuration.""" - - json_encoders: dict[Any, Callable] = { - Quantity: custom_quantity_encoder, - } - json_loads: Callable = json_loader - validate_assignment: bool = True - arbitrary_types_allowed: bool = True diff --git a/gufe/vendor/openff/models/types.py b/gufe/vendor/openff/models/types.py deleted file mode 100644 index 670fc05df..000000000 --- a/gufe/vendor/openff/models/types.py +++ /dev/null @@ -1,231 +0,0 @@ -"""Custom models for dealing with unit-bearing quantities in a Pydantic-compatible manner.""" - -import json -from typing import TYPE_CHECKING, Any - -import numpy -from openff.units import Quantity, Unit, unit -from openff.utilities import has_package, requires_package - -from .exceptions import ( - MissingUnitError, - UnitValidationError, - UnsupportedExportError, -) - -if TYPE_CHECKING: - import openmm.unit - - -class _FloatQuantityMeta(type): - def __getitem__(self, t): - return type("FloatQuantity", (FloatQuantity,), {"__unit__": t}) - - -if TYPE_CHECKING: - FloatQuantity = unit.Quantity -else: - - class FloatQuantity(float, metaclass=_FloatQuantityMeta): - """A model for unit-bearing floats.""" - - @classmethod - def __get_validators__(cls): - yield cls.validate_type - - @classmethod - def validate_type(cls, val): - """Process a value tagged with units into one tagged with "OpenFF" style units.""" - unit_ = getattr(cls, "__unit__", Any) - if unit_ is Any: - if isinstance(val, (float, int)): - # TODO: Can this exception be raised with knowledge of the field it's in? - raise MissingUnitError(f"Value {val} needs to be tagged with a unit") - elif isinstance(val, Quantity): - return Quantity(val) - elif _is_openmm_quantity(val): - return _from_omm_quantity(val) - else: - raise UnitValidationError(f"Could not validate data of type {type(val)}") - else: - unit_ = Unit(unit_) - if isinstance(val, Quantity): - # some custom behavior could go here - assert unit_.dimensionality == val.dimensionality - # return through converting to some intended default units (taken from the class) - val._magnitude = float(val.m) - return val.to(unit_) - - if _is_openmm_quantity(val): - return _from_omm_quantity(val).to(unit_) - if isinstance(val, int) and not isinstance(val, bool): - # coerce ints into floats for a FloatQuantity - return float(val) * unit_ - if isinstance(val, float): - return val * unit_ - if isinstance(val, str): - # could do custom deserialization here? - val = Quantity(val).to(unit_) - val._magnitude = float(val._magnitude) - return val - if "unyt" in str(val.__class__): - if val.value.shape == (): - # this is a scalar force into an array by unyt's design - if "float" in str(val.value.dtype): - return float(val.value) * unit_ - elif "int" in str(val.value.dtype): - return int(val.value) * unit_ - - raise UnitValidationError(f"Could not validate data of type {type(val)}") - - -def _is_openmm_quantity(obj: object) -> bool: - if has_package("openmm"): - import openmm.unit - - return isinstance(obj, openmm.unit.Quantity) - - else: - return "openmm.unit.quantity.Quantity" in str(type(object)) - - -@requires_package("openmm.unit") -def _from_omm_quantity(val: "openmm.unit.Quantity") -> Quantity: - """ - Convert float or array quantities tagged with SimTK/OpenMM units to a Pint-compatible quantity. - """ - unit_: openmm.unit.Unit = val.unit - val_ = val.value_in_unit(unit_) - if type(val_) in {float, int}: - unit_ = val.unit - return float(val_) * Unit(str(unit_)) - # Here is where the toolkit's ValidatedList could go, if present in the environment - elif (type(val_) in {tuple, list, numpy.ndarray}) or (type(val_).__module__ == "openmm.vec3"): - array = numpy.asarray(val_) - return array * Unit(str(unit_)) - elif isinstance(val_, (float, int)) and type(val_).__module__ == "numpy": - return val_ * Unit(str(unit_)) - else: - raise UnitValidationError( - "Found a openmm.unit.Unit wrapped around something other than a float-like " - f"or numpy.ndarray-like. Found a unit wrapped around type {type(val_)}." - ) - - -class QuantityEncoder(json.JSONEncoder): - """ - JSON encoder for unit-wrapped floats and NumPy arrays. - - This is intended to operate on FloatQuantity and ArrayQuantity objects. - """ - - def default(self, obj): - if isinstance(obj, Quantity): - if isinstance(obj.magnitude, (float, int)): - data = obj.magnitude - elif isinstance(obj.magnitude, numpy.ndarray): - data = obj.magnitude.tolist() - else: - # This shouldn't ever be hit if our object models - # behave in ways we expect? - raise UnsupportedExportError(f"trying to serialize unsupported type {type(obj.magnitude)}") - return { - "val": data, - "unit": str(obj.units), - } - - -def custom_quantity_encoder(v): - """Wrap json.dump to use QuantityEncoder.""" - return json.dumps(v, cls=QuantityEncoder) - - -def json_loader(data: str) -> dict: - """Load JSON containing custom unit-tagged quantities.""" - # TODO: recursively call this function for nested models - out: dict = json.loads(data) - for key, val in out.items(): - try: - # Directly look for an encoded FloatQuantity/ArrayQuantity, - # which is itself a dict - v = json.loads(val) - except (json.JSONDecodeError, TypeError): - # Handles some cases of the val being a primitive type - continue - # TODO: More gracefully parse non-FloatQuantity/ArrayQuantity dicts - unit_ = Unit(v["unit"]) - val = v["val"] - out[key] = unit_ * val - return out - - -class _ArrayQuantityMeta(type): - def __getitem__(self, t): - return type("ArrayQuantity", (ArrayQuantity,), {"__unit__": t}) - - -if TYPE_CHECKING: - ArrayQuantity = unit.Quantity -else: - - class ArrayQuantity(float, metaclass=_ArrayQuantityMeta): - """A model for unit-bearing arrays.""" - - @classmethod - def __get_validators__(cls): - yield cls.validate_type - - @classmethod - def validate_type(cls, val): - """Process an array tagged with units into one tagged with "OpenFF" style units.""" - unit_ = getattr(cls, "__unit__", Any) - if unit_ is Any: - if isinstance(val, (list, numpy.ndarray)): - # Work around a special case in which val might be list[openmm.unit.Quantity] - if isinstance(val, list) and {type(element).__module__ for element in val} == { - "openmm.unit.quantity" - }: - unit_ = _from_omm_quantity(val[-1]).units - return Quantity( - [_from_omm_quantity(element).m for element in val], - units=unit_, - ) - - # TODO: Can this exception be raised with knowledge of the field it's in? - raise MissingUnitError(f"Value {val} needs to be tagged with a unit") - - elif isinstance(val, Quantity): - # TODO: This might be a redundant cast causing wasted CPU time. - # But maybe it handles pint vs openff.units.unit? - return Quantity(val) - elif _is_openmm_quantity(val): - return _from_omm_quantity(val) - else: - raise UnitValidationError(f"Could not validate data of type {type(val)}") - else: - unit_ = Unit(unit_) - if isinstance(val, Quantity): - assert unit_.dimensionality == val.dimensionality - return val.to(unit_) - if _is_openmm_quantity(val): - return _from_omm_quantity(val).to(unit_) - if isinstance(val, (numpy.ndarray, list)): - if "unyt" in str(val.__class__): - val = val.to_ndarray() - try: - return val * unit_ - except RuntimeError as error: - # unyt subclasses ndarray but doesn't __mult__ with - # pint.Unit objects - if val.__class__.__module__.startswith("unyt"): - return val.to_ndarray() * unit_ - else: - raise error - if isinstance(val, bytes): - # Define outside loop - dt = numpy.dtype(int).newbyteorder("<") - return numpy.frombuffer(val, dtype=dt) * unit_ - if isinstance(val, str): - # could do custom deserialization here? - raise NotImplementedError - raise UnitValidationError(f"Could not validate data of type {type(val)}") diff --git a/news/pydantic_v2.rst b/news/pydantic_v2.rst new file mode 100644 index 000000000..d7c6e15f1 --- /dev/null +++ b/news/pydantic_v2.rst @@ -0,0 +1,26 @@ +**Added:** + +* + +**Changed:** + +* ``FloatQuantity`` is no longer supported. Instead, use `GufeQuantity` and `specify_quantity_units()` to make a `TypeAlias`. +* System generator setting ``nonbonded_cutoff`` no longer attempts to coerce ambiguous inputs to ``unit.nanometer``. Instead, a length unit is required, e.g. ``2.2 * unit.nanometer`` or ``"2.2 nm"``. +* ``ThermoSettings`` parameters ``pressure`` and ``temperature`` no longer attempt to coerce ambiguous inputs to unts. Instead, the units must be passed explicitly, e.g. ``1.0 * units.atm`` or ``"1 atm"`` for pressure, and ``300 * unit.kelvin`` or ``"300 kelvin"`` for temperature. + +**Deprecated:** + +* + +.. TODO: add a link to docs +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +*