Skip to content

Commit a22f97f

Browse files
CLU Authorscopybara-github
authored andcommitted
Change the count type from int32 to float32 to avoid overflows.
PiperOrigin-RevId: 723026512
1 parent 43acbbd commit a22f97f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

clu/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,7 @@ class Average(Metric):
784784

785785
@classmethod
786786
def empty(cls) -> Average:
787-
return cls(total=jnp.array(0, jnp.float32), count=jnp.array(0, jnp.int32))
787+
return cls(total=jnp.array(0, jnp.float32), count=jnp.array(0, jnp.float32))
788788

789789
@classmethod
790790
def from_model_output(

0 commit comments

Comments
 (0)