From 1280e5b92265a1216895f2fdb4e18af02dc56d66 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Tue, 17 Feb 2026 11:38:50 +0100 Subject: [PATCH 1/3] ottjax compat initial fix attempt --- pyproject.toml | 2 +- src/cellflow/_compat.py | 78 +++++++++++++++++++++++++++++++++ src/cellflow/model/_cellflow.py | 6 +-- src/cellflow/solvers/_genot.py | 9 ++-- src/cellflow/solvers/_otfm.py | 4 +- tests/solver/test_solver.py | 8 ++-- tests/trainer/test_trainer.py | 12 ++--- 7 files changed, 99 insertions(+), 20 deletions(-) create mode 100644 src/cellflow/_compat.py diff --git a/pyproject.toml b/pyproject.toml index 152c690b..8dd4d43d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "diffrax", "flax", "orbax", - "ott-jax==0.5", + "ott-jax[neural]>=0.5.0", "pyarrow", # required for dask.dataframe "scanpy", "scikit-learn==1.5.1", diff --git a/src/cellflow/_compat.py b/src/cellflow/_compat.py new file mode 100644 index 00000000..dda2650f --- /dev/null +++ b/src/cellflow/_compat.py @@ -0,0 +1,78 @@ +"""Compatibility layer for ``ott-jax`` across versions. + +``ott-jax>=0.6`` removed ``ott.neural.methods.flows.dynamics`` and the +``ott.neural.networks.velocity_field.VelocityField`` (flax linen) class. +This module re-exports the symbols needed by CellFlow so that both +``ott-jax>=0.5,<0.6`` and ``ott-jax>=0.6`` are supported. +""" + +# --------------------------------------------------------------------------- +# Probability-path dynamics (BaseFlow, ConstantNoiseFlow, BrownianBridge) +# --------------------------------------------------------------------------- +try: + from ott.neural.methods.flows.dynamics import ( # ott-jax <0.6 + BaseFlow, + BrownianBridge, + ConstantNoiseFlow, + ) +except ImportError: + import abc + + import jax + import jax.numpy as jnp + + class BaseFlow(abc.ABC): + """Base class for all flows.""" + + def __init__(self, sigma: float): + self.sigma = sigma + + @abc.abstractmethod + def compute_mu_t( + self, t: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray + ) -> jnp.ndarray: ... + + @abc.abstractmethod + def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: ... + + @abc.abstractmethod + def compute_ut( + self, t: jnp.ndarray, x: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray + ) -> jnp.ndarray: ... + + def compute_xt( + self, rng: jax.Array, t: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray + ) -> jnp.ndarray: + """Sample from the probability path.""" + noise = jax.random.normal(rng, shape=x0.shape) + mu_t = self.compute_mu_t(t, x0, x1) + sigma_t = self.compute_sigma_t(t) + return mu_t + sigma_t * noise + + class _StraightFlow(BaseFlow, abc.ABC): + def compute_mu_t(self, t, x0, x1): + return (1.0 - t) * x0 + t * x1 + + def compute_ut(self, t, x, x0, x1): + del t, x + return x1 - x0 + + class ConstantNoiseFlow(_StraightFlow): + r"""Flow with straight paths and constant noise :math:`\sigma`.""" + + def compute_sigma_t(self, t): + return jnp.full_like(t, fill_value=self.sigma) + + class BrownianBridge(_StraightFlow): + r"""Brownian Bridge with :math:`\sigma_t = \sigma \sqrt{t(1-t)}`.""" + + def compute_sigma_t(self, t): + return self.sigma * jnp.sqrt(t * (1.0 - t)) + + def compute_ut(self, t, x, x0, x1): + drift_term = (1 - 2 * t) / (2 * t * (1 - t)) * (x - (t * x1 + (1 - t) * x0)) + control_term = x1 - x0 + return drift_term + control_term + + +__all__ = ["BaseFlow", "ConstantNoiseFlow", "BrownianBridge"] diff --git a/src/cellflow/model/_cellflow.py b/src/cellflow/model/_cellflow.py index a535ddf7..f572c770 100644 --- a/src/cellflow/model/_cellflow.py +++ b/src/cellflow/model/_cellflow.py @@ -13,7 +13,7 @@ import numpy as np import optax import pandas as pd -from ott.neural.methods.flows import dynamics +from cellflow._compat import BrownianBridge, ConstantNoiseFlow from cellflow import _constants from cellflow._types import ArrayLike, Layers_separate_input_t, Layers_t @@ -477,9 +477,9 @@ def prepare_model( probability_path, noise = next(iter(probability_path.items())) if probability_path == "constant_noise": - probability_path = dynamics.ConstantNoiseFlow(noise) + probability_path = ConstantNoiseFlow(noise) elif probability_path == "bridge": - probability_path = dynamics.BrownianBridge(noise) + probability_path = BrownianBridge(noise) else: raise NotImplementedError( f"The key of `probability_path` must be `'constant_noise'` or `'bridge'` but found {probability_path}." diff --git a/src/cellflow/solvers/_genot.py b/src/cellflow/solvers/_genot.py index 7270ad7f..e8e350b0 100644 --- a/src/cellflow/solvers/_genot.py +++ b/src/cellflow/solvers/_genot.py @@ -7,10 +7,11 @@ import jax.numpy as jnp import numpy as np from flax.training import train_state -from ott.neural.methods.flows import dynamics -from ott.neural.networks import velocity_field +from flax import linen as nn from ott.solvers import utils as solver_utils +from cellflow._compat import BaseFlow + from cellflow import utils from cellflow._types import ArrayLike from cellflow.model._utils import _multivariate_normal @@ -57,8 +58,8 @@ class GENOT: def __init__( self, - vf: velocity_field.VelocityField, - probability_path: dynamics.BaseFlow, + vf: nn.Module, + probability_path: BaseFlow, data_match_fn: DataMatchFn, *, source_dim: int, diff --git a/src/cellflow/solvers/_otfm.py b/src/cellflow/solvers/_otfm.py index 31114a6b..478e397d 100644 --- a/src/cellflow/solvers/_otfm.py +++ b/src/cellflow/solvers/_otfm.py @@ -8,7 +8,7 @@ import numpy as np from flax.core import frozen_dict from flax.training import train_state -from ott.neural.methods.flows import dynamics +from cellflow._compat import BaseFlow from ott.solvers import utils as solver_utils from cellflow import utils @@ -47,7 +47,7 @@ class OTFlowMatching: def __init__( self, vf: ConditionalVelocityField, - probability_path: dynamics.BaseFlow, + probability_path: BaseFlow, match_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] | None = None, time_sampler: Callable[[jax.Array, int], jnp.ndarray] = solver_utils.uniform_sampler, **kwargs: Any, diff --git a/tests/solver/test_solver.py b/tests/solver/test_solver.py index 6bf10401..cd28477b 100644 --- a/tests/solver/test_solver.py +++ b/tests/solver/test_solver.py @@ -5,7 +5,7 @@ import numpy as np import optax import pytest -from ott.neural.methods.flows import dynamics +from cellflow._compat import ConstantNoiseFlow import cellflow from cellflow.solvers import _genot, _otfm @@ -42,7 +42,7 @@ def test_predict_batch(self, dataloader, solver_class): solver = _otfm.OTFlowMatching( vf=vf, match_fn=match_linear, - probability_path=dynamics.ConstantNoiseFlow(0.0), + probability_path=ConstantNoiseFlow(0.0), optimizer=opt, conditions={"drug": np.random.rand(2, 1, 3)}, rng=vf_rng, @@ -51,7 +51,7 @@ def test_predict_batch(self, dataloader, solver_class): solver = _genot.GENOT( vf=vf, data_match_fn=match_linear, - probability_path=dynamics.ConstantNoiseFlow(0.0), + probability_path=ConstantNoiseFlow(0.0), optimizer=opt, source_dim=5, target_dim=5, @@ -105,7 +105,7 @@ def test_EMA(self, dataloader, ema): solver1 = _otfm.OTFlowMatching( vf=vf1, match_fn=match_linear, - probability_path=dynamics.ConstantNoiseFlow(0.0), + probability_path=ConstantNoiseFlow(0.0), optimizer=opt, conditions={"drug": drug}, rng=vf_rng, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index beef4eb1..153361b9 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -5,7 +5,7 @@ import numpy as np import optax import pytest -from ott.neural.methods.flows import dynamics +from cellflow._compat import ConstantNoiseFlow import cellflow from cellflow.solvers import _otfm @@ -55,7 +55,7 @@ def test_cellflow_trainer(self, dataloader, valid_freq): model = _otfm.OTFlowMatching( vf=vf, match_fn=match_linear, - probability_path=dynamics.ConstantNoiseFlow(0.0), + probability_path=ConstantNoiseFlow(0.0), optimizer=opt, conditions=cond, rng=vf_rng, @@ -91,7 +91,7 @@ def test_cellflow_trainer_with_callback(self, dataloader, valid_loader, use_vali model = _otfm.OTFlowMatching( vf=vf, match_fn=match_linear, - probability_path=dynamics.ConstantNoiseFlow(0.0), + probability_path=ConstantNoiseFlow(0.0), optimizer=opt, conditions=cond, rng=vf_rng, @@ -137,7 +137,7 @@ def test_cellflow_trainer_with_custom_callback(self, dataloader, valid_loader): solver = _otfm.OTFlowMatching( vf=vf, match_fn=match_linear, - probability_path=dynamics.ConstantNoiseFlow(0.0), + probability_path=ConstantNoiseFlow(0.0), optimizer=opt, conditions=cond, rng=vf_rng, @@ -181,7 +181,7 @@ def test_predict_kwargs_iter(self, dataloader, valid_loader): model_1 = _otfm.OTFlowMatching( vf=vf_1, match_fn=match_linear, - probability_path=dynamics.ConstantNoiseFlow(0.0), + probability_path=ConstantNoiseFlow(0.0), optimizer=opt_1, conditions=cond, rng=vf_rng, @@ -189,7 +189,7 @@ def test_predict_kwargs_iter(self, dataloader, valid_loader): model_2 = _otfm.OTFlowMatching( vf=vf_2, match_fn=match_linear, - probability_path=dynamics.ConstantNoiseFlow(0.0), + probability_path=ConstantNoiseFlow(0.0), optimizer=opt_2, conditions=cond, rng=vf_rng, From a35e1bafd38810758fc26948287ede8173464074 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Feb 2026 10:40:36 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pyproject.toml | 6 +++--- src/cellflow/_compat.py | 12 +++--------- src/cellflow/model/_cellflow.py | 2 +- src/cellflow/solvers/_genot.py | 5 ++--- src/cellflow/solvers/_otfm.py | 2 +- tests/solver/test_solver.py | 2 +- tests/trainer/test_trainer.py | 2 +- 7 files changed, 12 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8dd4d43d..911b088f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,11 +38,11 @@ dependencies = [ "diffrax", "flax", "orbax", - "ott-jax[neural]>=0.5.0", - "pyarrow", # required for dask.dataframe + "ott-jax[neural]>=0.5", + "pyarrow", # required for dask.dataframe "scanpy", "scikit-learn==1.5.1", - "scipy<1.16", # see https://github.com/statsmodels/statsmodels/issues/9584 + "scipy<1.16", # see https://github.com/statsmodels/statsmodels/issues/9584 "session-info", ] diff --git a/src/cellflow/_compat.py b/src/cellflow/_compat.py index dda2650f..512140fe 100644 --- a/src/cellflow/_compat.py +++ b/src/cellflow/_compat.py @@ -28,21 +28,15 @@ def __init__(self, sigma: float): self.sigma = sigma @abc.abstractmethod - def compute_mu_t( - self, t: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray - ) -> jnp.ndarray: ... + def compute_mu_t(self, t: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray) -> jnp.ndarray: ... @abc.abstractmethod def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: ... @abc.abstractmethod - def compute_ut( - self, t: jnp.ndarray, x: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray - ) -> jnp.ndarray: ... + def compute_ut(self, t: jnp.ndarray, x: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray) -> jnp.ndarray: ... - def compute_xt( - self, rng: jax.Array, t: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray - ) -> jnp.ndarray: + def compute_xt(self, rng: jax.Array, t: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray) -> jnp.ndarray: """Sample from the probability path.""" noise = jax.random.normal(rng, shape=x0.shape) mu_t = self.compute_mu_t(t, x0, x1) diff --git a/src/cellflow/model/_cellflow.py b/src/cellflow/model/_cellflow.py index f572c770..0ee02cba 100644 --- a/src/cellflow/model/_cellflow.py +++ b/src/cellflow/model/_cellflow.py @@ -13,9 +13,9 @@ import numpy as np import optax import pandas as pd -from cellflow._compat import BrownianBridge, ConstantNoiseFlow from cellflow import _constants +from cellflow._compat import BrownianBridge, ConstantNoiseFlow from cellflow._types import ArrayLike, Layers_separate_input_t, Layers_t from cellflow.data._data import ConditionData, TrainingData, ValidationData from cellflow.data._dataloader import OOCTrainSampler, PredictionSampler, TrainSampler, ValidationSampler diff --git a/src/cellflow/solvers/_genot.py b/src/cellflow/solvers/_genot.py index e8e350b0..178511d5 100644 --- a/src/cellflow/solvers/_genot.py +++ b/src/cellflow/solvers/_genot.py @@ -6,13 +6,12 @@ import jax import jax.numpy as jnp import numpy as np -from flax.training import train_state from flax import linen as nn +from flax.training import train_state from ott.solvers import utils as solver_utils -from cellflow._compat import BaseFlow - from cellflow import utils +from cellflow._compat import BaseFlow from cellflow._types import ArrayLike from cellflow.model._utils import _multivariate_normal diff --git a/src/cellflow/solvers/_otfm.py b/src/cellflow/solvers/_otfm.py index 478e397d..3a12fd9e 100644 --- a/src/cellflow/solvers/_otfm.py +++ b/src/cellflow/solvers/_otfm.py @@ -8,10 +8,10 @@ import numpy as np from flax.core import frozen_dict from flax.training import train_state -from cellflow._compat import BaseFlow from ott.solvers import utils as solver_utils from cellflow import utils +from cellflow._compat import BaseFlow from cellflow._types import ArrayLike from cellflow.networks._velocity_field import ConditionalVelocityField from cellflow.solvers.utils import ema_update diff --git a/tests/solver/test_solver.py b/tests/solver/test_solver.py index cd28477b..9ed3131b 100644 --- a/tests/solver/test_solver.py +++ b/tests/solver/test_solver.py @@ -5,9 +5,9 @@ import numpy as np import optax import pytest -from cellflow._compat import ConstantNoiseFlow import cellflow +from cellflow._compat import ConstantNoiseFlow from cellflow.solvers import _genot, _otfm from cellflow.utils import match_linear diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 153361b9..dc693236 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -5,9 +5,9 @@ import numpy as np import optax import pytest -from cellflow._compat import ConstantNoiseFlow import cellflow +from cellflow._compat import ConstantNoiseFlow from cellflow.solvers import _otfm from cellflow.training import CellFlowTrainer, ComputationCallback, Metrics from cellflow.utils import match_linear From aa9cf352bb1508c93e2366576018243a3f83a1a8 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Tue, 17 Feb 2026 11:43:25 +0100 Subject: [PATCH 3/3] comment on copied file --- src/cellflow/_compat.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/cellflow/_compat.py b/src/cellflow/_compat.py index 512140fe..7b232332 100644 --- a/src/cellflow/_compat.py +++ b/src/cellflow/_compat.py @@ -8,6 +8,17 @@ # --------------------------------------------------------------------------- # Probability-path dynamics (BaseFlow, ConstantNoiseFlow, BrownianBridge) +# +# For ott-jax <0.6 we import directly from ott. For ott-jax >=0.6 the +# module was removed, so we provide a vendored copy below. +# +# The fallback classes are a verbatim copy of +# ott.neural.methods.flows.dynamics +# from ott-jax 0.5.0 (commit 690b1ae, 2024-12-03). +# ott-jax is licensed under the Apache License 2.0, which permits +# reproduction and distribution of derivative works provided the license +# and copyright notice are retained. See: +# https://github.com/ott-jax/ott/blob/0.5.0/LICENSE # --------------------------------------------------------------------------- try: from ott.neural.methods.flows.dynamics import ( # ott-jax <0.6 @@ -16,6 +27,10 @@ ConstantNoiseFlow, ) except ImportError: + # -- Vendored from ott-jax 0.5.0 (Apache-2.0) -------------------------- + # Source: src/ott/neural/methods/flows/dynamics.py + # Copyright OTT-JAX contributors + # ----------------------------------------------------------------------- import abc import jax