Skip to content

Commit cf784a7

Browse files
andsteingcopybara-github
authored andcommitted
Fixes code snippet for collecting CollectingMetric using a thread pool.
PiperOrigin-RevId: 542168193
1 parent 3746749 commit cf784a7

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

clu/metrics.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -361,18 +361,20 @@ def compute(self):
361361
from clu import asynclib
362362
363363
def evaluate(params):
364-
ms = MyCollection.empty()
365364
pool = asynclib.Pool()
366365
367366
@pool
368-
def merge(update):
369-
nonlocal ms
370-
ms = ms.merge(update)
367+
def copy_to_host(update):
368+
return jax.tree_map(np.asarray, update)
371369
370+
futures = []
372371
for batch in eval_ds:
373-
merge(eval_step(params, batch))
372+
futures.append(copy_to_host(eval_step(params, batch)))
373+
374+
ms = MyCollection.empty()
375+
for future in futures:
376+
ms = ms.merge(future.result())
374377
375-
pool.join()
376378
return ms.compute()
377379
"""
378380

0 commit comments

Comments
 (0)