Skip to content

Commit 227229d

Browse files
CLU Authorscopybara-github
authored andcommitted
internal cleanup
PiperOrigin-RevId: 529693178
1 parent 2ce88bf commit 227229d

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

clu/metrics.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,9 @@ def reduce_step(reduced: M, metric: M) -> tuple[M, None]:
194194

195195
first = jax.tree_map(lambda x: x[0], self)
196196
remainder = jax.tree_map(lambda x: x[1:], self)
197-
# TODO(b/160868467) Verify this adds no significant computational cost.
197+
# According to b/160868467#comment4, usage of `jax.lax.scan` does not add a
198+
# significant computational cost for simple metrics where e.g. `jnp.sum`
199+
# could be used instead.
198200
return jax.lax.scan(reduce_step, first, remainder)[0]
199201

200202
@classmethod

0 commit comments

Comments
 (0)