Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ dependencies = [
"diffrax",
"flax",
"orbax",
"ott-jax==0.5",
"pyarrow", # required for dask.dataframe
"ott-jax[neural]>=0.5",
"pyarrow", # required for dask.dataframe
"scanpy",
"scikit-learn==1.5.1",
"scipy<1.16", # see https://github.com/statsmodels/statsmodels/issues/9584
"scipy<1.16", # see https://github.com/statsmodels/statsmodels/issues/9584
"session-info",
]

Expand Down
87 changes: 87 additions & 0 deletions src/cellflow/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Compatibility layer for ``ott-jax`` across versions.

``ott-jax>=0.6`` removed ``ott.neural.methods.flows.dynamics`` and the
``ott.neural.networks.velocity_field.VelocityField`` (flax linen) class.
This module re-exports the symbols needed by CellFlow so that both
``ott-jax>=0.5,<0.6`` and ``ott-jax>=0.6`` are supported.
"""

# ---------------------------------------------------------------------------
# Probability-path dynamics (BaseFlow, ConstantNoiseFlow, BrownianBridge)
#
# For ott-jax <0.6 we import directly from ott. For ott-jax >=0.6 the
# module was removed, so we provide a vendored copy below.
#
# The fallback classes are a verbatim copy of
# ott.neural.methods.flows.dynamics
# from ott-jax 0.5.0 (commit 690b1ae, 2024-12-03).
# ott-jax is licensed under the Apache License 2.0, which permits
# reproduction and distribution of derivative works provided the license
# and copyright notice are retained. See:
# https://github.com/ott-jax/ott/blob/0.5.0/LICENSE
# ---------------------------------------------------------------------------
try:
from ott.neural.methods.flows.dynamics import ( # ott-jax <0.6
BaseFlow,
BrownianBridge,
ConstantNoiseFlow,
)
except ImportError:
# -- Vendored from ott-jax 0.5.0 (Apache-2.0) --------------------------
# Source: src/ott/neural/methods/flows/dynamics.py
# Copyright OTT-JAX contributors
# -----------------------------------------------------------------------
import abc

import jax
import jax.numpy as jnp

class BaseFlow(abc.ABC):
"""Base class for all flows."""

def __init__(self, sigma: float):
self.sigma = sigma

@abc.abstractmethod
def compute_mu_t(self, t: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray) -> jnp.ndarray: ...

@abc.abstractmethod
def compute_sigma_t(self, t: jnp.ndarray) -> jnp.ndarray: ...

@abc.abstractmethod
def compute_ut(self, t: jnp.ndarray, x: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray) -> jnp.ndarray: ...

def compute_xt(self, rng: jax.Array, t: jnp.ndarray, x0: jnp.ndarray, x1: jnp.ndarray) -> jnp.ndarray:
"""Sample from the probability path."""
noise = jax.random.normal(rng, shape=x0.shape)
mu_t = self.compute_mu_t(t, x0, x1)
sigma_t = self.compute_sigma_t(t)
return mu_t + sigma_t * noise

class _StraightFlow(BaseFlow, abc.ABC):
def compute_mu_t(self, t, x0, x1):
return (1.0 - t) * x0 + t * x1

def compute_ut(self, t, x, x0, x1):
del t, x
return x1 - x0

class ConstantNoiseFlow(_StraightFlow):
r"""Flow with straight paths and constant noise :math:`\sigma`."""

def compute_sigma_t(self, t):
return jnp.full_like(t, fill_value=self.sigma)

class BrownianBridge(_StraightFlow):
r"""Brownian Bridge with :math:`\sigma_t = \sigma \sqrt{t(1-t)}`."""

def compute_sigma_t(self, t):
return self.sigma * jnp.sqrt(t * (1.0 - t))

def compute_ut(self, t, x, x0, x1):
drift_term = (1 - 2 * t) / (2 * t * (1 - t)) * (x - (t * x1 + (1 - t) * x0))
control_term = x1 - x0
return drift_term + control_term


__all__ = ["BaseFlow", "ConstantNoiseFlow", "BrownianBridge"]
6 changes: 3 additions & 3 deletions src/cellflow/model/_cellflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import numpy as np
import optax
import pandas as pd
from ott.neural.methods.flows import dynamics

from cellflow import _constants
from cellflow._compat import BrownianBridge, ConstantNoiseFlow
from cellflow._types import ArrayLike, Layers_separate_input_t, Layers_t
from cellflow.data._data import ConditionData, TrainingData, ValidationData
from cellflow.data._dataloader import OOCTrainSampler, PredictionSampler, TrainSampler, ValidationSampler
Expand Down Expand Up @@ -477,9 +477,9 @@ def prepare_model(

probability_path, noise = next(iter(probability_path.items()))
if probability_path == "constant_noise":
probability_path = dynamics.ConstantNoiseFlow(noise)
probability_path = ConstantNoiseFlow(noise)
elif probability_path == "bridge":
probability_path = dynamics.BrownianBridge(noise)
probability_path = BrownianBridge(noise)
else:
raise NotImplementedError(
f"The key of `probability_path` must be `'constant_noise'` or `'bridge'` but found {probability_path}."
Expand Down
8 changes: 4 additions & 4 deletions src/cellflow/solvers/_genot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.training import train_state
from ott.neural.methods.flows import dynamics
from ott.neural.networks import velocity_field
from ott.solvers import utils as solver_utils

from cellflow import utils
from cellflow._compat import BaseFlow
from cellflow._types import ArrayLike
from cellflow.model._utils import _multivariate_normal

Expand Down Expand Up @@ -57,8 +57,8 @@ class GENOT:

def __init__(
self,
vf: velocity_field.VelocityField,
probability_path: dynamics.BaseFlow,
vf: nn.Module,
probability_path: BaseFlow,
data_match_fn: DataMatchFn,
*,
source_dim: int,
Expand Down
4 changes: 2 additions & 2 deletions src/cellflow/solvers/_otfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import numpy as np
from flax.core import frozen_dict
from flax.training import train_state
from ott.neural.methods.flows import dynamics
from ott.solvers import utils as solver_utils

from cellflow import utils
from cellflow._compat import BaseFlow
from cellflow._types import ArrayLike
from cellflow.networks._velocity_field import ConditionalVelocityField
from cellflow.solvers.utils import ema_update
Expand Down Expand Up @@ -47,7 +47,7 @@ class OTFlowMatching:
def __init__(
self,
vf: ConditionalVelocityField,
probability_path: dynamics.BaseFlow,
probability_path: BaseFlow,
match_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] | None = None,
time_sampler: Callable[[jax.Array, int], jnp.ndarray] = solver_utils.uniform_sampler,
**kwargs: Any,
Expand Down
8 changes: 4 additions & 4 deletions tests/solver/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import numpy as np
import optax
import pytest
from ott.neural.methods.flows import dynamics

import cellflow
from cellflow._compat import ConstantNoiseFlow
from cellflow.solvers import _genot, _otfm
from cellflow.utils import match_linear

Expand Down Expand Up @@ -42,7 +42,7 @@ def test_predict_batch(self, dataloader, solver_class):
solver = _otfm.OTFlowMatching(
vf=vf,
match_fn=match_linear,
probability_path=dynamics.ConstantNoiseFlow(0.0),
probability_path=ConstantNoiseFlow(0.0),
optimizer=opt,
conditions={"drug": np.random.rand(2, 1, 3)},
rng=vf_rng,
Expand All @@ -51,7 +51,7 @@ def test_predict_batch(self, dataloader, solver_class):
solver = _genot.GENOT(
vf=vf,
data_match_fn=match_linear,
probability_path=dynamics.ConstantNoiseFlow(0.0),
probability_path=ConstantNoiseFlow(0.0),
optimizer=opt,
source_dim=5,
target_dim=5,
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_EMA(self, dataloader, ema):
solver1 = _otfm.OTFlowMatching(
vf=vf1,
match_fn=match_linear,
probability_path=dynamics.ConstantNoiseFlow(0.0),
probability_path=ConstantNoiseFlow(0.0),
optimizer=opt,
conditions={"drug": drug},
rng=vf_rng,
Expand Down
12 changes: 6 additions & 6 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import numpy as np
import optax
import pytest
from ott.neural.methods.flows import dynamics

import cellflow
from cellflow._compat import ConstantNoiseFlow
from cellflow.solvers import _otfm
from cellflow.training import CellFlowTrainer, ComputationCallback, Metrics
from cellflow.utils import match_linear
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_cellflow_trainer(self, dataloader, valid_freq):
model = _otfm.OTFlowMatching(
vf=vf,
match_fn=match_linear,
probability_path=dynamics.ConstantNoiseFlow(0.0),
probability_path=ConstantNoiseFlow(0.0),
optimizer=opt,
conditions=cond,
rng=vf_rng,
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_cellflow_trainer_with_callback(self, dataloader, valid_loader, use_vali
model = _otfm.OTFlowMatching(
vf=vf,
match_fn=match_linear,
probability_path=dynamics.ConstantNoiseFlow(0.0),
probability_path=ConstantNoiseFlow(0.0),
optimizer=opt,
conditions=cond,
rng=vf_rng,
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_cellflow_trainer_with_custom_callback(self, dataloader, valid_loader):
solver = _otfm.OTFlowMatching(
vf=vf,
match_fn=match_linear,
probability_path=dynamics.ConstantNoiseFlow(0.0),
probability_path=ConstantNoiseFlow(0.0),
optimizer=opt,
conditions=cond,
rng=vf_rng,
Expand Down Expand Up @@ -181,15 +181,15 @@ def test_predict_kwargs_iter(self, dataloader, valid_loader):
model_1 = _otfm.OTFlowMatching(
vf=vf_1,
match_fn=match_linear,
probability_path=dynamics.ConstantNoiseFlow(0.0),
probability_path=ConstantNoiseFlow(0.0),
optimizer=opt_1,
conditions=cond,
rng=vf_rng,
)
model_2 = _otfm.OTFlowMatching(
vf=vf_2,
match_fn=match_linear,
probability_path=dynamics.ConstantNoiseFlow(0.0),
probability_path=ConstantNoiseFlow(0.0),
optimizer=opt_2,
conditions=cond,
rng=vf_rng,
Expand Down
Loading