diff --git a/src/cfp/model/_cellflow.py b/src/cfp/model/_cellflow.py index dfc71349..9680730d 100644 --- a/src/cfp/model/_cellflow.py +++ b/src/cfp/model/_cellflow.py @@ -246,6 +246,8 @@ def prepare_model( flow: dict[Literal["constant_noise", "bridge"], float] | None = None, match_fn: Callable[[ArrayLike, ArrayLike], ArrayLike] = match_linear, optimizer: optax.GradientTransformation = optax.adam(1e-4), + cfg_p_resample: float = 0.0, + cfg_ode_weight: float = 0.0, layer_norm_before_concatenation: bool = False, linear_projection_before_concatenation: bool = False, genot_source_layers: Layers_t | None = None, @@ -346,6 +348,11 @@ def prepare_model( data and return the optimal transport matrix, see e.g. :func:`cfp.utils.match_linear`. optimizer Optimizer used for training. + cfg_p_resample + Probability of the null condition for classifier free guidance. + cfg_ode_weight + Weighting factor of the null condition for classifier free guidance. + 0 corresponds to no classifier-free guidance, the larger 0, the more guidance. layer_norm_before_concatenation If :obj:`True`, applies layer normalization before concatenating the embedded time, embedded data, and condition embeddings. @@ -447,6 +454,8 @@ def prepare_model( match_fn=match_fn, flow=flow, optimizer=optimizer, + cfg_p_resample=cfg_p_resample, + cfg_ode_weight=cfg_ode_weight, conditions=self.train_data.condition_data, rng=jax.random.PRNGKey(seed), **solver_kwargs, diff --git a/src/cfp/solvers/_otfm.py b/src/cfp/solvers/_otfm.py index 31e4866e..75c2c5d8 100644 --- a/src/cfp/solvers/_otfm.py +++ b/src/cfp/solvers/_otfm.py @@ -36,6 +36,11 @@ class OTFlowMatching: time_sampler Time sampler with a ``(rng, n_samples) -> time`` signature, see e.g. :func:`ott.solvers.utils.uniform_sampler`. + cfg_p_resample + Probability of the null condition for classifier free guidance. + cfg_ode_weight + Weighting factor of the null condition for classifier free guidance. + 0 corresponds to no classifier-free guidance, the larger 0, the more guidance. kwargs Keyword arguments for :meth:`cfp.networks.ConditionalVelocityField.create_train_state`. """ @@ -48,18 +53,33 @@ def __init__( time_sampler: Callable[ [jax.Array, int], jnp.ndarray ] = solver_utils.uniform_sampler, + cfg_p_resample: float = 0.0, + cfg_ode_weight: float = 0.0, **kwargs: Any, ): self._is_trained: bool = False self.vf = vf self.flow = flow self.time_sampler = time_sampler + if cfg_p_resample > 0 and cfg_ode_weight == 0: + raise ValueError( + "cfg_p_resample > 0 requires cfg_ode_weight > 0 for classifier free guidance." + ) + if cfg_p_resample == 0 and cfg_ode_weight > 0: + raise ValueError( + "cfg_ode_weight > 0 requires cfg_p_resample > 0 for classifier free guidance." + ) + if cfg_ode_weight < 0: + raise ValueError("cfg_ode_weight must be non-negative.") + self.cfg_p_resample = cfg_p_resample + self.cfg_ode_weight = cfg_ode_weight self.match_fn = jax.jit(match_fn) self.vf_state = self.vf.create_train_state( input_dim=self.vf.output_dims[-1], **kwargs ) self.vf_step_fn = self._get_vf_step_fn() + self.null_value_cfg = self.vf.mask_value def _get_vf_step_fn(self) -> Callable: # type: ignore[type-arg] @@ -125,7 +145,14 @@ def step_fn( """ src, tgt = batch["src_cell_data"], batch["tgt_cell_data"] condition = batch.get("condition") - rng_resample, rng_step_fn = jax.random.split(rng, 2) + rng_resample, rng_cfg, rng_step_fn = jax.random.split(rng, 3) + cfg_null = jax.random.bernoulli(rng_cfg, self.cfg_p_resample) + if cfg_null: + # TODO: adapt to null condition in transformer + condition = jax.tree_util.tree_map( + lambda x: jnp.full(x.shape, self.null_value_cfg), condition + ) + if self.match_fn is not None: tmat = self.match_fn(src, tgt) src_ixs, tgt_ixs = solver_utils.sample_joint(rng_resample, tmat) @@ -192,8 +219,21 @@ def vf( params = self.vf_state.params return self.vf_state.apply_fn({"params": params}, t, x, cond, train=False) + def vf_cfg( + t: jnp.ndarray, x: jnp.ndarray, cond: dict[str, jnp.ndarray] | None + ) -> jnp.ndarray: + cond_mask = jax.tree_util.tree_map( + lambda x: jnp.full(x.shape, self.null_value_cfg), cond + ) + params = self.vf_state.params + return (1 + self.cfg_ode_weight) * self.vf_state.apply_fn( + {"params": params}, t, x, cond, train=False + ) - self.cfg_ode_weight * self.vf_state.apply_fn( + {"params": params}, t, x, cond_mask, train=False + ) + def solve_ode(x: jnp.ndarray, condition: jnp.ndarray | None) -> jnp.ndarray: - ode_term = diffrax.ODETerm(vf) + ode_term = diffrax.ODETerm(vf_cfg if self.cfg_p_resample else vf) result = diffrax.diffeqsolve( ode_term, t0=0.0, diff --git a/src/cfp/training/_trainer.py b/src/cfp/training/_trainer.py index 0040a826..4aaee27e 100644 --- a/src/cfp/training/_trainer.py +++ b/src/cfp/training/_trainer.py @@ -117,7 +117,7 @@ def train( loss = self.solver.step_fn(rng_step_fn, batch) self.training_logs["loss"].append(float(loss)) - if ((it - 1) % valid_freq == 0) and (it > 1): + if ((it + 1) % valid_freq == 0) and (it > 1): # Get predictions from validation data valid_true_data, valid_pred_data = self._validation_step( valid_loaders, mode="on_log_iteration" diff --git a/tests/model/test_cellflow.py b/tests/model/test_cellflow.py index 89686398..14e3a06b 100644 --- a/tests/model/test_cellflow.py +++ b/tests/model/test_cellflow.py @@ -16,11 +16,18 @@ class TestCellFlow: @pytest.mark.parametrize("solver", ["otfm", "genot"]) + @pytest.mark.parametrize("use_classifier_free_guidance", [False, True]) def test_cellflow_solver( - self, - adata_perturbation, - solver, + self, adata_perturbation, solver, use_classifier_free_guidance ): + if solver == "genot" and use_classifier_free_guidance: + pytest.skip("Classifier free guidance is not implemented for GENOT") + if use_classifier_free_guidance: + cfg_p_resample = 0.3 + cfg_ode_weight = 2.0 + else: + cfg_p_resample = 0.0 + cfg_ode_weight = 0.0 sample_rep = "X" control_key = "control" perturbation_covariates = {"drug": ["drug1", "drug2"]} @@ -47,6 +54,8 @@ def test_cellflow_solver( hidden_dims=(32, 32), decoder_dims=(32, 32), condition_encoder_kwargs=condition_encoder_kwargs, + cfg_p_resample=cfg_p_resample, + cfg_ode_weight=cfg_ode_weight, ) assert cf._trainer is not None