From 54e2e9e0ae505cd3d63f1a4e7722388f755f2a47 Mon Sep 17 00:00:00 2001 From: LeonStadelmann Date: Sat, 3 May 2025 18:05:25 +0200 Subject: [PATCH 1/8] init --- src/cellflow/training/_callbacks.py | 57 +++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/src/cellflow/training/_callbacks.py b/src/cellflow/training/_callbacks.py index fda372fc..54940b84 100644 --- a/src/cellflow/training/_callbacks.py +++ b/src/cellflow/training/_callbacks.py @@ -266,6 +266,63 @@ def on_log_iteration( return metrics +class PCADecodedMetrics2(Metrics): + """Callback to compute metrics on true validation data during training + + Parameters + ---------- + ref_adata + An :class:`~anndata.AnnData` object with the reference data containing + ``adata.varm["X_mean"]`` and ``adata.varm["PCs"]``. + metrics + List of metrics to compute. Supported metrics are ``"r_squared"``, ``"mmd"``, + ``"sinkhorn_div"``, and ``"e_distance"``. + metric_aggregations + List of aggregation functions to use for each metric. Supported aggregations are ``"mean"`` + and ``"median"``. + condition_id_key + Key in :attr:`~anndata.AnnData.obs` that defines the condition id. + log_prefix + Prefix to add to the log keys. + """ + + def __init__( + self, + ref_adata: ad.AnnData, + metrics: list[Literal["r_squared", "mmd", "sinkhorn_div", "e_distance"]], + metric_aggregations: list[Literal["mean", "median"]] = None, + condition_id_key: str | None = None, + log_prefix: str = "pca_decoded_", + ): + super().__init__(metrics, metric_aggregations) + self.pcs = ref_adata.varm["PCs"] + self.means = ref_adata.varm["X_mean"] + self.reconstruct_data = lambda x: x @ np.transpose(self.pcs) + np.transpose(self.means) + self.condition_id_key = condition_id_key + self.log_prefix = log_prefix + + def on_log_iteration( + self, + validation_data: dict[str, dict[str, ArrayLike]], + predicted_data: dict[str, dict[str, ArrayLike]], + ) -> dict[str, float]: + """Called at each validation/log iteration to reconstruct the data and compute metrics on the reconstruction + + Parameters + ---------- + validation_data + Validation data in nested dictionary format with same keys as ``predicted_data`` + predicted_data + Predicted data in nested dictionary format with same keys as ``validation_data`` + """ + validation_data_gt = None + predicted_data_decoded = jtu.tree_map(self.reconstruct_data, predicted_data) + + metrics = super().on_log_iteration(validation_data_gt, predicted_data_decoded) + metrics = {f"{self.log_prefix}{k}": v for k, v in metrics.items()} + return metrics + + class VAEDecodedMetrics(Metrics): """Callback to compute metrics on decoded validation data during training From 577da5a1d0d00658418a8121d3598849a11d3caf Mon Sep 17 00:00:00 2001 From: LeonStadelmann Date: Wed, 21 May 2025 14:36:56 +0200 Subject: [PATCH 2/8] Add method --- src/cellflow/model/_cellflow.py | 5 ++- src/cellflow/training/_callbacks.py | 52 +++++++++++++++++++++++------ src/cellflow/training/_trainer.py | 3 ++ 3 files changed, 49 insertions(+), 11 deletions(-) diff --git a/src/cellflow/model/_cellflow.py b/src/cellflow/model/_cellflow.py index 859cf10b..53624d81 100644 --- a/src/cellflow/model/_cellflow.py +++ b/src/cellflow/model/_cellflow.py @@ -56,6 +56,7 @@ def __init__(self, adata: ad.AnnData, solver: Literal["otfm", "genot"] = "otfm") self._dataloader: TrainSampler | None = None self._trainer: CellFlowTrainer | None = None self._validation_data: dict[str, ValidationData] = {} + self._validation_adata: dict[str, ad.Anndata] = {} self._solver: _otfm.OTFlowMatching | _genot.GENOT | None = None self._condition_dim: int | None = None self._vf: _velocity_field.ConditionalVelocityField | _velocity_field.GENOTConditionalVelocityField | None = None @@ -225,6 +226,7 @@ def prepare_validation_data( n_conditions_on_log_iteration=n_conditions_on_log_iteration, n_conditions_on_train_end=n_conditions_on_train_end, ) + self._validation_adata[name] = adata self._validation_data[name] = val_data def prepare_model( @@ -498,7 +500,8 @@ def prepare_model( ) else: raise NotImplementedError(f"Solver must be an instance of OTFlowMatching or GENOT, got {type(self.solver)}") - self._trainer = CellFlowTrainer(solver=self.solver) # type: ignore[arg-type] + validation_adata = self._validation_adata or {} + self._trainer = CellFlowTrainer(solver=self.solver, validation_adata=validation_adata) # type: ignore[arg-type] def train( self, diff --git a/src/cellflow/training/_callbacks.py b/src/cellflow/training/_callbacks.py index 54940b84..035c177f 100644 --- a/src/cellflow/training/_callbacks.py +++ b/src/cellflow/training/_callbacks.py @@ -14,6 +14,7 @@ compute_scalar_mmd, compute_sinkhorn_div, ) +from cellflow.solvers import _genot, _otfm __all__ = [ "BaseCallback", @@ -301,27 +302,58 @@ def __init__( self.condition_id_key = condition_id_key self.log_prefix = log_prefix + def add_validation_adata( + self, + validation_adata: dict[str, ad.AnnData], + ) -> None: + self.validation_adata = validation_adata + def on_log_iteration( self, - validation_data: dict[str, dict[str, ArrayLike]], - predicted_data: dict[str, dict[str, ArrayLike]], + valid_source_data: dict[str, dict[str, ArrayLike]], + valid_true_data: dict[str, dict[str, ArrayLike]], + valid_pred_data: dict[str, dict[str, ArrayLike]], + solver: _genot.GENOT | _otfm.OTFlowMatching, ) -> dict[str, float]: """Called at each validation/log iteration to reconstruct the data and compute metrics on the reconstruction Parameters ---------- - validation_data - Validation data in nested dictionary format with same keys as ``predicted_data`` - predicted_data - Predicted data in nested dictionary format with same keys as ``validation_data`` + valid_source_data + Source data in nested dictionary format with same keys as ``valid_true_data`` + valid_true_data + Validation data in nested dictionary format with same keys as ``valid_pred_data`` + valid_pred_data + Predicted data in nested dictionary format with same keys as ``valid_true_data`` + solver + :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + solver with a conditional velocity field. """ - validation_data_gt = None - predicted_data_decoded = jtu.tree_map(self.reconstruct_data, predicted_data) - - metrics = super().on_log_iteration(validation_data_gt, predicted_data_decoded) + true_counts = {} + for name in self.validation_adata.keys(): + true_counts[name] = {} + conditions_adata = self.validation_adata[name].obs[self.condition_id_key].unique() + conditions_pred = valid_pred_data[name].keys() + for cond in conditions_adata & conditions_pred: + true_counts[name][cond] = self.validation_adata[name][ + self.validation_adata[name].obs[self.condition_id_key] == cond + ].X + + predicted_data_decoded = jtu.tree_map(self.reconstruct_data, valid_pred_data) + + metrics = super().on_log_iteration(true_counts, predicted_data_decoded) metrics = {f"{self.log_prefix}{k}": v for k, v in metrics.items()} return metrics + def on_train_end( + self, + valid_source_data: dict[str, dict[str, ArrayLike]], + valid_true_data: dict[str, dict[str, ArrayLike]], + valid_pred_data: dict[str, dict[str, ArrayLike]], + solver: _genot.GENOT | _otfm.OTFlowMatching, + ) -> dict[str, float]: + return super().on_train_end(valid_source_data, valid_true_data, valid_pred_data, solver) + class VAEDecodedMetrics(Metrics): """Callback to compute metrics on decoded validation data during training diff --git a/src/cellflow/training/_trainer.py b/src/cellflow/training/_trainer.py index e898393c..98f71415 100644 --- a/src/cellflow/training/_trainer.py +++ b/src/cellflow/training/_trainer.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from typing import Any, Literal +import anndata as ad import jax import numpy as np from numpy.typing import ArrayLike @@ -31,12 +32,14 @@ class CellFlowTrainer: def __init__( self, solver: _otfm.OTFlowMatching | _genot.GENOT, + validation_adata: dict[str, ad.AnnData], seed: int = 0, ): if not isinstance(solver, (_otfm.OTFlowMatching | _genot.GENOT)): raise NotImplementedError(f"Solver must be an instance of OTFlowMatching or GENOT, got {type(solver)}") self.solver = solver + self.validation_adata = validation_adata self.rng_subsampling = np.random.default_rng(seed) self.training_logs: dict[str, Any] = {} From 5574c8febbd184e0f58a3fc5d9552516b725647e Mon Sep 17 00:00:00 2001 From: LeonStadelmann Date: Wed, 21 May 2025 16:07:39 +0200 Subject: [PATCH 3/8] Add test --- src/cellflow/training/__init__.py | 2 ++ src/cellflow/training/_callbacks.py | 9 +++++---- src/cellflow/training/_trainer.py | 6 +++++- tests/trainer/test_callbacks.py | 23 +++++++++++++++++++++++ 4 files changed, 35 insertions(+), 5 deletions(-) diff --git a/src/cellflow/training/__init__.py b/src/cellflow/training/__init__.py index 387411d2..a2d5f53e 100644 --- a/src/cellflow/training/__init__.py +++ b/src/cellflow/training/__init__.py @@ -5,6 +5,7 @@ LoggingCallback, Metrics, PCADecodedMetrics, + PCADecodedMetrics2, VAEDecodedMetrics, WandbLogger, ) @@ -19,6 +20,7 @@ "WandbLogger", "CallbackRunner", "PCADecodedMetrics", + "PCADecodedMetrics2", "PCADecoder", "VAEDecodedMetrics", ] diff --git a/src/cellflow/training/_callbacks.py b/src/cellflow/training/_callbacks.py index 035c177f..7c9af4dd 100644 --- a/src/cellflow/training/_callbacks.py +++ b/src/cellflow/training/_callbacks.py @@ -24,6 +24,7 @@ "WandbLogger", "CallbackRunner", "PCADecodedMetrics", + "PCADecodedMetrics2", "VAEDecodedMetrics", ] @@ -292,8 +293,8 @@ def __init__( ref_adata: ad.AnnData, metrics: list[Literal["r_squared", "mmd", "sinkhorn_div", "e_distance"]], metric_aggregations: list[Literal["mean", "median"]] = None, - condition_id_key: str | None = None, - log_prefix: str = "pca_decoded_", + condition_id_key: str = "condition", + log_prefix: str = "pca_decoded_2_", ): super().__init__(metrics, metric_aggregations) self.pcs = ref_adata.varm["PCs"] @@ -332,12 +333,12 @@ def on_log_iteration( true_counts = {} for name in self.validation_adata.keys(): true_counts[name] = {} - conditions_adata = self.validation_adata[name].obs[self.condition_id_key].unique() + conditions_adata = set(self.validation_adata[name].obs[self.condition_id_key].unique()) conditions_pred = valid_pred_data[name].keys() for cond in conditions_adata & conditions_pred: true_counts[name][cond] = self.validation_adata[name][ self.validation_adata[name].obs[self.condition_id_key] == cond - ].X + ].X.toarray() predicted_data_decoded = jtu.tree_map(self.reconstruct_data, valid_pred_data) diff --git a/src/cellflow/training/_trainer.py b/src/cellflow/training/_trainer.py index 98f71415..4006d643 100644 --- a/src/cellflow/training/_trainer.py +++ b/src/cellflow/training/_trainer.py @@ -9,7 +9,7 @@ from cellflow.data._dataloader import TrainSampler, ValidationSampler from cellflow.solvers import _genot, _otfm -from cellflow.training._callbacks import BaseCallback, CallbackRunner +from cellflow.training._callbacks import BaseCallback, CallbackRunner, PCADecodedMetrics2 class CellFlowTrainer: @@ -106,6 +106,10 @@ def train( self.training_logs = {"loss": []} rng = jax.random.PRNGKey(0) + for callback in callbacks: + if isinstance(callback, PCADecodedMetrics2): + callback.add_validation_adata(self.validation_adata) + # Initiate callbacks valid_loaders = valid_loaders or {} crun = CallbackRunner( diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index f1346ce6..4cd586a6 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -1,6 +1,7 @@ import anndata as ad import jax.numpy as jnp import jax.tree_util as jtu +import numpy as np import pytest @@ -18,6 +19,28 @@ def test_pca_reconstruction(self, adata_pca: ad.AnnData, metrics): assert reconstruction.shape == adata_pca.X.shape assert jnp.allclose(reconstruction, adata_pca.layers["counts"]) + def test_pca_decoded_2(self, adata_pca: ad.AnnData): + from cellflow.solvers import OTFlowMatching + from cellflow.training import PCADecodedMetrics2 + + adata_gt = adata_pca.copy() + adata_gt.obs["condition"] = np.random.choice(["A", "B"], size=adata_pca.shape[0]) + + decoded_metrics_callback = PCADecodedMetrics2( + ref_adata=adata_pca, metrics=["r_squared"], condition_id_key="condition" + ) + + callbacks = [decoded_metrics_callback] + for e in callbacks: + if isinstance(e, PCADecodedMetrics2): + e.add_validation_adata({"test": adata_gt}) + + valid_pred_data = {"test": {"A": np.random.random((2, 10)), "B": np.random.random((2, 10))}} + + res = decoded_metrics_callback.on_log_iteration({}, {}, valid_pred_data, OTFlowMatching) + assert "pca_decoded_2_test_r_squared_mean" in res + assert isinstance(res["pca_decoded_2_test_r_squared_mean"], float) + @pytest.mark.parametrize("metrics", [["r_squared"]]) def test_vae_reconstruction(self, metrics): from scvi.data import synthetic_iid From e960ee14b0b89d70712e0406b2f034945594a742 Mon Sep 17 00:00:00 2001 From: LeonStadelmann Date: Wed, 21 May 2025 16:11:41 +0200 Subject: [PATCH 4/8] fix typo --- src/cellflow/training/_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cellflow/training/_callbacks.py b/src/cellflow/training/_callbacks.py index 7c9af4dd..aed4deae 100644 --- a/src/cellflow/training/_callbacks.py +++ b/src/cellflow/training/_callbacks.py @@ -353,7 +353,7 @@ def on_train_end( valid_pred_data: dict[str, dict[str, ArrayLike]], solver: _genot.GENOT | _otfm.OTFlowMatching, ) -> dict[str, float]: - return super().on_train_end(valid_source_data, valid_true_data, valid_pred_data, solver) + return self.on_log_iteration(valid_source_data, valid_true_data, valid_pred_data, solver) class VAEDecodedMetrics(Metrics): From ca25068cb2efe1e97849312f2d9c1cbd0abccd9e Mon Sep 17 00:00:00 2001 From: LeonStadelmann Date: Fri, 23 May 2025 13:22:01 +0200 Subject: [PATCH 5/8] Allow for dense and layer input --- src/cellflow/training/_callbacks.py | 16 +++++++++++++--- tests/trainer/test_callbacks.py | 10 ++++++++-- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/cellflow/training/_callbacks.py b/src/cellflow/training/_callbacks.py index aed4deae..bd2bf3ad 100644 --- a/src/cellflow/training/_callbacks.py +++ b/src/cellflow/training/_callbacks.py @@ -6,6 +6,7 @@ import jax.tree as jt import jax.tree_util as jtu import numpy as np +import scipy from cellflow._types import ArrayLike from cellflow.metrics._metrics import ( @@ -284,6 +285,9 @@ class PCADecodedMetrics2(Metrics): and ``"median"``. condition_id_key Key in :attr:`~anndata.AnnData.obs` that defines the condition id. + layer + Key in :attr:`~anndata.AnnData.layers` from which to get the counts. + If :obj:`None`, use :attr:`~anndata.AnnData.X`. log_prefix Prefix to add to the log keys. """ @@ -294,6 +298,7 @@ def __init__( metrics: list[Literal["r_squared", "mmd", "sinkhorn_div", "e_distance"]], metric_aggregations: list[Literal["mean", "median"]] = None, condition_id_key: str = "condition", + layers: str | None = None, log_prefix: str = "pca_decoded_2_", ): super().__init__(metrics, metric_aggregations) @@ -301,6 +306,7 @@ def __init__( self.means = ref_adata.varm["X_mean"] self.reconstruct_data = lambda x: x @ np.transpose(self.pcs) + np.transpose(self.means) self.condition_id_key = condition_id_key + self.layers = layers self.log_prefix = log_prefix def add_validation_adata( @@ -336,9 +342,13 @@ def on_log_iteration( conditions_adata = set(self.validation_adata[name].obs[self.condition_id_key].unique()) conditions_pred = valid_pred_data[name].keys() for cond in conditions_adata & conditions_pred: - true_counts[name][cond] = self.validation_adata[name][ - self.validation_adata[name].obs[self.condition_id_key] == cond - ].X.toarray() + condition_mask = self.validation_adata[name].obs[self.condition_id_key] == cond + counts = ( + self.validation_adata[name][condition_mask].X + if self.layers is None + else self.validation_adata[name][condition_mask].layers[self.layers] + ) + true_counts[name][cond] = counts.toarray() if scipy.sparse.issparse(counts) else counts predicted_data_decoded = jtu.tree_map(self.reconstruct_data, valid_pred_data) diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index 4cd586a6..b8f58da5 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -19,15 +19,21 @@ def test_pca_reconstruction(self, adata_pca: ad.AnnData, metrics): assert reconstruction.shape == adata_pca.X.shape assert jnp.allclose(reconstruction, adata_pca.layers["counts"]) - def test_pca_decoded_2(self, adata_pca: ad.AnnData): + @pytest.mark.parametrize("sparse_matrix", [True, False]) + @pytest.mark.parametrize("layers", [None, "test"]) + def test_pca_decoded_2(self, adata_pca: ad.AnnData, sparse_matrix, layers): from cellflow.solvers import OTFlowMatching from cellflow.training import PCADecodedMetrics2 adata_gt = adata_pca.copy() adata_gt.obs["condition"] = np.random.choice(["A", "B"], size=adata_pca.shape[0]) + if not sparse_matrix: + adata_gt.X = adata_gt.X.toarray() + if layers is not None: + adata_gt.layers[layers] = adata_gt.X.copy() decoded_metrics_callback = PCADecodedMetrics2( - ref_adata=adata_pca, metrics=["r_squared"], condition_id_key="condition" + ref_adata=adata_pca, metrics=["r_squared"], condition_id_key="condition", layers=layers ) callbacks = [decoded_metrics_callback] From 41a7dff9e9d3ed879eb9e2699e7ed43c4b78b486 Mon Sep 17 00:00:00 2001 From: LeonStadelmann Date: Wed, 11 Jun 2025 13:12:35 +0200 Subject: [PATCH 6/8] Pass validation adata in init --- src/cellflow/model/_cellflow.py | 5 +---- src/cellflow/training/__init__.py | 1 - src/cellflow/training/_callbacks.py | 12 ++++++------ src/cellflow/training/_trainer.py | 11 +---------- tests/trainer/test_callbacks.py | 11 +++++------ 5 files changed, 13 insertions(+), 27 deletions(-) diff --git a/src/cellflow/model/_cellflow.py b/src/cellflow/model/_cellflow.py index 53624d81..859cf10b 100644 --- a/src/cellflow/model/_cellflow.py +++ b/src/cellflow/model/_cellflow.py @@ -56,7 +56,6 @@ def __init__(self, adata: ad.AnnData, solver: Literal["otfm", "genot"] = "otfm") self._dataloader: TrainSampler | None = None self._trainer: CellFlowTrainer | None = None self._validation_data: dict[str, ValidationData] = {} - self._validation_adata: dict[str, ad.Anndata] = {} self._solver: _otfm.OTFlowMatching | _genot.GENOT | None = None self._condition_dim: int | None = None self._vf: _velocity_field.ConditionalVelocityField | _velocity_field.GENOTConditionalVelocityField | None = None @@ -226,7 +225,6 @@ def prepare_validation_data( n_conditions_on_log_iteration=n_conditions_on_log_iteration, n_conditions_on_train_end=n_conditions_on_train_end, ) - self._validation_adata[name] = adata self._validation_data[name] = val_data def prepare_model( @@ -500,8 +498,7 @@ def prepare_model( ) else: raise NotImplementedError(f"Solver must be an instance of OTFlowMatching or GENOT, got {type(self.solver)}") - validation_adata = self._validation_adata or {} - self._trainer = CellFlowTrainer(solver=self.solver, validation_adata=validation_adata) # type: ignore[arg-type] + self._trainer = CellFlowTrainer(solver=self.solver) # type: ignore[arg-type] def train( self, diff --git a/src/cellflow/training/__init__.py b/src/cellflow/training/__init__.py index a2d5f53e..d6b6f176 100644 --- a/src/cellflow/training/__init__.py +++ b/src/cellflow/training/__init__.py @@ -21,6 +21,5 @@ "CallbackRunner", "PCADecodedMetrics", "PCADecodedMetrics2", - "PCADecoder", "VAEDecodedMetrics", ] diff --git a/src/cellflow/training/_callbacks.py b/src/cellflow/training/_callbacks.py index bd2bf3ad..1c4e7d33 100644 --- a/src/cellflow/training/_callbacks.py +++ b/src/cellflow/training/_callbacks.py @@ -277,6 +277,10 @@ class PCADecodedMetrics2(Metrics): ref_adata An :class:`~anndata.AnnData` object with the reference data containing ``adata.varm["X_mean"]`` and ``adata.varm["PCs"]``. + validation_adata + Dictionary where the keys are the names of the datasets given in + :func:`~cellflow.model.prepare_validation_data` and the values are the corresponding + :class:`~anndata.AnnData` objects. metrics List of metrics to compute. Supported metrics are ``"r_squared"``, ``"mmd"``, ``"sinkhorn_div"``, and ``"e_distance"``. @@ -295,6 +299,7 @@ class PCADecodedMetrics2(Metrics): def __init__( self, ref_adata: ad.AnnData, + validation_adata: dict[str, ad.AnnData], metrics: list[Literal["r_squared", "mmd", "sinkhorn_div", "e_distance"]], metric_aggregations: list[Literal["mean", "median"]] = None, condition_id_key: str = "condition", @@ -305,16 +310,11 @@ def __init__( self.pcs = ref_adata.varm["PCs"] self.means = ref_adata.varm["X_mean"] self.reconstruct_data = lambda x: x @ np.transpose(self.pcs) + np.transpose(self.means) + self.validation_adata = validation_adata self.condition_id_key = condition_id_key self.layers = layers self.log_prefix = log_prefix - def add_validation_adata( - self, - validation_adata: dict[str, ad.AnnData], - ) -> None: - self.validation_adata = validation_adata - def on_log_iteration( self, valid_source_data: dict[str, dict[str, ArrayLike]], diff --git a/src/cellflow/training/_trainer.py b/src/cellflow/training/_trainer.py index 4006d643..5ace511a 100644 --- a/src/cellflow/training/_trainer.py +++ b/src/cellflow/training/_trainer.py @@ -1,7 +1,6 @@ from collections.abc import Sequence from typing import Any, Literal -import anndata as ad import jax import numpy as np from numpy.typing import ArrayLike @@ -9,7 +8,7 @@ from cellflow.data._dataloader import TrainSampler, ValidationSampler from cellflow.solvers import _genot, _otfm -from cellflow.training._callbacks import BaseCallback, CallbackRunner, PCADecodedMetrics2 +from cellflow.training._callbacks import BaseCallback, CallbackRunner class CellFlowTrainer: @@ -17,8 +16,6 @@ class CellFlowTrainer: Parameters ---------- - dataloader - Data sampler. solver OTFM/GENOT solver with a conditional velocity field. seed @@ -32,14 +29,12 @@ class CellFlowTrainer: def __init__( self, solver: _otfm.OTFlowMatching | _genot.GENOT, - validation_adata: dict[str, ad.AnnData], seed: int = 0, ): if not isinstance(solver, (_otfm.OTFlowMatching | _genot.GENOT)): raise NotImplementedError(f"Solver must be an instance of OTFlowMatching or GENOT, got {type(solver)}") self.solver = solver - self.validation_adata = validation_adata self.rng_subsampling = np.random.default_rng(seed) self.training_logs: dict[str, Any] = {} @@ -106,10 +101,6 @@ def train( self.training_logs = {"loss": []} rng = jax.random.PRNGKey(0) - for callback in callbacks: - if isinstance(callback, PCADecodedMetrics2): - callback.add_validation_adata(self.validation_adata) - # Initiate callbacks valid_loaders = valid_loaders or {} crun = CallbackRunner( diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index b8f58da5..4459362d 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -33,14 +33,13 @@ def test_pca_decoded_2(self, adata_pca: ad.AnnData, sparse_matrix, layers): adata_gt.layers[layers] = adata_gt.X.copy() decoded_metrics_callback = PCADecodedMetrics2( - ref_adata=adata_pca, metrics=["r_squared"], condition_id_key="condition", layers=layers + ref_adata=adata_pca, + validation_adata={"test": adata_gt}, + metrics=["r_squared"], + condition_id_key="condition", + layers=layers, ) - callbacks = [decoded_metrics_callback] - for e in callbacks: - if isinstance(e, PCADecodedMetrics2): - e.add_validation_adata({"test": adata_gt}) - valid_pred_data = {"test": {"A": np.random.random((2, 10)), "B": np.random.random((2, 10))}} res = decoded_metrics_callback.on_log_iteration({}, {}, valid_pred_data, OTFlowMatching) From dea483ffcfde4b3fc89dc9632677d518b6ba5476 Mon Sep 17 00:00:00 2001 From: LeonStadelmann Date: Wed, 11 Jun 2025 14:43:22 +0200 Subject: [PATCH 7/8] Add new matrix to docs --- docs/user/training.rst | 1 + src/cellflow/training/_callbacks.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/user/training.rst b/docs/user/training.rst index a2741424..28a6dc42 100644 --- a/docs/user/training.rst +++ b/docs/user/training.rst @@ -11,6 +11,7 @@ Training training.LoggingCallback training.Metrics training.PCADecodedMetrics + training.PCADecodedMetrics2 training.VAEDecodedMetrics training.WandbLogger training.CellFlowTrainer diff --git a/src/cellflow/training/_callbacks.py b/src/cellflow/training/_callbacks.py index 1c4e7d33..c08759bc 100644 --- a/src/cellflow/training/_callbacks.py +++ b/src/cellflow/training/_callbacks.py @@ -279,7 +279,7 @@ class PCADecodedMetrics2(Metrics): ``adata.varm["X_mean"]`` and ``adata.varm["PCs"]``. validation_adata Dictionary where the keys are the names of the datasets given in - :func:`~cellflow.model.prepare_validation_data` and the values are the corresponding + :func:`~cellflow.model.CellFlow.prepare_validation_data` and the values are the corresponding :class:`~anndata.AnnData` objects. metrics List of metrics to compute. Supported metrics are ``"r_squared"``, ``"mmd"``, From c37431c3e55b7fbe11eabc955e8b94f6bd6155d6 Mon Sep 17 00:00:00 2001 From: LeonStadelmann Date: Wed, 11 Jun 2025 14:46:31 +0200 Subject: [PATCH 8/8] Add func arguments --- src/cellflow/training/_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cellflow/training/_callbacks.py b/src/cellflow/training/_callbacks.py index 45092d0f..ae63a863 100644 --- a/src/cellflow/training/_callbacks.py +++ b/src/cellflow/training/_callbacks.py @@ -406,7 +406,7 @@ def on_log_iteration( predicted_data_decoded = jtu.tree_map(self.reconstruct_data, valid_pred_data) - metrics = super().on_log_iteration(true_counts, predicted_data_decoded) + metrics = super().on_log_iteration(valid_source_data, true_counts, predicted_data_decoded, solver) metrics = {f"{self.log_prefix}{k}": v for k, v in metrics.items()} return metrics