Skip to content

Commit ec67350

Browse files
Qwlousecopybara-github
authored andcommitted
Fix usage of mask for computing total in Std metric.
PiperOrigin-RevId: 529696402
1 parent 227229d commit ec67350

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

clu/metrics.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -815,8 +815,10 @@ def from_model_output(
815815
if mask is None:
816816
mask = jnp.ones(values.shape[0], dtype=jnp.int32)
817817
return cls(
818-
total=values.sum(),
819-
sum_of_squares=jnp.where(mask, values**2, jnp.zeros_like(values)).sum(),
818+
total=jnp.where(mask, values, jnp.zeros_like(values)).sum(),
819+
sum_of_squares=jnp.where(
820+
mask, values**2, jnp.zeros_like(values)
821+
).sum(),
820822
count=mask.sum(),
821823
)
822824

0 commit comments

Comments
 (0)