Skip to content

Commit 903f1ef

Browse files
josephlrcopybara-github
authored andcommitted
Internal change
PiperOrigin-RevId: 527427711
1 parent 40d8cfa commit 903f1ef

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

clu/metrics.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def evaluate(model, p_variables, test_ds):
5555
ms = p_eval_step(ms, model, p_variables, inputs, labels)
5656
return ms.unreplicate().compute()
5757
"""
58-
59-
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type
58+
from collections.abc import Callable, Mapping, Sequence
59+
from typing import Any, Optional
6060

6161
from absl import logging
6262

@@ -173,7 +173,7 @@ def reduce(self) -> "Metric":
173173
reduced metric.
174174
"""
175175

176-
def reduce_step(reduced: Metric, metric: Metric) -> Tuple[Metric, None]:
176+
def reduce_step(reduced: Metric, metric: Metric) -> tuple[Metric, None]:
177177
# pylint: disable-next=protected-access
178178
return reduced._reduce_merge(metric), None
179179

@@ -363,7 +363,7 @@ def merge(update):
363363
return ms.compute()
364364
"""
365365

366-
values: Dict[str, Tuple[np.ndarray, ...]]
366+
values: dict[str, tuple[np.ndarray, ...]]
367367

368368
@classmethod
369369
def empty(cls) -> "CollectingMetric":
@@ -390,7 +390,7 @@ def reduce(self) -> "CollectingMetric":
390390
return type(self)(
391391
{name: jnp.concatenate(values) for name, values in self.values.items()})
392392

393-
def compute(self) -> Dict[str, np.ndarray]:
393+
def compute(self) -> dict[str, np.ndarray]:
394394
return {k: np.concatenate(v) for k, v in self.values.items()}
395395

396396
@classmethod
@@ -458,7 +458,7 @@ class Metrics(Collection):
458458
_reduction_counter: _ReductionCounter
459459

460460
@classmethod
461-
def create(cls, **metrics: Type[Metric]) -> Type["Collection"]:
461+
def create(cls, **metrics: type[Metric]) -> type["Collection"]:
462462
"""Handy short-cut to define a `Collection` inline.
463463
464464
Instead declaring a `Collection` dataclass:
@@ -602,7 +602,7 @@ def reduce(self) -> "Collection":
602602
for metric_name, metric in vars(self).items()
603603
})
604604

605-
def compute(self) -> Dict[str, jnp.array]:
605+
def compute(self) -> dict[str, jnp.array]:
606606
"""Returns a dictionary mapping metric field name to `Metric.compute()`."""
607607
_check_reduction_counter_ndim(self._reduction_counter)
608608
return {
@@ -611,7 +611,7 @@ def compute(self) -> Dict[str, jnp.array]:
611611
if metric_name != "_reduction_counter"
612612
}
613613

614-
def compute_values(self) -> Dict[str, clu.values.Value]:
614+
def compute_values(self) -> dict[str, clu.values.Value]:
615615
"""Computes metrics and returns them as clu.values.Value."""
616616
_check_reduction_counter_ndim(self._reduction_counter)
617617
return {

0 commit comments

Comments
 (0)