@@ -56,7 +56,7 @@ def evaluate(model, p_variables, test_ds):
5656 return ms.unreplicate().compute()
5757"""
5858from collections .abc import Callable , Mapping , Sequence
59- from typing import Any , Optional
59+ from typing import Any , Optional , TypeVar
6060
6161from absl import logging
6262
@@ -435,6 +435,9 @@ def _check_reduction_counter_ndim(reduction_counter: _ReductionCounter):
435435 f"call a flax.jax_utils.unreplicate() or a Collections.reduce()?" )
436436
437437
438+ C = TypeVar ("C" , bound = "Collection" )
439+
440+
438441@flax .struct .dataclass
439442class Collection :
440443 """Updates a collection of `Metric` from model outputs.
@@ -512,7 +515,7 @@ class MyMetrics(metrics.Collection):
512515 return collection_class (_reduction_counter = counter , ** metrics )
513516
514517 @classmethod
515- def empty (cls ) -> "Collection" :
518+ def empty (cls : type [ C ] ) -> C :
516519 return cls (
517520 _reduction_counter = _ReductionCounter (jnp .array (1 , dtype = jnp .int32 )),
518521 ** {
@@ -521,7 +524,7 @@ def empty(cls) -> "Collection":
521524 })
522525
523526 @classmethod
524- def _from_model_output (cls , ** kwargs ) -> "Collection" :
527+ def _from_model_output (cls : type [ C ] , ** kwargs ) -> C :
525528 """Creates a `Collection` from model outputs."""
526529 return cls (
527530 _reduction_counter = _ReductionCounter (jnp .array (1 , dtype = jnp .int32 )),
@@ -531,7 +534,7 @@ def _from_model_output(cls, **kwargs) -> "Collection":
531534 })
532535
533536 @classmethod
534- def single_from_model_output (cls , ** kwargs ) -> "Collection" :
537+ def single_from_model_output (cls : type [ C ] , ** kwargs ) -> C :
535538 """Creates a `Collection` from model outputs.
536539
537540 Note: This function should only be called when metrics are collected in a
@@ -546,9 +549,7 @@ def single_from_model_output(cls, **kwargs) -> "Collection":
546549 return cls ._from_model_output (** kwargs )
547550
548551 @classmethod
549- def gather_from_model_output (cls ,
550- axis_name = "batch" ,
551- ** kwargs ) -> "Collection" :
552+ def gather_from_model_output (cls : type [C ], axis_name = "batch" , ** kwargs ) -> C :
552553 """Creates a `Collection` from model outputs in a distributed setting.
553554
554555 Args:
@@ -563,14 +564,14 @@ def gather_from_model_output(cls,
563564 return jax .lax .all_gather (
564565 cls ._from_model_output (** kwargs ), axis_name = axis_name ).reduce ()
565566
566- def merge (self , other : "Collection" ) -> "Collection" :
567+ def merge (self : C , other : C ) -> C :
567568 """Returns `Collection` that is the accumulation of `self` and `other`."""
568569 return type (self )(** {
569570 metric_name : metric .merge (getattr (other , metric_name ))
570571 for metric_name , metric in vars (self ).items ()
571572 })
572573
573- def reduce (self ) -> "Collection" :
574+ def reduce (self : C ) -> C :
574575 """Reduces the collection by calling `Metric.reduce()` on each metric.
575576
576577 The primary use case is to reduce collection that was gathered
@@ -620,7 +621,7 @@ def compute_values(self) -> dict[str, clu.values.Value]:
620621 if metric_name != "_reduction_counter"
621622 }
622623
623- def unreplicate (self ) -> "Collection" :
624+ def unreplicate (self : C ) -> C :
624625 """Short-hand for `flax.jax_utils.unreplicate(self)`.
625626
626627 The collection should be gathered and `reduce`d inside pmap,
0 commit comments