diff --git a/clu/metrics.py b/clu/metrics.py index 103d871..f0c888b 100644 --- a/clu/metrics.py +++ b/clu/metrics.py @@ -784,7 +784,7 @@ class Average(Metric): @classmethod def empty(cls) -> Average: - return cls(total=jnp.array(0, jnp.float32), count=jnp.array(0, jnp.int32)) + return cls(total=jnp.array(0, jnp.float32), count=jnp.array(0, jnp.float32)) @classmethod def from_model_output( @@ -792,12 +792,13 @@ def from_model_output( ) -> Average: values, mask = _broadcast_masks(values, mask) return cls( - total=jnp.where(mask, values, jnp.zeros_like(values)).sum(), + total=jnp.where(mask, values, jnp.zeros_like(values)).sum().astype( + jnp.float32), count=jnp.where( mask, jnp.ones_like(values, dtype=jnp.int32), jnp.zeros_like(values, dtype=jnp.int32), - ).sum(), + ).sum().astype(jnp.float32), ) def merge(self, other: Average) -> Average: