From 2f40a19aa6e6f9c1cf9737a600872ceea92cbfcc Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Thu, 4 Dec 2025 10:52:15 -0500 Subject: [PATCH] Drop pydantic v1 and use most recent validators --- src/stratocaster/base/models.py | 1 + src/stratocaster/base/strategy.py | 1 + src/stratocaster/strategies/connectivity.py | 25 ++++++++++----------- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/stratocaster/base/models.py b/src/stratocaster/base/models.py index 44e78ca..a8b663c 100644 --- a/src/stratocaster/base/models.py +++ b/src/stratocaster/base/models.py @@ -3,4 +3,5 @@ class StrategySettings(SettingsBaseModel): """Base class for Strategy settings.""" + pass diff --git a/src/stratocaster/base/strategy.py b/src/stratocaster/base/strategy.py index 039d2c2..a7e3d3e 100644 --- a/src/stratocaster/base/strategy.py +++ b/src/stratocaster/base/strategy.py @@ -8,6 +8,7 @@ TProtocolResult = TypeVar("TProtocolResult", bound=ProtocolResult) + class StrategyResult(GufeTokenizable): """Results produced by a Strategy.""" diff --git a/src/stratocaster/strategies/connectivity.py b/src/stratocaster/strategies/connectivity.py index 739eaf2..d6a310b 100644 --- a/src/stratocaster/strategies/connectivity.py +++ b/src/stratocaster/strategies/connectivity.py @@ -4,14 +4,11 @@ from stratocaster.base import Strategy, StrategyResult from stratocaster.base.models import StrategySettings -try: - from pydantic.v1 import Field, root_validator, validator -except ImportError: - from pydantic import ( - Field, - root_validator, - validator, - ) +from pydantic import ( + Field, + model_validator, + field_validator, +) import pydantic @@ -31,27 +28,27 @@ class ConnectivityStrategySettings(StrategySettings): description="the upper limit of protocol DAG results needed before a transformation is no longer weighed", ) - @validator("cutoff") + @field_validator("cutoff", mode="before") def validate_cutoff(cls, value): if value is not None: if not (0 < value): raise ValueError("`cutoff` must be greater than 0") return value - @validator("decay_rate") + @field_validator("decay_rate", mode="before") def validate_decay_rate(cls, value): if not (0 < value < 1): raise ValueError("`decay_rate` must be between 0 and 1") return value - @validator("max_runs") + @field_validator("max_runs", mode="before") def validate_max_runs(cls, value): if value is not None: if not value >= 1: raise ValueError("`max_runs` must be greater than or equal to 1") return value - @root_validator + @model_validator(mode="before") def check_cutoff_or_max_runs(cls, values): """Check that at either max_runs or cutoff is set.""" max_runs, cutoff = values.get("max_runs"), values.get("cutoff") @@ -72,7 +69,9 @@ class ConnectivityStrategy(Strategy): _settings_cls = ConnectivityStrategySettings - def _exponential_decay_scaling(self, number_of_results: int, decay_rate: float) -> float: + def _exponential_decay_scaling( + self, number_of_results: int, decay_rate: float + ) -> float: """Transformation weight decay factor. Parameters