Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/stratocaster/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@

class StrategySettings(SettingsBaseModel):
"""Base class for Strategy settings."""

pass
1 change: 1 addition & 0 deletions src/stratocaster/base/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

TProtocolResult = TypeVar("TProtocolResult", bound=ProtocolResult)


class StrategyResult(GufeTokenizable):
"""Results produced by a Strategy."""

Expand Down
25 changes: 12 additions & 13 deletions src/stratocaster/strategies/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@
from stratocaster.base import Strategy, StrategyResult
from stratocaster.base.models import StrategySettings

try:
from pydantic.v1 import Field, root_validator, validator
except ImportError:
from pydantic import (
Field,
root_validator,
validator,
)
from pydantic import (
Field,
model_validator,
field_validator,
)

import pydantic

Expand All @@ -31,27 +28,27 @@ class ConnectivityStrategySettings(StrategySettings):
description="the upper limit of protocol DAG results needed before a transformation is no longer weighed",
)

@validator("cutoff")
@field_validator("cutoff", mode="before")
def validate_cutoff(cls, value):
if value is not None:
if not (0 < value):
raise ValueError("`cutoff` must be greater than 0")
return value

@validator("decay_rate")
@field_validator("decay_rate", mode="before")
def validate_decay_rate(cls, value):
if not (0 < value < 1):
raise ValueError("`decay_rate` must be between 0 and 1")
return value

@validator("max_runs")
@field_validator("max_runs", mode="before")
def validate_max_runs(cls, value):
if value is not None:
if not value >= 1:
raise ValueError("`max_runs` must be greater than or equal to 1")
return value

@root_validator
@model_validator(mode="before")
def check_cutoff_or_max_runs(cls, values):
"""Check that at either max_runs or cutoff is set."""
max_runs, cutoff = values.get("max_runs"), values.get("cutoff")
Expand All @@ -72,7 +69,9 @@ class ConnectivityStrategy(Strategy):

_settings_cls = ConnectivityStrategySettings

def _exponential_decay_scaling(self, number_of_results: int, decay_rate: float) -> float:
def _exponential_decay_scaling(
self, number_of_results: int, decay_rate: float
) -> float:
"""Transformation weight decay factor.

Parameters
Expand Down