From b983ddbefd83d74d6414d66b8f66a4697765de79 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 5 Nov 2024 12:30:35 +0100 Subject: [PATCH 01/14] first draft classifier-free guidance --- src/cfp/model/_cellflow.py | 9 +++++++++ src/cfp/solvers/_otfm.py | 38 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/cfp/model/_cellflow.py b/src/cfp/model/_cellflow.py index dfc71349..9680730d 100644 --- a/src/cfp/model/_cellflow.py +++ b/src/cfp/model/_cellflow.py @@ -246,6 +246,8 @@ def prepare_model( flow: dict[Literal["constant_noise", "bridge"], float] | None = None, match_fn: Callable[[ArrayLike, ArrayLike], ArrayLike] = match_linear, optimizer: optax.GradientTransformation = optax.adam(1e-4), + cfg_p_resample: float = 0.0, + cfg_ode_weight: float = 0.0, layer_norm_before_concatenation: bool = False, linear_projection_before_concatenation: bool = False, genot_source_layers: Layers_t | None = None, @@ -346,6 +348,11 @@ def prepare_model( data and return the optimal transport matrix, see e.g. :func:`cfp.utils.match_linear`. optimizer Optimizer used for training. + cfg_p_resample + Probability of the null condition for classifier free guidance. + cfg_ode_weight + Weighting factor of the null condition for classifier free guidance. + 0 corresponds to no classifier-free guidance, the larger 0, the more guidance. layer_norm_before_concatenation If :obj:`True`, applies layer normalization before concatenating the embedded time, embedded data, and condition embeddings. @@ -447,6 +454,8 @@ def prepare_model( match_fn=match_fn, flow=flow, optimizer=optimizer, + cfg_p_resample=cfg_p_resample, + cfg_ode_weight=cfg_ode_weight, conditions=self.train_data.condition_data, rng=jax.random.PRNGKey(seed), **solver_kwargs, diff --git a/src/cfp/solvers/_otfm.py b/src/cfp/solvers/_otfm.py index 31e4866e..3c194807 100644 --- a/src/cfp/solvers/_otfm.py +++ b/src/cfp/solvers/_otfm.py @@ -36,6 +36,11 @@ class OTFlowMatching: time_sampler Time sampler with a ``(rng, n_samples) -> time`` signature, see e.g. :func:`ott.solvers.utils.uniform_sampler`. + cfg_p_resample + Probability of the null condition for classifier free guidance. + cfg_ode_weight + Weighting factor of the null condition for classifier free guidance. + 0 corresponds to no classifier-free guidance, the larger 0, the more guidance. kwargs Keyword arguments for :meth:`cfp.networks.ConditionalVelocityField.create_train_state`. """ @@ -48,12 +53,26 @@ def __init__( time_sampler: Callable[ [jax.Array, int], jnp.ndarray ] = solver_utils.uniform_sampler, + cfg_p_resample: float = 0.0, + cfg_ode_weight: float = 0.0, **kwargs: Any, ): self._is_trained: bool = False self.vf = vf self.flow = flow self.time_sampler = time_sampler + if cfg_p_resample > 0 and cfg_ode_weight == 0: + raise ValueError( + "cfg_p_resample > 0 requires cfg_ode_weight > 0 for classifier free guidance." + ) + if cfg_p_resample == 0 and cfg_ode_weight > 0: + raise ValueError( + "cfg_ode_weight > 0 requires cfg_p_resample > 0 for classifier free guidance." + ) + if cfg_ode_weight < 0: + raise ValueError("cfg_ode_weight must be non-negative.") + self.cfg_p_resample = cfg_p_resample + self.cfg_ode_weight = cfg_ode_weight self.match_fn = jax.jit(match_fn) self.vf_state = self.vf.create_train_state( @@ -125,7 +144,12 @@ def step_fn( """ src, tgt = batch["src_cell_data"], batch["tgt_cell_data"] condition = batch.get("condition") - rng_resample, rng_step_fn = jax.random.split(rng, 2) + rng_resample, rng_cfg, rng_step_fn = jax.random.split(rng, 3) + cfg_null = jax.random.bernoulli(rng_cfg, self.cfg_p_resample) + if cfg_null: + # TODO: adapt to null condition in transformer + condition = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), condition) + if self.match_fn is not None: tmat = self.match_fn(src, tgt) src_ixs, tgt_ixs = solver_utils.sample_joint(rng_resample, tmat) @@ -192,8 +216,18 @@ def vf( params = self.vf_state.params return self.vf_state.apply_fn({"params": params}, t, x, cond, train=False) + def vf_cfg( + t: jnp.ndarray, x: jnp.ndarray, cond: dict[str, jnp.ndarray] | None + ) -> jnp.ndarray: + # TODO: adapt to null condition in transformer + null_cond = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), cond) + params = self.vf_state.params + v_cond = self.vf_state.apply_fn({"params": params}, t, x, cond) + v_uncond = self.vf_state.apply_fn({"params": params}, t, x, null_cond) + return (1 + self.cfg_ode_weight) * v_cond - self.cfg_ode_weight * v_uncond + def solve_ode(x: jnp.ndarray, condition: jnp.ndarray | None) -> jnp.ndarray: - ode_term = diffrax.ODETerm(vf) + ode_term = diffrax.ODETerm(vf_cfg if self.cfg_p_resample else vf) result = diffrax.diffeqsolve( ode_term, t0=0.0, From caf489a45033a318723ee1f957f5cbed690679ad Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 5 Nov 2024 13:29:49 +0100 Subject: [PATCH 02/14] add test --- tests/model/test_cellflow.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/model/test_cellflow.py b/tests/model/test_cellflow.py index 89686398..14e3a06b 100644 --- a/tests/model/test_cellflow.py +++ b/tests/model/test_cellflow.py @@ -16,11 +16,18 @@ class TestCellFlow: @pytest.mark.parametrize("solver", ["otfm", "genot"]) + @pytest.mark.parametrize("use_classifier_free_guidance", [False, True]) def test_cellflow_solver( - self, - adata_perturbation, - solver, + self, adata_perturbation, solver, use_classifier_free_guidance ): + if solver == "genot" and use_classifier_free_guidance: + pytest.skip("Classifier free guidance is not implemented for GENOT") + if use_classifier_free_guidance: + cfg_p_resample = 0.3 + cfg_ode_weight = 2.0 + else: + cfg_p_resample = 0.0 + cfg_ode_weight = 0.0 sample_rep = "X" control_key = "control" perturbation_covariates = {"drug": ["drug1", "drug2"]} @@ -47,6 +54,8 @@ def test_cellflow_solver( hidden_dims=(32, 32), decoder_dims=(32, 32), condition_encoder_kwargs=condition_encoder_kwargs, + cfg_p_resample=cfg_p_resample, + cfg_ode_weight=cfg_ode_weight, ) assert cf._trainer is not None From 2a03c358c4bd1a99977137847e75c04779535ef5 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 6 Nov 2024 20:00:20 +0100 Subject: [PATCH 03/14] try fix --- src/cfp/solvers/_otfm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/cfp/solvers/_otfm.py b/src/cfp/solvers/_otfm.py index 3c194807..e7c580f9 100644 --- a/src/cfp/solvers/_otfm.py +++ b/src/cfp/solvers/_otfm.py @@ -209,6 +209,8 @@ def predict( kwargs.setdefault( "stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5) ) + null_cond = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), condition) + def vf( t: jnp.ndarray, x: jnp.ndarray, cond: dict[str, jnp.ndarray] | None @@ -220,7 +222,6 @@ def vf_cfg( t: jnp.ndarray, x: jnp.ndarray, cond: dict[str, jnp.ndarray] | None ) -> jnp.ndarray: # TODO: adapt to null condition in transformer - null_cond = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), cond) params = self.vf_state.params v_cond = self.vf_state.apply_fn({"params": params}, t, x, cond) v_uncond = self.vf_state.apply_fn({"params": params}, t, x, null_cond) From c75f8ebf9f578c1adb03811ad737cd2aa64ab94a Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 6 Nov 2024 20:18:10 +0100 Subject: [PATCH 04/14] try ones_like instead --- src/cfp/solvers/_otfm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cfp/solvers/_otfm.py b/src/cfp/solvers/_otfm.py index e7c580f9..f21f0ccc 100644 --- a/src/cfp/solvers/_otfm.py +++ b/src/cfp/solvers/_otfm.py @@ -148,7 +148,7 @@ def step_fn( cfg_null = jax.random.bernoulli(rng_cfg, self.cfg_p_resample) if cfg_null: # TODO: adapt to null condition in transformer - condition = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), condition) + condition = jax.tree_util.tree_map(lambda x: jnp.ones_like(x), condition) if self.match_fn is not None: tmat = self.match_fn(src, tgt) @@ -209,7 +209,7 @@ def predict( kwargs.setdefault( "stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5) ) - null_cond = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), condition) + null_cond = jax.tree_util.tree_map(lambda x: jnp.ones_like(x), condition) def vf( From ea13acee93bb6c77df616fd162c5d20cb47bfae6 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 6 Nov 2024 20:27:49 +0100 Subject: [PATCH 05/14] try dummy --- src/cfp/solvers/_otfm.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/cfp/solvers/_otfm.py b/src/cfp/solvers/_otfm.py index f21f0ccc..f0d69514 100644 --- a/src/cfp/solvers/_otfm.py +++ b/src/cfp/solvers/_otfm.py @@ -221,11 +221,14 @@ def vf( def vf_cfg( t: jnp.ndarray, x: jnp.ndarray, cond: dict[str, jnp.ndarray] | None ) -> jnp.ndarray: - # TODO: adapt to null condition in transformer params = self.vf_state.params - v_cond = self.vf_state.apply_fn({"params": params}, t, x, cond) - v_uncond = self.vf_state.apply_fn({"params": params}, t, x, null_cond) - return (1 + self.cfg_ode_weight) * v_cond - self.cfg_ode_weight * v_uncond + return self.vf_state.apply_fn({"params": params}, t, x, cond, train=False) + + # TODO: adapt to null condition in transformer + #params = self.vf_state.params + #v_cond = self.vf_state.apply_fn({"params": params}, t, x, cond) + #v_uncond = self.vf_state.apply_fn({"params": params}, t, x, null_cond) + #return (1 + self.cfg_ode_weight) * v_cond - self.cfg_ode_weight * v_uncond def solve_ode(x: jnp.ndarray, condition: jnp.ndarray | None) -> jnp.ndarray: ode_term = diffrax.ODETerm(vf_cfg if self.cfg_p_resample else vf) From e738419db7c77cc5cc89771272c57a57107cab7e Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 6 Nov 2024 20:35:41 +0100 Subject: [PATCH 06/14] before this it worked --- src/cfp/solvers/_otfm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/cfp/solvers/_otfm.py b/src/cfp/solvers/_otfm.py index f0d69514..304e5b88 100644 --- a/src/cfp/solvers/_otfm.py +++ b/src/cfp/solvers/_otfm.py @@ -215,6 +215,7 @@ def predict( def vf( t: jnp.ndarray, x: jnp.ndarray, cond: dict[str, jnp.ndarray] | None ) -> jnp.ndarray: + cond = jax.tree_util.tree_map(lambda x: jnp.ones_like(x), cond) params = self.vf_state.params return self.vf_state.apply_fn({"params": params}, t, x, cond, train=False) From a86971124accb9ac9c6fff60903da15cd8ea9e75 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 6 Nov 2024 20:49:47 +0100 Subject: [PATCH 07/14] now jtu tree map in right vf --- src/cfp/solvers/_otfm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cfp/solvers/_otfm.py b/src/cfp/solvers/_otfm.py index 304e5b88..eb0b98ab 100644 --- a/src/cfp/solvers/_otfm.py +++ b/src/cfp/solvers/_otfm.py @@ -215,13 +215,13 @@ def predict( def vf( t: jnp.ndarray, x: jnp.ndarray, cond: dict[str, jnp.ndarray] | None ) -> jnp.ndarray: - cond = jax.tree_util.tree_map(lambda x: jnp.ones_like(x), cond) params = self.vf_state.params return self.vf_state.apply_fn({"params": params}, t, x, cond, train=False) def vf_cfg( t: jnp.ndarray, x: jnp.ndarray, cond: dict[str, jnp.ndarray] | None ) -> jnp.ndarray: + cond = jax.tree_util.tree_map(lambda x: jnp.ones_like(x), cond) params = self.vf_state.params return self.vf_state.apply_fn({"params": params}, t, x, cond, train=False) From 1b0def39cebbd03d197b378f63f67e40da9f9f6d Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 6 Nov 2024 21:04:23 +0100 Subject: [PATCH 08/14] before working --- src/cfp/solvers/_otfm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cfp/solvers/_otfm.py b/src/cfp/solvers/_otfm.py index eb0b98ab..c365343c 100644 --- a/src/cfp/solvers/_otfm.py +++ b/src/cfp/solvers/_otfm.py @@ -223,7 +223,7 @@ def vf_cfg( ) -> jnp.ndarray: cond = jax.tree_util.tree_map(lambda x: jnp.ones_like(x), cond) params = self.vf_state.params - return self.vf_state.apply_fn({"params": params}, t, x, cond, train=False) + return (1 + self.cfg_ode_weight) * self.vf_state.apply_fn({"params": params}, t, x, cond, train=False) # TODO: adapt to null condition in transformer #params = self.vf_state.params From c6612944cbc9ccd89f01ffde946feace861be66a Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 6 Nov 2024 21:14:19 +0100 Subject: [PATCH 09/14] before working --- src/cfp/solvers/_otfm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cfp/solvers/_otfm.py b/src/cfp/solvers/_otfm.py index c365343c..173a7e0d 100644 --- a/src/cfp/solvers/_otfm.py +++ b/src/cfp/solvers/_otfm.py @@ -221,9 +221,9 @@ def vf( def vf_cfg( t: jnp.ndarray, x: jnp.ndarray, cond: dict[str, jnp.ndarray] | None ) -> jnp.ndarray: - cond = jax.tree_util.tree_map(lambda x: jnp.ones_like(x), cond) + cond_mask = jax.tree_util.tree_map(lambda x: jnp.ones_like(x), cond) params = self.vf_state.params - return (1 + self.cfg_ode_weight) * self.vf_state.apply_fn({"params": params}, t, x, cond, train=False) + return (1 + self.cfg_ode_weight) * self.vf_state.apply_fn({"params": params}, t, x, cond, train=False) - self.cfg_ode_weight * (1 + self.cfg_ode_weight) * self.vf_state.apply_fn({"params": params}, t, x, cond_mask, train=False) # TODO: adapt to null condition in transformer #params = self.vf_state.params From 7e9c27b9cd5fe1623910e22a72acf4a6e61e970f Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 6 Nov 2024 21:25:02 +0100 Subject: [PATCH 10/14] before working --- src/cfp/solvers/_otfm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/cfp/solvers/_otfm.py b/src/cfp/solvers/_otfm.py index 173a7e0d..50688351 100644 --- a/src/cfp/solvers/_otfm.py +++ b/src/cfp/solvers/_otfm.py @@ -209,9 +209,7 @@ def predict( kwargs.setdefault( "stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5) ) - null_cond = jax.tree_util.tree_map(lambda x: jnp.ones_like(x), condition) - def vf( t: jnp.ndarray, x: jnp.ndarray, cond: dict[str, jnp.ndarray] | None ) -> jnp.ndarray: @@ -223,7 +221,7 @@ def vf_cfg( ) -> jnp.ndarray: cond_mask = jax.tree_util.tree_map(lambda x: jnp.ones_like(x), cond) params = self.vf_state.params - return (1 + self.cfg_ode_weight) * self.vf_state.apply_fn({"params": params}, t, x, cond, train=False) - self.cfg_ode_weight * (1 + self.cfg_ode_weight) * self.vf_state.apply_fn({"params": params}, t, x, cond_mask, train=False) + return (1 + self.cfg_ode_weight) * self.vf_state.apply_fn({"params": params}, t, x, cond, train=False) - self.cfg_ode_weight * self.vf_state.apply_fn({"params": params}, t, x, cond_mask, train=False) # TODO: adapt to null condition in transformer #params = self.vf_state.params From 011fb944569475a46175d2baba87b1d2caba0aa5 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 6 Nov 2024 21:49:41 +0100 Subject: [PATCH 11/14] before working --- src/cfp/solvers/_otfm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cfp/solvers/_otfm.py b/src/cfp/solvers/_otfm.py index 50688351..fb078ef7 100644 --- a/src/cfp/solvers/_otfm.py +++ b/src/cfp/solvers/_otfm.py @@ -148,7 +148,7 @@ def step_fn( cfg_null = jax.random.bernoulli(rng_cfg, self.cfg_p_resample) if cfg_null: # TODO: adapt to null condition in transformer - condition = jax.tree_util.tree_map(lambda x: jnp.ones_like(x), condition) + condition = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), condition) if self.match_fn is not None: tmat = self.match_fn(src, tgt) @@ -219,7 +219,7 @@ def vf( def vf_cfg( t: jnp.ndarray, x: jnp.ndarray, cond: dict[str, jnp.ndarray] | None ) -> jnp.ndarray: - cond_mask = jax.tree_util.tree_map(lambda x: jnp.ones_like(x), cond) + cond_mask = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), cond) params = self.vf_state.params return (1 + self.cfg_ode_weight) * self.vf_state.apply_fn({"params": params}, t, x, cond, train=False) - self.cfg_ode_weight * self.vf_state.apply_fn({"params": params}, t, x, cond_mask, train=False) From 5d2679ab25e1a97d95cbee9a6dbb7b8c62ce2d97 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 15 Nov 2024 08:59:30 +0100 Subject: [PATCH 12/14] adapt validation --- src/cfp/training/_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cfp/training/_trainer.py b/src/cfp/training/_trainer.py index 0040a826..4aaee27e 100644 --- a/src/cfp/training/_trainer.py +++ b/src/cfp/training/_trainer.py @@ -117,7 +117,7 @@ def train( loss = self.solver.step_fn(rng_step_fn, batch) self.training_logs["loss"].append(float(loss)) - if ((it - 1) % valid_freq == 0) and (it > 1): + if ((it + 1) % valid_freq == 0) and (it > 1): # Get predictions from validation data valid_true_data, valid_pred_data = self._validation_step( valid_loaders, mode="on_log_iteration" From 3e8d3512013a97a3b0d8fa3ddc96bd7fbc469de8 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 15 Nov 2024 09:04:04 +0100 Subject: [PATCH 13/14] adapt validation --- src/cfp/solvers/_otfm.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/cfp/solvers/_otfm.py b/src/cfp/solvers/_otfm.py index fb078ef7..88ddebfc 100644 --- a/src/cfp/solvers/_otfm.py +++ b/src/cfp/solvers/_otfm.py @@ -79,6 +79,7 @@ def __init__( input_dim=self.vf.output_dims[-1], **kwargs ) self.vf_step_fn = self._get_vf_step_fn() + self.null_value_cfg = self.vf.mask_value def _get_vf_step_fn(self) -> Callable: # type: ignore[type-arg] @@ -148,7 +149,7 @@ def step_fn( cfg_null = jax.random.bernoulli(rng_cfg, self.cfg_p_resample) if cfg_null: # TODO: adapt to null condition in transformer - condition = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), condition) + condition = jax.tree_util.tree_map(lambda x: jnp.full(x.shape, self.null_value_cfg), condition) if self.match_fn is not None: tmat = self.match_fn(src, tgt) @@ -219,16 +220,10 @@ def vf( def vf_cfg( t: jnp.ndarray, x: jnp.ndarray, cond: dict[str, jnp.ndarray] | None ) -> jnp.ndarray: - cond_mask = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), cond) + cond_mask = jax.tree_util.tree_map(lambda x: jnp.full(x.shape, self.null_value_cfg), cond) params = self.vf_state.params return (1 + self.cfg_ode_weight) * self.vf_state.apply_fn({"params": params}, t, x, cond, train=False) - self.cfg_ode_weight * self.vf_state.apply_fn({"params": params}, t, x, cond_mask, train=False) - # TODO: adapt to null condition in transformer - #params = self.vf_state.params - #v_cond = self.vf_state.apply_fn({"params": params}, t, x, cond) - #v_uncond = self.vf_state.apply_fn({"params": params}, t, x, null_cond) - #return (1 + self.cfg_ode_weight) * v_cond - self.cfg_ode_weight * v_uncond - def solve_ode(x: jnp.ndarray, condition: jnp.ndarray | None) -> jnp.ndarray: ode_term = diffrax.ODETerm(vf_cfg if self.cfg_p_resample else vf) result = diffrax.diffeqsolve( From 4c72c0a27b497f4cdb1d48a2a6c98575ec6b0e25 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 15 Nov 2024 09:10:04 +0100 Subject: [PATCH 14/14] run pre-commit --- src/cfp/solvers/_otfm.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/cfp/solvers/_otfm.py b/src/cfp/solvers/_otfm.py index 88ddebfc..75c2c5d8 100644 --- a/src/cfp/solvers/_otfm.py +++ b/src/cfp/solvers/_otfm.py @@ -149,7 +149,9 @@ def step_fn( cfg_null = jax.random.bernoulli(rng_cfg, self.cfg_p_resample) if cfg_null: # TODO: adapt to null condition in transformer - condition = jax.tree_util.tree_map(lambda x: jnp.full(x.shape, self.null_value_cfg), condition) + condition = jax.tree_util.tree_map( + lambda x: jnp.full(x.shape, self.null_value_cfg), condition + ) if self.match_fn is not None: tmat = self.match_fn(src, tgt) @@ -210,7 +212,7 @@ def predict( kwargs.setdefault( "stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5) ) - + def vf( t: jnp.ndarray, x: jnp.ndarray, cond: dict[str, jnp.ndarray] | None ) -> jnp.ndarray: @@ -220,9 +222,15 @@ def vf( def vf_cfg( t: jnp.ndarray, x: jnp.ndarray, cond: dict[str, jnp.ndarray] | None ) -> jnp.ndarray: - cond_mask = jax.tree_util.tree_map(lambda x: jnp.full(x.shape, self.null_value_cfg), cond) + cond_mask = jax.tree_util.tree_map( + lambda x: jnp.full(x.shape, self.null_value_cfg), cond + ) params = self.vf_state.params - return (1 + self.cfg_ode_weight) * self.vf_state.apply_fn({"params": params}, t, x, cond, train=False) - self.cfg_ode_weight * self.vf_state.apply_fn({"params": params}, t, x, cond_mask, train=False) + return (1 + self.cfg_ode_weight) * self.vf_state.apply_fn( + {"params": params}, t, x, cond, train=False + ) - self.cfg_ode_weight * self.vf_state.apply_fn( + {"params": params}, t, x, cond_mask, train=False + ) def solve_ode(x: jnp.ndarray, condition: jnp.ndarray | None) -> jnp.ndarray: ode_term = diffrax.ODETerm(vf_cfg if self.cfg_p_resample else vf)