diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index a5dbf12a3..34c1272d1 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -1,6 +1,7 @@ from collections.abc import Mapping, Sequence, Callable import numpy as np +import warnings import keras @@ -9,7 +10,6 @@ from bayesflow.types import Tensor from bayesflow.utils import ( filter_kwargs, - logging, split_arrays, squeeze_inner_estimates_dict, concatenate_valid, @@ -148,18 +148,21 @@ def build_adapter( def compile( self, *args, - inference_metrics: Sequence[keras.Metric] = None, - summary_metrics: Sequence[keras.Metric] = None, **kwargs, ): - if inference_metrics: - self.inference_network._metrics = inference_metrics + if "inference_metrics" in kwargs: + warnings.warn( + "Supplying inference metrics to the approximator is no longer supported. " + "Please pass the metrics directly to the network using the metrics parameter.", + DeprecationWarning, + ) - if summary_metrics: - if self.summary_network is None: - logging.warning("Ignoring summary metrics because there is no summary network.") - else: - self.summary_network._metrics = summary_metrics + if "summary_metrics" in kwargs: + warnings.warn( + "Supplying summary metrics to the approximator is no longer supported. " + "Please pass the metrics directly to the network using the metrics parameter.", + DeprecationWarning, + ) return super().compile(*args, **kwargs) @@ -329,16 +332,6 @@ def get_config(self): return base_config | serialize(config) - def get_compile_config(self): - base_config = super().get_compile_config() or {} - - config = { - "inference_metrics": self.inference_network._metrics, - "summary_metrics": self.summary_network._metrics if self.summary_network is not None else None, - } - - return base_config | serialize(config) - def estimate( self, conditions: Mapping[str, np.ndarray], diff --git a/bayesflow/approximators/model_comparison_approximator.py b/bayesflow/approximators/model_comparison_approximator.py index d71aafdaf..ebaedf924 100644 --- a/bayesflow/approximators/model_comparison_approximator.py +++ b/bayesflow/approximators/model_comparison_approximator.py @@ -2,6 +2,7 @@ import keras import numpy as np +import warnings from bayesflow.adapters import Adapter from bayesflow.datasets import OnlineDataset @@ -151,18 +152,21 @@ def build_dataset( def compile( self, *args, - classifier_metrics: Sequence[keras.Metric] = None, - summary_metrics: Sequence[keras.Metric] = None, **kwargs, ): - if classifier_metrics: - self.classifier_network._metrics = classifier_metrics + if "classifier_metrics" in kwargs: + warnings.warn( + "Supplying classifier metrics to the approximator is no longer supported. " + "Please pass the metrics directly to the network using the metrics parameter.", + DeprecationWarning, + ) - if summary_metrics: - if self.summary_network is None: - logging.warning("Ignoring summary metrics because there is no summary network.") - else: - self.summary_network._metrics = summary_metrics + if "summary_metrics" in kwargs: + warnings.warn( + "Supplying summary metrics to the approximator is no longer supported. " + "Please pass the metrics directly to the network using the metrics parameter.", + DeprecationWarning, + ) return super().compile(*args, **kwargs) @@ -223,9 +227,10 @@ def compute_metrics( classifier_metrics = {"loss": cross_entropy} if stage != "training" and any(self.classifier_network.metrics): - predictions = keras.ops.argmax(logits, axis=-1) + # compute sample-based metrics + probabilities = keras.ops.softmax(logits) classifier_metrics |= { - metric.name: metric(model_indices, predictions) for metric in self.classifier_network.metrics + metric.name: metric(model_indices, probabilities) for metric in self.classifier_network.metrics } if "loss" in summary_metrics: @@ -342,16 +347,6 @@ def get_config(self): return base_config | serialize(config) - def get_compile_config(self): - base_config = super().get_compile_config() or {} - - config = { - "classifier_metrics": self.classifier_network._metrics, - "summary_metrics": self.summary_network._metrics if self.summary_network is not None else None, - } - - return base_config | serialize(config) - def predict( self, *, diff --git a/bayesflow/experimental/free_form_flow/free_form_flow.py b/bayesflow/experimental/free_form_flow/free_form_flow.py index d1e826864..bc324770f 100644 --- a/bayesflow/experimental/free_form_flow/free_form_flow.py +++ b/bayesflow/experimental/free_form_flow/free_form_flow.py @@ -8,7 +8,6 @@ find_network, jacobian, jvp, - model_kwargs, vjp, weighted_mean, ) @@ -240,7 +239,6 @@ def from_config(cls, config, custom_objects=None): def get_config(self): base_config = super().get_config() - base_config = model_kwargs(base_config) config = { "beta": self.beta, diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py index 3e1778e89..170181279 100644 --- a/bayesflow/networks/consistency_models/consistency_model.py +++ b/bayesflow/networks/consistency_models/consistency_model.py @@ -4,7 +4,7 @@ import numpy as np from bayesflow.types import Tensor -from bayesflow.utils import find_network, layer_kwargs, weighted_mean, tensor_utils, expand_right_as +from bayesflow.utils import find_network, weighted_mean, tensor_utils, expand_right_as from bayesflow.utils.serialization import deserialize, serializable, serialize from ..inference_network import InferenceNetwork @@ -115,7 +115,9 @@ def from_config(cls, config, custom_objects=None): def get_config(self): base_config = super().get_config() - base_config = layer_kwargs(base_config) + + # base distribution is passed manually to InferenceNetwork parent class, do not store it here + base_config.pop("base_distribution") config = { "total_steps": self.total_steps, diff --git a/bayesflow/networks/coupling_flow/coupling_flow.py b/bayesflow/networks/coupling_flow/coupling_flow.py index 781e6148d..7302c720b 100644 --- a/bayesflow/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/networks/coupling_flow/coupling_flow.py @@ -3,7 +3,6 @@ from bayesflow.types import Tensor from bayesflow.utils import ( find_permutation, - layer_kwargs, weighted_mean, ) from bayesflow.utils.serialization import deserialize, serializable, serialize @@ -91,7 +90,7 @@ def __init__( Keyword arguments forwarded to the affine or spline transforms (e.g., bins for splines) **kwargs - Additional keyword arguments passed to `InvertibleLayer`. + Additional keyword arguments passed to `InferenceNetwork`. """ super().__init__(base_distribution=base_distribution, **kwargs) @@ -131,7 +130,6 @@ def from_config(cls, config, custom_objects=None): def get_config(self): base_config = super().get_config() - base_config = layer_kwargs(base_config) config = { "subnet": self.subnet, diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 744e86b37..6b2f22934 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -10,7 +10,6 @@ expand_right_as, find_network, jacobian_trace, - layer_kwargs, weighted_mean, integrate, integrate_stochastic, @@ -156,7 +155,9 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: def get_config(self): base_config = super().get_config() - base_config = layer_kwargs(base_config) + + # base distribution is passed manually to InferenceNetwork parent class, do not store it here + base_config.pop("base_distribution") config = { "subnet": self.subnet, diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index da4acd321..0f5e67df5 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -9,7 +9,6 @@ find_network, integrate, jacobian_trace, - layer_kwargs, optimal_transport, weighted_mean, tensor_utils, @@ -154,7 +153,6 @@ def from_config(cls, config, custom_objects=None): def get_config(self): base_config = super().get_config() - base_config = layer_kwargs(base_config) config = { "subnet": self.subnet, diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index b092ce2cb..3469d434f 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -1,13 +1,46 @@ +from typing import Literal +from collections.abc import Sequence + import keras +from keras.src.utils import python_utils from bayesflow.types import Shape, Tensor from bayesflow.utils import layer_kwargs, find_distribution from bayesflow.utils.decorators import allow_batch_size +from bayesflow.utils.serialization import serializable, serialize +@serializable("bayesflow.networks") class InferenceNetwork(keras.Layer): - def __init__(self, base_distribution: str = "normal", **kwargs): + """ + Constructs an inference network using a specified base distribution and optional custom metrics. + Use this interface for custom inference networks. + """ + + def __init__( + self, + base_distribution: Literal["normal", "student", "mixture"] | keras.Layer = "normal", + *, + metrics: Sequence[keras.Metric] | None = None, + **kwargs, + ): + """ + Creates the network with provided arguments. Optional user-supplied metrics will be stored + in a `custom_metrics` attribute. A special `metrics` attribute will be created internally by `keras.Layer`. + + Parameters + ---------- + base_distribution : Literal["normal", "student", "mixture"] or keras.Layer + Name or the actual base distribution to use. Passed to `find_distribution` to + obtain the corresponding distribution object. + metrics : Sequence[keras.Metric] or None, optional + Sequence of custom Keras Metric instances to compute during training + and evaluation. If `None`, no custom metrics are used. + **kwargs + Additional keyword arguments forwarded to the `keras.Layer` constructor. + """ super().__init__(**layer_kwargs(kwargs)) + self.custom_metrics = metrics self.base_distribution = find_distribution(base_distribution) def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: @@ -66,9 +99,16 @@ def compute_metrics( if stage != "training" and any(self.metrics): # compute sample-based metrics - samples = self.sample((keras.ops.shape(x)[0],), conditions=conditions) + samples = self.sample(batch_shape=(keras.ops.shape(x)[0],), conditions=conditions) for metric in self.metrics: metrics[metric.name] = metric(samples, x) return metrics + + @python_utils.default + def get_config(self): + base_config = super().get_config() + + config = {"metrics": self.custom_metrics, "base_distribution": self.base_distribution} + return base_config | serialize(config) diff --git a/bayesflow/networks/mlp/mlp.py b/bayesflow/networks/mlp/mlp.py index 7184070af..935933836 100644 --- a/bayesflow/networks/mlp/mlp.py +++ b/bayesflow/networks/mlp/mlp.py @@ -29,11 +29,12 @@ def __init__( dropout: Literal[0, None] | float = 0.05, norm: Literal["batch", "layer"] | keras.Layer = None, spectral_normalization: bool = False, + metrics: Sequence[keras.Metric] | None = None, **kwargs, ): """ - Implements a flexible multi-layer perceptron (MLP) with optional residual connections, dropout, and - spectral normalization. + Creates a flexible multi-layer perceptron (MLP) with optional residual connections, dropout, + spectral normalization, and metrics. This MLP can be used as a general-purpose feature extractor or function approximator, supporting configurable depth, width, activation functions, and weight initializations. @@ -41,6 +42,9 @@ def __init__( If `residual` is enabled, each layer includes a skip connection for improved gradient flow. The model also supports dropout for regularization and spectral normalization for stability in learning smooth functions. + Optional user-supplied metrics will be stored in a `custom_metrics` attribute. A special `metrics` attribute + will be created internally by `keras.Layer`. + Parameters ---------- widths : Sequence[int], optional @@ -54,12 +58,15 @@ def __init__( dropout : float or None, optional Dropout rate applied within the MLP layers for regularization. Default is 0.05. norm: str, optional - + Type of learnable normalization to be used (e.g., "batch" or "layer"). Default is None. spectral_normalization : bool, optional Whether to apply spectral normalization to stabilize training. Default is False. + metrics: Sequence[keras.Metric], optional + A sequence of callable metrics following keras' `Metric` signature. Default is None. **kwargs Additional keyword arguments passed to the Keras layer initialization. """ + self.custom_metrics = metrics self.widths = list(widths) self.activation = activation self.kernel_initializer = kernel_initializer @@ -90,6 +97,7 @@ def get_config(self): "dropout": self.dropout, "norm": self.norm, "spectral_normalization": self.spectral_normalization, + "metrics": self.custom_metrics, } return base_config | serialize(config) diff --git a/bayesflow/networks/point_inference_network.py b/bayesflow/networks/point_inference_network.py index 2328d992f..f4209fabf 100644 --- a/bayesflow/networks/point_inference_network.py +++ b/bayesflow/networks/point_inference_network.py @@ -1,4 +1,6 @@ +from collections.abc import Sequence import keras +from keras.src.utils import python_utils from bayesflow.utils import model_kwargs, find_network from bayesflow.utils.serialization import deserialize, serializable, serialize @@ -17,9 +19,12 @@ def __init__( self, scores: dict[str, ScoringRule], subnet: str | keras.Layer = "mlp", + *, + metrics: Sequence[keras.Metric] | None = None, **kwargs, ): super().__init__(**model_kwargs(kwargs)) + self.custom_metrics = metrics self.scores = scores @@ -28,6 +33,7 @@ def __init__( self.config = { "subnet": serialize(subnet), "scores": serialize(scores), + "metrics": serialize(metrics), **kwargs, } @@ -106,6 +112,7 @@ def build_from_config(self, config): for head_key, head in self.heads[score_key].items(): head.name = config["heads"][score_key][head_key] + @python_utils.default def get_config(self): base_config = super().get_config() @@ -114,9 +121,7 @@ def get_config(self): @classmethod def from_config(cls, config): config = config.copy() - config["scores"] = deserialize(config["scores"]) - config["subnet"] = deserialize(config["subnet"]) - return cls(**config) + return cls(**deserialize(config)) def call( self, diff --git a/bayesflow/networks/summary_network.py b/bayesflow/networks/summary_network.py index e821be3f3..77b4306a1 100644 --- a/bayesflow/networks/summary_network.py +++ b/bayesflow/networks/summary_network.py @@ -1,15 +1,52 @@ import keras +from keras.src.utils import python_utils +from collections.abc import Sequence from bayesflow.metrics.functional import maximum_mean_discrepancy from bayesflow.types import Tensor from bayesflow.utils import layer_kwargs, find_distribution from bayesflow.utils.decorators import sanitize_input_shape -from bayesflow.utils.serialization import deserialize +from bayesflow.utils.serialization import serializable, serialize +@serializable("bayesflow.networks") class SummaryNetwork(keras.Layer): - def __init__(self, base_distribution: str = None, **kwargs): + """ + Builds a summary network with an optional base distribution and custom metrics. Use this class + as an interface for custom summary networks. + + Important + --------- + If a base distribution is passed, the summary outputs will be optimized to follow + that distribution, as described in [1]. + + References + ---------- + [1] Schmitt, M., Bürkner, P. C., Köthe, U., & Radev, S. T. (2023). + Detecting model misspecification in amortized Bayesian inference with neural networks. + In DAGM German Conference on Pattern Recognition (pp. 541-557). + Cham: Springer Nature Switzerland. + """ + + def __init__(self, base_distribution: str = None, *, metrics: Sequence[keras.Metric] | None = None, **kwargs): + """ + Creates the network with provided arguments. Optional user-supplied metrics will be stored + in a `custom_metrics` attribute. A special `metrics` attribute will be created internally by `keras.Layer`. + + Parameters + ---------- + base_distribution : str or None, default None + Name of the base distribution to use. If `None`, a default distribution + is chosen. Passed to `find_distribution` to obtain the corresponding + distribution object. + metrics : Sequence[keras.Metric] or None, optional + Sequence of custom Keras Metric instances to compute during training + and evaluation. If `None`, no custom metrics are used. + **kwargs + Additional keyword arguments forwarded to the `keras.Layer` constructor. + """ super().__init__(**layer_kwargs(kwargs)) + self.custom_metrics = metrics self.base_distribution = find_distribution(base_distribution) @sanitize_input_shape @@ -17,7 +54,7 @@ def build(self, input_shape): x = keras.ops.zeros(input_shape) z = self.call(x) - if self.base_distribution is not None: + if self.base_distribution is not None and not self.base_distribution.built: self.base_distribution.build(keras.ops.shape(z)) @sanitize_input_shape @@ -51,6 +88,9 @@ def compute_metrics(self, x: Tensor, stage: str = "training", **kwargs) -> dict[ return metrics - @classmethod - def from_config(cls, config, custom_objects=None): - return cls(**deserialize(config, custom_objects=custom_objects)) + @python_utils.default + def get_config(self): + base_config = super().get_config() + + config = {"metrics": self.custom_metrics, "base_distribution": self.base_distribution} + return base_config | serialize(config) diff --git a/bayesflow/utils/serialization.py b/bayesflow/utils/serialization.py index 5be0e0e1d..be63c7972 100644 --- a/bayesflow/utils/serialization.py +++ b/bayesflow/utils/serialization.py @@ -3,12 +3,16 @@ import builtins import inspect import keras +import functools import numpy as np import sys from warnings import warn # this import needs to be exactly like this to work with monkey patching -from keras.saving import deserialize_keras_object +from keras.saving import deserialize_keras_object, get_registered_object, get_registered_name +from keras.src.saving.serialization_lib import SerializableDict +from keras import dtype_policies +from keras import tree from .context_managers import monkey_patch from .decorators import allow_args @@ -95,6 +99,10 @@ def deserialize(config: dict, custom_objects=None, safe_mode=True, **kwargs): return obj +def _deserializing_from_config(cls, config, custom_objects=None): + return cls(**deserialize(config, custom_objects=custom_objects)) + + @allow_args def serializable(cls, package: str, name: str | None = None, disable_module_check: bool = False): """Register class as Keras serializable. @@ -143,6 +151,68 @@ def serializable(cls, package: str, name: str | None = None, disable_module_chec if name is None: name = copy(cls.__name__) + def init_decorator(original_init): + # Adds auto-config behavior after the __init__ function. This extends the auto-config capabilities provided + # by keras.Operation (base class of keras.Layer) with support for all serializable objects. + # This produces a serialized config that has to be deserialized properly, see below. + @functools.wraps(original_init) + def wrapper(instance, *args, **kwargs): + original_init(instance, *args, **kwargs) + + # Generate a config to be returned by default by `get_config()`. + # Adapted from keras.Operation. + kwargs = kwargs.copy() + arg_names = inspect.getfullargspec(original_init).args + kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args))) + + # Explicitly serialize `dtype` to support auto_config + dtype = kwargs.get("dtype", None) + if dtype is not None and isinstance(dtype, dtype_policies.DTypePolicy): + # For backward compatibility, we use a str (`name`) for + # `DTypePolicy` + if dtype.quantization_mode is None: + kwargs["dtype"] = dtype.name + # Otherwise, use `dtype_policies.serialize` + else: + kwargs["dtype"] = dtype_policies.serialize(dtype) + + # supported basic types + supported_types = (str, int, float, bool, type(None)) + + flat_arg_values = tree.flatten(kwargs) + auto_config = True + for value in flat_arg_values: + # adaptation: we allow all registered serializable objects + is_serializable_object = ( + isinstance(value, supported_types) + or get_registered_object(get_registered_name(type(value))) is not None + ) + # adaptation: we allow all registered serializable objects + try: + is_serializable_class = inspect.isclass(value) and deserialize(serialize(value)) + except ValueError: + # deserializtion of type failed, probably not registered + is_serializable_class = False + if not (is_serializable_object or is_serializable_class): + auto_config = False + break + + if auto_config: + with monkey_patch(keras.saving.serialize_keras_object, serialize): + instance._auto_config = SerializableDict(**kwargs) + else: + instance._auto_config = None + + return wrapper + + cls.__init__ = init_decorator(cls.__init__) + + if hasattr(cls, "from_config") and cls.from_config.__func__ == keras.Layer.from_config.__func__: + # By default, keras.Layer.from_config does not deserializte the config. For this class, there is a + # from_config method that is identical to keras.Layer.config, so we replace it with a variant that applies + # deserialization to the config. + cls.from_config = classmethod(_deserializing_from_config) + # register subclasses as keras serializable return keras.saving.register_keras_serializable(package=package, name=name)(cls) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index a5868deb5..f782dbe62 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -4,6 +4,7 @@ import keras from bayesflow.utils.serialization import deserialize, serialize +from tests.utils import assert_configs_equal import bayesflow as bf @@ -29,7 +30,7 @@ def test_serialize_deserialize(adapter, random_data): deserialized = deserialize(serialized) reserialized = serialize(deserialized) - assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized) + assert_configs_equal(serialized, reserialized) random_data["foo"] = random_data["x1"] deserialized_processed = deserialized(random_data) @@ -122,7 +123,6 @@ def test_simple_transforms(random_data): def test_custom_transform(): # test that transform raises errors in all relevant cases - import keras from bayesflow.adapters.transforms import SerializableCustomTransform from copy import deepcopy diff --git a/tests/test_approximators/conftest.py b/tests/test_approximators/conftest.py index 3c4d2fd4c..31f17404d 100644 --- a/tests/test_approximators/conftest.py +++ b/tests/test_approximators/conftest.py @@ -20,8 +20,11 @@ def summary_network(): @pytest.fixture() def inference_network(): from bayesflow.networks import CouplingFlow + from bayesflow.metrics import RootMeanSquaredError - return CouplingFlow(subnet="mlp", depth=2, subnet_kwargs=dict(widths=(32, 32))) + return CouplingFlow( + subnet="mlp", depth=2, subnet_kwargs=dict(widths=(32, 32)), metrics=[RootMeanSquaredError(name="rmse")] + ) @pytest.fixture() @@ -37,6 +40,7 @@ def continuous_approximator(adapter, inference_network, summary_network): @pytest.fixture() def point_inference_network(): + from bayesflow.metrics import RootMeanSquaredError from bayesflow.networks import PointInferenceNetwork from bayesflow.scores import NormedDifferenceScore, QuantileScore, MultivariateNormalScore @@ -48,11 +52,13 @@ def point_inference_network(): ), subnet="mlp", subnet_kwargs=dict(widths=(32, 32)), + metrics=[RootMeanSquaredError(name="rmse")], ) @pytest.fixture() def point_inference_network_with_multiple_parametric_scores(): + from bayesflow.metrics import RootMeanSquaredError from bayesflow.networks import PointInferenceNetwork from bayesflow.scores import MultivariateNormalScore @@ -61,6 +67,7 @@ def point_inference_network_with_multiple_parametric_scores(): mvn1=MultivariateNormalScore(), mvn2=MultivariateNormalScore(), ), + metrics=[RootMeanSquaredError(name="rmse")], ) @@ -178,9 +185,10 @@ def validation_dataset(batch_size, adapter, simulator): @pytest.fixture() def mean_std_summary_network(): + from bayesflow.metrics import MaximumMeanDiscrepancy from tests.utils import MeanStdSummaryNetwork - return MeanStdSummaryNetwork() + return MeanStdSummaryNetwork(metrics=[MaximumMeanDiscrepancy("mmd")]) @pytest.fixture(params=["continuous_approximator", "point_approximator", "model_comparison_approximator"]) diff --git a/tests/test_approximators/test_fit.py b/tests/test_approximators/test_fit.py index b561efb77..8d416c1e2 100644 --- a/tests/test_approximators/test_fit.py +++ b/tests/test_approximators/test_fit.py @@ -49,3 +49,7 @@ def test_loss_progress(approximator, train_dataset, validation_dataset): # check that the shown loss is not nan or zero assert re.search(r"\bnan\b", output) is None, "found nan in output" assert re.search(r"\bloss: 0\.0000e\+00\b", output) is None, "found zero loss in output" + + # check that additional metric is present + assert "val_rmse/inference_rmse" in output, "custom metric (RMSE) not shown" + assert re.search(r"\bval_rmse/inference_rmse: \d+\.\d+", output) is not None, "custom metric not correctly shown" diff --git a/tests/test_approximators/test_model_comparison_approximator/conftest.py b/tests/test_approximators/test_model_comparison_approximator/conftest.py index c5df51533..cdb6c584b 100644 --- a/tests/test_approximators/test_model_comparison_approximator/conftest.py +++ b/tests/test_approximators/test_model_comparison_approximator/conftest.py @@ -51,15 +51,17 @@ def adapter(): @pytest.fixture def summary_network(): from bayesflow.networks import DeepSet + from bayesflow.metrics import RootMeanSquaredError - return DeepSet(summary_dim=2, depth=1) + return DeepSet(summary_dim=2, depth=1, base_distribution="normal", metrics=[RootMeanSquaredError(name="rmse")]) @pytest.fixture def classifier_network(): from bayesflow.networks import MLP + from keras.metrics import CategoricalAccuracy - return MLP(widths=[32, 32]) + return MLP(widths=[32, 32], metrics=[CategoricalAccuracy(name="categorical_accuracy")]) @pytest.fixture diff --git a/tests/test_approximators/test_model_comparison_approximator/test_model_comparison_approximator.py b/tests/test_approximators/test_model_comparison_approximator/test_model_comparison_approximator.py index 0246ee7b7..5ca1d28e0 100644 --- a/tests/test_approximators/test_model_comparison_approximator/test_model_comparison_approximator.py +++ b/tests/test_approximators/test_model_comparison_approximator/test_model_comparison_approximator.py @@ -55,7 +55,10 @@ def test_fit(approximator, train_dataset, validation_dataset): output = stream.getvalue() # check that the loss is shown - assert "loss" in output + assert "loss/summary_loss" in output + assert "loss/classifier_loss" in output + assert "val_categorical_accuracy/classifier_categorical_accuracy" in output + assert "val_rmse/summary_rmse" in output def test_save_and_load(tmp_path, approximator, train_dataset, validation_dataset): diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index 63bc317ff..de536a30c 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -1,6 +1,7 @@ import pytest from collections.abc import Sequence +from bayesflow.metrics import RootMeanSquaredError from bayesflow.networks import MLP, Sequential from bayesflow.utils.tensor_utils import concatenate_valid from bayesflow.utils.serialization import serializable, serialize @@ -15,6 +16,7 @@ def diffusion_model_edm_F(): integrate_kwargs={"method": "rk45", "steps": 250}, noise_schedule="edm", prediction_type="F", + metrics=[RootMeanSquaredError()], ) @@ -140,6 +142,7 @@ def flow_matching(): return FlowMatching( subnet=MLP([8, 8]), integrate_kwargs={"method": "rk45", "steps": 100}, + metrics=[RootMeanSquaredError()], ) @@ -156,7 +159,11 @@ def flow_matching_subnet_separate_inputs(): def consistency_model(): from bayesflow.networks import ConsistencyModel - return ConsistencyModel(total_steps=100, subnet=MLP([8, 8])) + return ConsistencyModel( + total_steps=100, + subnet=MLP([8, 8]), + metrics=[RootMeanSquaredError()], + ) @pytest.fixture() @@ -171,7 +178,12 @@ def affine_coupling_flow(): from bayesflow.networks import CouplingFlow return CouplingFlow( - depth=2, subnet="mlp", subnet_kwargs=dict(widths=[8, 8]), transform="affine", transform_kwargs=dict(clamp=1.8) + depth=2, + subnet="mlp", + subnet_kwargs=dict(widths=[8, 8]), + transform="affine", + transform_kwargs=dict(clamp=1.8), + metrics=[RootMeanSquaredError()], ) @@ -180,7 +192,12 @@ def spline_coupling_flow(): from bayesflow.networks import CouplingFlow return CouplingFlow( - depth=2, subnet="mlp", subnet_kwargs=dict(widths=[8, 8]), transform="spline", transform_kwargs=dict(bins=8) + depth=2, + subnet="mlp", + subnet_kwargs=dict(widths=[8, 8]), + transform="spline", + transform_kwargs=dict(bins=8), + metrics=[RootMeanSquaredError()], ) @@ -188,7 +205,11 @@ def spline_coupling_flow(): def free_form_flow(): from bayesflow.experimental import FreeFormFlow - return FreeFormFlow(encoder_subnet=MLP([16, 16]), decoder_subnet=MLP([16, 16])) + return FreeFormFlow( + encoder_subnet=MLP([16, 16]), + decoder_subnet=MLP([16, 16]), + metrics=[RootMeanSquaredError()], + ) @pytest.fixture() @@ -318,35 +339,35 @@ def inference_network_subnet_separate_inputs(request): def time_series_network(summary_dim): from bayesflow.networks import TimeSeriesNetwork - return TimeSeriesNetwork(summary_dim=summary_dim) + return TimeSeriesNetwork(summary_dim=summary_dim, metrics=[RootMeanSquaredError()]) @pytest.fixture(scope="function") def time_series_transformer(summary_dim): from bayesflow.networks import TimeSeriesTransformer - return TimeSeriesTransformer(summary_dim=summary_dim) + return TimeSeriesTransformer(summary_dim=summary_dim, metrics=[RootMeanSquaredError()]) @pytest.fixture(scope="function") def fusion_transformer(summary_dim): from bayesflow.networks import FusionTransformer - return FusionTransformer(summary_dim=summary_dim) + return FusionTransformer(summary_dim=summary_dim, metrics=[RootMeanSquaredError()]) @pytest.fixture(scope="function") def set_transformer(summary_dim): from bayesflow.networks import SetTransformer - return SetTransformer(summary_dim=summary_dim) + return SetTransformer(summary_dim=summary_dim, metrics=[RootMeanSquaredError()]) @pytest.fixture(scope="function") def deep_set(summary_dim): from bayesflow.networks import DeepSet - return DeepSet(summary_dim=summary_dim) + return DeepSet(summary_dim=summary_dim, metrics=[RootMeanSquaredError()]) @pytest.fixture( diff --git a/tests/test_networks/test_fusion_network/test_fusion_network.py b/tests/test_networks/test_fusion_network/test_fusion_network.py index f1dbfa1c0..b73e7c426 100644 --- a/tests/test_networks/test_fusion_network/test_fusion_network.py +++ b/tests/test_networks/test_fusion_network/test_fusion_network.py @@ -2,7 +2,7 @@ import pytest import keras -from tests.utils import assert_layers_equal, allclose +from tests.utils import assert_layers_equal, assert_configs_equal, allclose @pytest.mark.parametrize("automatic", [True, False]) @@ -57,7 +57,7 @@ def test_serialize_deserialize(fusion_network, multimodal_data): deserialized = deserialize(serialized) reserialized = serialize(deserialized) - assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized) + assert_configs_equal(serialized, reserialized) def test_save_and_load(tmp_path, fusion_network, multimodal_data): diff --git a/tests/test_networks/test_inference_networks.py b/tests/test_networks/test_inference_networks.py index c0035e0f2..5c3ef32f3 100644 --- a/tests/test_networks/test_inference_networks.py +++ b/tests/test_networks/test_inference_networks.py @@ -4,7 +4,7 @@ from bayesflow.utils.serialization import serialize, deserialize -from tests.utils import assert_allclose, assert_layers_equal +from tests.utils import assert_allclose, assert_layers_equal, assert_configs_equal def test_build(inference_network, random_samples, random_conditions): @@ -140,7 +140,7 @@ def test_serialize_deserialize(inference_network, random_samples, random_conditi deserialized = deserialize(serialized) reserialized = serialize(deserialized) - assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized) + assert_configs_equal(serialized, reserialized) def test_save_and_load(tmp_path, inference_network, random_samples, random_conditions): diff --git a/tests/test_networks/test_point_inference_network/test_point_inference_network.py b/tests/test_networks/test_point_inference_network/test_point_inference_network.py index 38ba8ea4e..ac924bc1d 100644 --- a/tests/test_networks/test_point_inference_network/test_point_inference_network.py +++ b/tests/test_networks/test_point_inference_network/test_point_inference_network.py @@ -3,7 +3,7 @@ deserialize_keras_object as deserialize, serialize_keras_object as serialize, ) -from tests.utils import assert_layers_equal +from tests.utils import assert_layers_equal, assert_configs_equal import pytest @@ -72,7 +72,7 @@ def test_save_and_load_quantile(tmp_path, quantile_point_inference_network, rand loaded = keras.saving.load_model(tmp_path / "model.keras") print(net.get_config()) - assert net.get_config() == loaded.get_config() + assert_configs_equal(net.get_config(), loaded.get_config()) assert_layers_equal(net, loaded) diff --git a/tests/test_networks/test_summary_networks.py b/tests/test_networks/test_summary_networks.py index 74ce1f5fd..6b6518452 100644 --- a/tests/test_networks/test_summary_networks.py +++ b/tests/test_networks/test_summary_networks.py @@ -4,7 +4,7 @@ from bayesflow.utils.serialization import deserialize, serialize -from tests.utils import assert_layers_equal +from tests.utils import assert_layers_equal, assert_configs_equal @pytest.mark.parametrize("automatic", [True, False]) @@ -85,7 +85,7 @@ def test_serialize_deserialize(summary_network, random_set): deserialized = deserialize(serialized) reserialized = serialize(deserialized) - assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized) + assert_configs_equal(serialized, reserialized) def test_save_and_load(tmp_path, summary_network, random_set): diff --git a/tests/test_two_moons/conftest.py b/tests/test_two_moons/conftest.py index 5cd6f59db..282354f23 100644 --- a/tests/test_two_moons/conftest.py +++ b/tests/test_two_moons/conftest.py @@ -4,8 +4,9 @@ @pytest.fixture() def inference_network(): from bayesflow.networks import CouplingFlow + from bayesflow.metrics import MaximumMeanDiscrepancy - return CouplingFlow(depth=2, subnet="mlp", subnet_kwargs=dict(widths=(32, 32))) + return CouplingFlow(depth=2, subnet="mlp", subnet_kwargs=dict(widths=(32, 32)), metrics=[MaximumMeanDiscrepancy()]) @pytest.fixture() diff --git a/tests/test_two_moons/test_two_moons.py b/tests/test_two_moons/test_two_moons.py index 9189b142e..92aeaf929 100644 --- a/tests/test_two_moons/test_two_moons.py +++ b/tests/test_two_moons/test_two_moons.py @@ -13,13 +13,9 @@ def test_compile(approximator, random_samples, jit_compile): def test_fit(approximator, train_dataset, validation_dataset, batch_size): - from bayesflow.metrics import MaximumMeanDiscrepancy from bayesflow.networks import PointInferenceNetwork - inference_metrics = [] - if not isinstance(approximator.inference_network, PointInferenceNetwork): - inference_metrics += [MaximumMeanDiscrepancy()] - approximator.compile(inference_metrics=inference_metrics) + approximator.compile() mock_data = train_dataset[0] mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index 9c2affc22..5c7806cc9 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -3,4 +3,5 @@ from .check_combinations import * from .jupyter import * from .networks import * +from .normalize import * from .ops import * diff --git a/tests/utils/assertions.py b/tests/utils/assertions.py index fc4219206..240502984 100644 --- a/tests/utils/assertions.py +++ b/tests/utils/assertions.py @@ -1,4 +1,5 @@ import keras +from .normalize import normalize_config def assert_models_equal(model1: keras.Model, model2: keras.Model): @@ -13,6 +14,11 @@ def assert_models_equal(model1: keras.Model, model2: keras.Model): else: assert_layers_equal(layer1, layer2) + assert len(model1.metrics) == len(model2.metrics) + for metric1, metric2 in zip(model1.metrics, model2.metrics): + assert type(metric1) is type(metric2) + assert metric1.name == metric2.name + def assert_layers_equal(layer1: keras.Layer, layer2: keras.Layer): msg = f"Layers {layer1.name} and {layer2.name} have different types." @@ -40,3 +46,12 @@ def assert_layers_equal(layer1: keras.Layer, layer2: keras.Layer): # this is turned off for now, see https://github.com/bayesflow-org/bayesflow/issues/412 msg = f"Layers {layer1.name} and {layer2.name} have a different name." # assert layer1.name == layer2.name, msg + + assert len(layer1.metrics) == len(layer2.metrics), f"metrics do not match: {layer1.metrics}!={layer2.metrics}" + for metric1, metric2 in zip(layer1.metrics, layer2.metrics): + assert type(metric1) is type(metric2) + assert metric1.name == metric2.name + + +def assert_configs_equal(config1: dict, config2: dict): + assert normalize_config(config1) == normalize_config(config2) diff --git a/tests/utils/normalize.py b/tests/utils/normalize.py new file mode 100644 index 000000000..96198dc93 --- /dev/null +++ b/tests/utils/normalize.py @@ -0,0 +1,25 @@ +from copy import deepcopy +import keras + + +def normalize_dtype(config): + """Convert dtypes with DTypePolicy to simple strings""" + config = deepcopy(config) + + def walk_dictionary(cur_dict): + # walks the dicitonary and modifies entries in-place + for key, value in cur_dict.items(): + if key == "dtype" and isinstance(value, dict): + if value.get("class_name", "") == "DTypePolicy": + cur_dict[key] = value["config"]["name"] + continue + if isinstance(value, dict): + walk_dictionary(value) + + walk_dictionary(config) + return config + + +def normalize_config(config): + config = normalize_dtype(config) + config = keras.tree.lists_to_tuples(config)