@@ -55,8 +55,8 @@ def evaluate(model, p_variables, test_ds):
5555 ms = p_eval_step(ms, model, p_variables, inputs, labels)
5656 return ms.unreplicate().compute()
5757"""
58-
59- from typing import Any , Callable , Dict , Mapping , Optional , Sequence , Tuple , Type
58+ from collections . abc import Callable , Mapping , Sequence
59+ from typing import Any , Optional
6060
6161from absl import logging
6262
@@ -173,7 +173,7 @@ def reduce(self) -> "Metric":
173173 reduced metric.
174174 """
175175
176- def reduce_step (reduced : Metric , metric : Metric ) -> Tuple [Metric , None ]:
176+ def reduce_step (reduced : Metric , metric : Metric ) -> tuple [Metric , None ]:
177177 # pylint: disable-next=protected-access
178178 return reduced ._reduce_merge (metric ), None
179179
@@ -363,7 +363,7 @@ def merge(update):
363363 return ms.compute()
364364 """
365365
366- values : Dict [str , Tuple [np .ndarray , ...]]
366+ values : dict [str , tuple [np .ndarray , ...]]
367367
368368 @classmethod
369369 def empty (cls ) -> "CollectingMetric" :
@@ -390,7 +390,7 @@ def reduce(self) -> "CollectingMetric":
390390 return type (self )(
391391 {name : jnp .concatenate (values ) for name , values in self .values .items ()})
392392
393- def compute (self ) -> Dict [str , np .ndarray ]:
393+ def compute (self ) -> dict [str , np .ndarray ]:
394394 return {k : np .concatenate (v ) for k , v in self .values .items ()}
395395
396396 @classmethod
@@ -458,7 +458,7 @@ class Metrics(Collection):
458458 _reduction_counter : _ReductionCounter
459459
460460 @classmethod
461- def create (cls , ** metrics : Type [Metric ]) -> Type ["Collection" ]:
461+ def create (cls , ** metrics : type [Metric ]) -> type ["Collection" ]:
462462 """Handy short-cut to define a `Collection` inline.
463463
464464 Instead declaring a `Collection` dataclass:
@@ -602,7 +602,7 @@ def reduce(self) -> "Collection":
602602 for metric_name , metric in vars (self ).items ()
603603 })
604604
605- def compute (self ) -> Dict [str , jnp .array ]:
605+ def compute (self ) -> dict [str , jnp .array ]:
606606 """Returns a dictionary mapping metric field name to `Metric.compute()`."""
607607 _check_reduction_counter_ndim (self ._reduction_counter )
608608 return {
@@ -611,7 +611,7 @@ def compute(self) -> Dict[str, jnp.array]:
611611 if metric_name != "_reduction_counter"
612612 }
613613
614- def compute_values (self ) -> Dict [str , clu .values .Value ]:
614+ def compute_values (self ) -> dict [str , clu .values .Value ]:
615615 """Computes metrics and returns them as clu.values.Value."""
616616 _check_reduction_counter_ndim (self ._reduction_counter )
617617 return {
0 commit comments