From cfac12bfdacc63d2060e3bca2642e9a9994447bf Mon Sep 17 00:00:00 2001 From: CLU Authors Date: Tue, 4 Feb 2025 04:23:23 -0800 Subject: [PATCH] Change the count type from int32 to float32 to avoid overflows. PiperOrigin-RevId: 723026512 --- clu/metrics.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/clu/metrics.py b/clu/metrics.py index 103d871..f0c888b 100644 --- a/clu/metrics.py +++ b/clu/metrics.py @@ -784,7 +784,7 @@ class Average(Metric): @classmethod def empty(cls) -> Average: - return cls(total=jnp.array(0, jnp.float32), count=jnp.array(0, jnp.int32)) + return cls(total=jnp.array(0, jnp.float32), count=jnp.array(0, jnp.float32)) @classmethod def from_model_output( @@ -792,12 +792,13 @@ def from_model_output( ) -> Average: values, mask = _broadcast_masks(values, mask) return cls( - total=jnp.where(mask, values, jnp.zeros_like(values)).sum(), + total=jnp.where(mask, values, jnp.zeros_like(values)).sum().astype( + jnp.float32), count=jnp.where( mask, jnp.ones_like(values, dtype=jnp.int32), jnp.zeros_like(values, dtype=jnp.int32), - ).sum(), + ).sum().astype(jnp.float32), ) def merge(self, other: Average) -> Average: