diff --git a/pyproject.toml b/pyproject.toml index 152c690b..911b088f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,11 +38,11 @@ dependencies = [ "diffrax", "flax", "orbax", - "ott-jax==0.5", - "pyarrow", # required for dask.dataframe + "ott-jax[neural]>=0.5", + "pyarrow", # required for dask.dataframe "scanpy", "scikit-learn==1.5.1", - "scipy<1.16", # see https://github.com/statsmodels/statsmodels/issues/9584 + "scipy<1.16", # see https://github.com/statsmodels/statsmodels/issues/9584 "session-info", ] diff --git a/src/cellflow/_compat.py b/src/cellflow/_compat.py new file mode 100644 index 00000000..7b232332 --- /dev/null +++ b/src/cellflow/_compat.py @@ -0,0 +1,87 @@ +"""Compatibility layer for ``ott-jax`` across versions. + +``ott-jax>=0.6`` removed ``ott.neural.methods.flows.dynamics`` and the +``ott.neural.networks.velocity_field.VelocityField`` (flax linen) class. +This module re-exports the symbols needed by CellFlow so that both +``ott-jax>=0.5,<0.6`` and ``ott-jax>=0.6`` are supported. +""" + +# --------------------------------------------------------------------------- +# Probability-path dynamics (BaseFlow, ConstantNoiseFlow, BrownianBridge) +# +# For ott-jax <0.6 we import directly from ott. For ott-jax >=0.6 the +# module was removed, so we provide a vendored copy below. +# +# The fallback classes are a verbatim copy of +# ott.neural.methods.flows.dynamics +# from ott-jax 0.5.0 (commit 690b1ae, 2024-12-03). +# ott-jax is licensed under the Apache License 2.0, which permits +# reproduction and distribution of derivative works provided the license +# and copyright notice are retained. See: +# https://github.com/ott-jax/ott/blob/0.5.0/LICENSE +# --------------------------------------------------------------------------- +try: + from ott.neural.methods.flows.dynamics import ( # ott-jax <0.6 + BaseFlow, + BrownianBridge, + ConstantNoiseFlow, + ) +except ImportError: + # -- Vendored from ott-jax 0.5.0 (Apache-2.0) -------------------------- + # Source: src/ott/neural/methods/flows/dynamics.py + # Copyright OTT-JAX contributors + # ----------------------------------------------------------------------- + import abc + + import jax + import jax.numpy as jnp + + class BaseFlow(abc.ABC): + """Base class for all flows.""" + + def __init__(self, sigma: float): + self.sigma = sigma + + @abc.abstractmethod + def compute_mu_t(self, t: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray) -> jnp.ndarray: ... + + @abc.abstractmethod + def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: ... + + @abc.abstractmethod + def compute_ut(self, t: jnp.ndarray, x: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray) -> jnp.ndarray: ... + + def compute_xt(self, rng: jax.Array, t: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray) -> jnp.ndarray: + """Sample from the probability path.""" + noise = jax.random.normal(rng, shape=x0.shape) + mu_t = self.compute_mu_t(t, x0, x1) + sigma_t = self.compute_sigma_t(t) + return mu_t + sigma_t * noise + + class _StraightFlow(BaseFlow, abc.ABC): + def compute_mu_t(self, t, x0, x1): + return (1.0 - t) * x0 + t * x1 + + def compute_ut(self, t, x, x0, x1): + del t, x + return x1 - x0 + + class ConstantNoiseFlow(_StraightFlow): + r"""Flow with straight paths and constant noise :math:`\sigma`.""" + + def compute_sigma_t(self, t): + return jnp.full_like(t, fill_value=self.sigma) + + class BrownianBridge(_StraightFlow): + r"""Brownian Bridge with :math:`\sigma_t = \sigma \sqrt{t(1-t)}`.""" + + def compute_sigma_t(self, t): + return self.sigma * jnp.sqrt(t * (1.0 - t)) + + def compute_ut(self, t, x, x0, x1): + drift_term = (1 - 2 * t) / (2 * t * (1 - t)) * (x - (t * x1 + (1 - t) * x0)) + control_term = x1 - x0 + return drift_term + control_term + + +__all__ = ["BaseFlow", "ConstantNoiseFlow", "BrownianBridge"] diff --git a/src/cellflow/model/_cellflow.py b/src/cellflow/model/_cellflow.py index a535ddf7..0ee02cba 100644 --- a/src/cellflow/model/_cellflow.py +++ b/src/cellflow/model/_cellflow.py @@ -13,9 +13,9 @@ import numpy as np import optax import pandas as pd -from ott.neural.methods.flows import dynamics from cellflow import _constants +from cellflow._compat import BrownianBridge, ConstantNoiseFlow from cellflow._types import ArrayLike, Layers_separate_input_t, Layers_t from cellflow.data._data import ConditionData, TrainingData, ValidationData from cellflow.data._dataloader import OOCTrainSampler, PredictionSampler, TrainSampler, ValidationSampler @@ -477,9 +477,9 @@ def prepare_model( probability_path, noise = next(iter(probability_path.items())) if probability_path == "constant_noise": - probability_path = dynamics.ConstantNoiseFlow(noise) + probability_path = ConstantNoiseFlow(noise) elif probability_path == "bridge": - probability_path = dynamics.BrownianBridge(noise) + probability_path = BrownianBridge(noise) else: raise NotImplementedError( f"The key of `probability_path` must be `'constant_noise'` or `'bridge'` but found {probability_path}." diff --git a/src/cellflow/solvers/_genot.py b/src/cellflow/solvers/_genot.py index 7270ad7f..178511d5 100644 --- a/src/cellflow/solvers/_genot.py +++ b/src/cellflow/solvers/_genot.py @@ -6,12 +6,12 @@ import jax import jax.numpy as jnp import numpy as np +from flax import linen as nn from flax.training import train_state -from ott.neural.methods.flows import dynamics -from ott.neural.networks import velocity_field from ott.solvers import utils as solver_utils from cellflow import utils +from cellflow._compat import BaseFlow from cellflow._types import ArrayLike from cellflow.model._utils import _multivariate_normal @@ -57,8 +57,8 @@ class GENOT: def __init__( self, - vf: velocity_field.VelocityField, - probability_path: dynamics.BaseFlow, + vf: nn.Module, + probability_path: BaseFlow, data_match_fn: DataMatchFn, *, source_dim: int, diff --git a/src/cellflow/solvers/_otfm.py b/src/cellflow/solvers/_otfm.py index 31114a6b..3a12fd9e 100644 --- a/src/cellflow/solvers/_otfm.py +++ b/src/cellflow/solvers/_otfm.py @@ -8,10 +8,10 @@ import numpy as np from flax.core import frozen_dict from flax.training import train_state -from ott.neural.methods.flows import dynamics from ott.solvers import utils as solver_utils from cellflow import utils +from cellflow._compat import BaseFlow from cellflow._types import ArrayLike from cellflow.networks._velocity_field import ConditionalVelocityField from cellflow.solvers.utils import ema_update @@ -47,7 +47,7 @@ class OTFlowMatching: def __init__( self, vf: ConditionalVelocityField, - probability_path: dynamics.BaseFlow, + probability_path: BaseFlow, match_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] | None = None, time_sampler: Callable[[jax.Array, int], jnp.ndarray] = solver_utils.uniform_sampler, **kwargs: Any, diff --git a/tests/solver/test_solver.py b/tests/solver/test_solver.py index 6bf10401..9ed3131b 100644 --- a/tests/solver/test_solver.py +++ b/tests/solver/test_solver.py @@ -5,9 +5,9 @@ import numpy as np import optax import pytest -from ott.neural.methods.flows import dynamics import cellflow +from cellflow._compat import ConstantNoiseFlow from cellflow.solvers import _genot, _otfm from cellflow.utils import match_linear @@ -42,7 +42,7 @@ def test_predict_batch(self, dataloader, solver_class): solver = _otfm.OTFlowMatching( vf=vf, match_fn=match_linear, - probability_path=dynamics.ConstantNoiseFlow(0.0), + probability_path=ConstantNoiseFlow(0.0), optimizer=opt, conditions={"drug": np.random.rand(2, 1, 3)}, rng=vf_rng, @@ -51,7 +51,7 @@ def test_predict_batch(self, dataloader, solver_class): solver = _genot.GENOT( vf=vf, data_match_fn=match_linear, - probability_path=dynamics.ConstantNoiseFlow(0.0), + probability_path=ConstantNoiseFlow(0.0), optimizer=opt, source_dim=5, target_dim=5, @@ -105,7 +105,7 @@ def test_EMA(self, dataloader, ema): solver1 = _otfm.OTFlowMatching( vf=vf1, match_fn=match_linear, - probability_path=dynamics.ConstantNoiseFlow(0.0), + probability_path=ConstantNoiseFlow(0.0), optimizer=opt, conditions={"drug": drug}, rng=vf_rng, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index beef4eb1..dc693236 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -5,9 +5,9 @@ import numpy as np import optax import pytest -from ott.neural.methods.flows import dynamics import cellflow +from cellflow._compat import ConstantNoiseFlow from cellflow.solvers import _otfm from cellflow.training import CellFlowTrainer, ComputationCallback, Metrics from cellflow.utils import match_linear @@ -55,7 +55,7 @@ def test_cellflow_trainer(self, dataloader, valid_freq): model = _otfm.OTFlowMatching( vf=vf, match_fn=match_linear, - probability_path=dynamics.ConstantNoiseFlow(0.0), + probability_path=ConstantNoiseFlow(0.0), optimizer=opt, conditions=cond, rng=vf_rng, @@ -91,7 +91,7 @@ def test_cellflow_trainer_with_callback(self, dataloader, valid_loader, use_vali model = _otfm.OTFlowMatching( vf=vf, match_fn=match_linear, - probability_path=dynamics.ConstantNoiseFlow(0.0), + probability_path=ConstantNoiseFlow(0.0), optimizer=opt, conditions=cond, rng=vf_rng, @@ -137,7 +137,7 @@ def test_cellflow_trainer_with_custom_callback(self, dataloader, valid_loader): solver = _otfm.OTFlowMatching( vf=vf, match_fn=match_linear, - probability_path=dynamics.ConstantNoiseFlow(0.0), + probability_path=ConstantNoiseFlow(0.0), optimizer=opt, conditions=cond, rng=vf_rng, @@ -181,7 +181,7 @@ def test_predict_kwargs_iter(self, dataloader, valid_loader): model_1 = _otfm.OTFlowMatching( vf=vf_1, match_fn=match_linear, - probability_path=dynamics.ConstantNoiseFlow(0.0), + probability_path=ConstantNoiseFlow(0.0), optimizer=opt_1, conditions=cond, rng=vf_rng, @@ -189,7 +189,7 @@ def test_predict_kwargs_iter(self, dataloader, valid_loader): model_2 = _otfm.OTFlowMatching( vf=vf_2, match_fn=match_linear, - probability_path=dynamics.ConstantNoiseFlow(0.0), + probability_path=ConstantNoiseFlow(0.0), optimizer=opt_2, conditions=cond, rng=vf_rng,