Skip to content

Commit fff0476

Browse files
josephlrcopybara-github
authored andcommitted
Internal change
PiperOrigin-RevId: 527429713
1 parent 903f1ef commit fff0476

File tree

2 files changed

+23
-15
lines changed

2 files changed

+23
-15
lines changed

clu/metrics.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def evaluate(model, p_variables, test_ds):
5656
return ms.unreplicate().compute()
5757
"""
5858
from collections.abc import Callable, Mapping, Sequence
59-
from typing import Any, Optional
59+
from typing import Any, Optional, TypeVar
6060

6161
from absl import logging
6262

@@ -435,6 +435,9 @@ def _check_reduction_counter_ndim(reduction_counter: _ReductionCounter):
435435
f"call a flax.jax_utils.unreplicate() or a Collections.reduce()?")
436436

437437

438+
C = TypeVar("C", bound="Collection")
439+
440+
438441
@flax.struct.dataclass
439442
class Collection:
440443
"""Updates a collection of `Metric` from model outputs.
@@ -512,7 +515,7 @@ class MyMetrics(metrics.Collection):
512515
return collection_class(_reduction_counter=counter, **metrics)
513516

514517
@classmethod
515-
def empty(cls) -> "Collection":
518+
def empty(cls: type[C]) -> C:
516519
return cls(
517520
_reduction_counter=_ReductionCounter(jnp.array(1, dtype=jnp.int32)),
518521
**{
@@ -521,7 +524,7 @@ def empty(cls) -> "Collection":
521524
})
522525

523526
@classmethod
524-
def _from_model_output(cls, **kwargs) -> "Collection":
527+
def _from_model_output(cls: type[C], **kwargs) -> C:
525528
"""Creates a `Collection` from model outputs."""
526529
return cls(
527530
_reduction_counter=_ReductionCounter(jnp.array(1, dtype=jnp.int32)),
@@ -531,7 +534,7 @@ def _from_model_output(cls, **kwargs) -> "Collection":
531534
})
532535

533536
@classmethod
534-
def single_from_model_output(cls, **kwargs) -> "Collection":
537+
def single_from_model_output(cls: type[C], **kwargs) -> C:
535538
"""Creates a `Collection` from model outputs.
536539
537540
Note: This function should only be called when metrics are collected in a
@@ -546,9 +549,7 @@ def single_from_model_output(cls, **kwargs) -> "Collection":
546549
return cls._from_model_output(**kwargs)
547550

548551
@classmethod
549-
def gather_from_model_output(cls,
550-
axis_name="batch",
551-
**kwargs) -> "Collection":
552+
def gather_from_model_output(cls: type[C], axis_name="batch", **kwargs) -> C:
552553
"""Creates a `Collection` from model outputs in a distributed setting.
553554
554555
Args:
@@ -563,14 +564,14 @@ def gather_from_model_output(cls,
563564
return jax.lax.all_gather(
564565
cls._from_model_output(**kwargs), axis_name=axis_name).reduce()
565566

566-
def merge(self, other: "Collection") -> "Collection":
567+
def merge(self: C, other: C) -> C:
567568
"""Returns `Collection` that is the accumulation of `self` and `other`."""
568569
return type(self)(**{
569570
metric_name: metric.merge(getattr(other, metric_name))
570571
for metric_name, metric in vars(self).items()
571572
})
572573

573-
def reduce(self) -> "Collection":
574+
def reduce(self: C) -> C:
574575
"""Reduces the collection by calling `Metric.reduce()` on each metric.
575576
576577
The primary use case is to reduce collection that was gathered
@@ -620,7 +621,7 @@ def compute_values(self) -> dict[str, clu.values.Value]:
620621
if metric_name != "_reduction_counter"
621622
}
622623

623-
def unreplicate(self) -> "Collection":
624+
def unreplicate(self: C) -> C:
624625
"""Short-hand for `flax.jax_utils.unreplicate(self)`.
625626
626627
The collection should be gathered and `reduce`d inside pmap,

clu/metrics_test.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def setUp(self):
8686
dict(mask=mask, **model_output)
8787
for mask, model_output in zip(masks, self.model_outputs))
8888

89+
self.count = 4
90+
self.count_masked = 2
8991
self.results = {
9092
"train_accuracy": 0.75,
9193
"learning_rate": 0.01,
@@ -367,18 +369,23 @@ def test_collection_create_collection(self):
367369
("_masked", True),
368370
)
369371
def test_collection_single(self, masked):
370-
372+
@jax.jit
371373
def compute_collection(model_outputs):
372374
collection = Collection.empty()
373375
for model_output in model_outputs:
374376
update = Collection.single_from_model_output(**model_output)
375377
collection = collection.merge(update)
376-
return collection.compute()
378+
return collection
377379

380+
model_outputs = self.model_outputs_masked if masked else self.model_outputs
381+
collection = compute_collection(model_outputs)
378382
chex.assert_trees_all_close(
379-
jax.jit(compute_collection)(
380-
self.model_outputs_masked if masked else self.model_outputs),
381-
self.results_masked if masked else self.results)
383+
collection.compute(), self.results_masked if masked else self.results
384+
)
385+
self.assertEqual(
386+
collection.train_accuracy.count,
387+
self.count_masked if masked else self.count,
388+
)
382389

383390
@parameterized.named_parameters(
384391
("", False),

0 commit comments

Comments
 (0)