File tree Expand file tree Collapse file tree 1 file changed +10
-6
lines changed
Expand file tree Collapse file tree 1 file changed +10
-6
lines changed Original file line number Diff line number Diff line change @@ -466,17 +466,21 @@ def test_collecting_metric_reduce(self):
466466 chex .assert_trees_all_close (reduced .compute (), {"value" : np .ones ([16 , 4 ])})
467467
468468 def test_collecting_metric_async (self ):
469- metric = CollectingMetricAccuracy .empty ()
470469 pool = asynclib .Pool ()
471470
472471 @pool
473- def merge (update ):
474- nonlocal metric
475- metric = metric .merge (update )
472+ def copy_to_host (update ):
473+ return jax .tree_map (np .asarray , update )
476474
475+ futures = []
476+ from_model_output = jax .jit (CollectingMetricAccuracy .from_model_output )
477477 for model_output in self .model_outputs :
478- merge (jax .jit (CollectingMetricAccuracy .from_model_output )(** model_output ))
479- pool .join ()
478+ futures .append (copy_to_host (from_model_output (** model_output )))
479+
480+ metric = CollectingMetricAccuracy .empty ()
481+ for future in futures :
482+ metric = metric .merge (future .result ())
483+
480484 result = metric .compute ()
481485 chex .assert_trees_all_close (result , 0.75 )
482486
You can’t perform that action at this time.
0 commit comments