Skip to content

Commit 1a8e9c7

Browse files
andsteingcopybara-github
authored andcommitted
Updates test to make it thread safe.
PiperOrigin-RevId: 542847213
1 parent cf784a7 commit 1a8e9c7

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

clu/metrics_test.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -466,17 +466,21 @@ def test_collecting_metric_reduce(self):
466466
chex.assert_trees_all_close(reduced.compute(), {"value": np.ones([16, 4])})
467467

468468
def test_collecting_metric_async(self):
469-
metric = CollectingMetricAccuracy.empty()
470469
pool = asynclib.Pool()
471470

472471
@pool
473-
def merge(update):
474-
nonlocal metric
475-
metric = metric.merge(update)
472+
def copy_to_host(update):
473+
return jax.tree_map(np.asarray, update)
476474

475+
futures = []
476+
from_model_output = jax.jit(CollectingMetricAccuracy.from_model_output)
477477
for model_output in self.model_outputs:
478-
merge(jax.jit(CollectingMetricAccuracy.from_model_output)(**model_output))
479-
pool.join()
478+
futures.append(copy_to_host(from_model_output(**model_output)))
479+
480+
metric = CollectingMetricAccuracy.empty()
481+
for future in futures:
482+
metric = metric.merge(future.result())
483+
480484
result = metric.compute()
481485
chex.assert_trees_all_close(result, 0.75)
482486

0 commit comments

Comments
 (0)