From ef6684e8f2a0da5f7a3d4aa0e237e3ff1e856cc7 Mon Sep 17 00:00:00 2001 From: AlejandroTL Date: Fri, 6 Sep 2024 17:50:03 +0200 Subject: [PATCH] SampleMetrics callback --- src/cfp/training/__init__.py | 2 + src/cfp/training/_callbacks.py | 96 ++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/src/cfp/training/__init__.py b/src/cfp/training/__init__.py index 95b10625..ec40d7bd 100644 --- a/src/cfp/training/__init__.py +++ b/src/cfp/training/__init__.py @@ -4,6 +4,7 @@ ComputationCallback, LoggingCallback, Metrics, + SampledMetrics, PCADecodedMetrics, WandbLogger, ) @@ -19,4 +20,5 @@ "CallbackRunner", "PCADecodedMetrics", "PCADecoder", + "SampledMetrics", ] diff --git a/src/cfp/training/_callbacks.py b/src/cfp/training/_callbacks.py index 73a11db1..65c58fae 100644 --- a/src/cfp/training/_callbacks.py +++ b/src/cfp/training/_callbacks.py @@ -7,6 +7,7 @@ import jax.tree_util as jtu import numpy as np from numpy.typing import ArrayLike +import random from cfp.metrics._metrics import compute_e_distance, compute_r_squared, compute_scalar_mmd, compute_sinkhorn_div @@ -18,6 +19,7 @@ "WandbLogger", "CallbackRunner", "PCADecodedMetrics", + "SampledMetrics", ] @@ -210,6 +212,100 @@ def on_train_end( Predicted data """ return self.on_log_iteration(validation_data, predicted_data) + + + +class SampledMetrics(ComputationCallback): + """Callback to compute metrics on sampled validation data during training + + Parameters + ---------- + sample_size : int + Number of samples to use for metric computation + metrics : list + List of metrics to compute. Supported metrics are "sinkhorn_div", "e_distance", and "mmd". + metric_aggregations : list + List of aggregation functions to use for each metric. Supported aggregations are "mean" and "median". + log_prefix : str + Prefix to add to the log keys. + """ + + def __init__( + self, + sample_size: int, + metrics: list[Literal["sinkhorn_div", "e_distance", "mmd"]], + metric_aggregations: list[Literal["mean", "median"]] = None, + log_prefix: str = "sampled_", + ): + self.sample_size = sample_size + self.metrics = metrics + self.metric_aggregation = ( + ["mean"] if metric_aggregations is None else metric_aggregations + ) + self.log_prefix = log_prefix + + for metric in metrics: + if metric not in ["sinkhorn_div", "e_distance", "mmd"]: + raise ValueError( + f"Metric {metric} not supported. Supported metrics are 'sinkhorn_div', 'e_distance', and 'mmd'" + ) + + def on_train_begin(self, *args: Any, **kwargs: Any) -> Any: + """Called at the beginning of training.""" + pass + + def sample_data(self, data: ArrayLike) -> ArrayLike: + """Sample data randomly""" + if len(data) <= self.sample_size: + return data + indices = random.sample(range(len(data)), self.sample_size) + return data[indices] + + def on_log_iteration( + self, + validation_data: dict[str, dict[str, ArrayLike]], + predicted_data: dict[str, dict[str, ArrayLike]], + ) -> dict[str, float]: + """Called at each validation/log iteration to compute metrics on sampled data + + Args: + validation_data: Validation data + predicted_data: Predicted data + """ + metrics = {} + for metric in self.metrics: + for k in validation_data.keys(): + sampled_validation = self.sample_data(validation_data[k]) + sampled_predicted = self.sample_data(predicted_data[k]) + + if metric == "sinkhorn_div": + result = compute_sinkhorn_div(sampled_validation, sampled_predicted) + elif metric == "e_distance": + result = compute_e_distance(sampled_validation, sampled_predicted) + elif metric == "mmd": + result = compute_scalar_mmd(sampled_validation, sampled_predicted) + + for agg_fn in self.metric_aggregation: + metrics[f"{self.log_prefix}{k}_{metric}_{agg_fn}"] = agg_fn_to_func[agg_fn](result) + + return metrics + + def on_train_end( + self, + validation_data: dict[str, dict[str, ArrayLike]], + predicted_data: dict[str, dict[str, ArrayLike]], + ) -> dict[str, float]: + """Called at the end of training to compute metrics + + Parameters + ---------- + validation_data : dict + Validation data + predicted_data : dict + Predicted data + """ + return self.on_log_iteration(validation_data, predicted_data) + class PCADecodedMetrics(Metrics):