@@ -55,8 +55,9 @@ def evaluate(model, p_variables, test_ds):
5555 ms = p_eval_step(ms, model, p_variables, inputs, labels)
5656 return ms.unreplicate().compute()
5757"""
58- from collections .abc import Callable , Mapping , Sequence
59- from typing import Any , Optional , TypeVar
58+ from __future__ import annotations
59+ from collections .abc import Mapping , Sequence
60+ from typing import Any , TypeVar , Protocol
6061
6162from absl import logging
6263
@@ -67,6 +68,17 @@ def evaluate(model, p_variables, test_ds):
6768import jax .numpy as jnp
6869import numpy as np
6970
71+ Array = jax .Array
72+ ArrayLike = jax .typing .ArrayLike
73+
74+
75+ class FromFunCallable (Protocol ):
76+ """The type of functions that can be passed to `Metrics.from_fun()`."""
77+
78+ def __call__ (self , ** kwargs : ArrayLike ) -> Array | Mapping [str , Array ]:
79+ """Returns the argument/arguments passed to the base from_model_output()."""
80+
81+
7082# TODO(b/200953513): Migrate away from logging imports (on module level)
7183# to logging the actual usage. See b/200953513.
7284
@@ -78,6 +90,9 @@ def _assert_same_shape(a: jnp.array, b: jnp.array):
7890 raise ValueError (f"Expected same shape: { a .shape } != { b .shape } " )
7991
8092
93+ M = TypeVar ("M" , bound = "Metric" )
94+
95+
8196class Metric :
8297 """Interface for computing metrics from intermediate values.
8398
@@ -114,11 +129,11 @@ def compute(self):
114129 """
115130
116131 @classmethod
117- def from_model_output (cls , * args , ** kwargs ) -> "Metric" :
132+ def from_model_output (cls : type [ M ] , * args , ** kwargs ) -> M :
118133 """Creates a `Metric` from model outputs."""
119134 raise NotImplementedError ("Must override from_model_output()" )
120135
121- def merge (self , other : "Metric" ) -> "Metric" :
136+ def merge (self : M , other : M ) -> M :
122137 """Returns `Metric` that is the accumulation of `self` and `other`.
123138
124139 Args:
@@ -141,23 +156,23 @@ def merge(self, other: "Metric") -> "Metric":
141156 # `_reduce_merge()` must be associative[1], otherwise we would get
142157 # different results when using different devices.
143158 # [1] https://en.wikipedia.org/wiki/Associative_property
144- def _reduce_merge (self , other : "Metric" ) -> "Metric" :
159+ def _reduce_merge (self : M , other : M ) -> M :
145160 return self .merge (other )
146161
147162 def compute (self ) -> jnp .array :
148163 """Computes final metrics from intermediate values."""
149164 raise NotImplementedError ("Must override compute()" )
150165
151166 @classmethod
152- def empty (cls ) -> "Metric" :
167+ def empty (cls : type [ M ] ) -> M :
153168 """Returns an empty instance (i.e. `.merge(Metric.empty())` is a no-op)."""
154169 raise NotImplementedError ("Must override empty()" )
155170
156171 def compute_value (self ) -> clu .values .Value :
157172 """Wraps compute() and returns a values.Value."""
158173 return clu .values .Scalar (self .compute ())
159174
160- def reduce (self ) -> "Metric" :
175+ def reduce (self : M ) -> M :
161176 """Reduces the metric along it first axis by calling `_reduce_merge()`.
162177
163178 This function primary use case is to aggregate metrics collected across
@@ -173,7 +188,7 @@ def reduce(self) -> "Metric":
173188 reduced metric.
174189 """
175190
176- def reduce_step (reduced : Metric , metric : Metric ) -> tuple [Metric , None ]:
191+ def reduce_step (reduced : M , metric : M ) -> tuple [M , None ]:
177192 # pylint: disable-next=protected-access
178193 return reduced ._reduce_merge (metric ), None
179194
@@ -183,7 +198,7 @@ def reduce_step(reduced: Metric, metric: Metric) -> tuple[Metric, None]:
183198 return jax .lax .scan (reduce_step , first , remainder )[0 ]
184199
185200 @classmethod
186- def from_fun (cls , fun : Callable ): # pylint: disable=g-bare-generic
201+ def from_fun (cls , fun : FromFunCallable ): # No way to annotate return type
187202 """Calls `cls.from_model_output` with the return value from `fun`.
188203
189204 Returns a `Metric` derived from `cls` whose `.from_model_output` (1) calls
@@ -233,7 +248,7 @@ class FromFun(cls):
233248 """Wrapper Metric class that collects output after applying `fun`."""
234249
235250 @classmethod
236- def from_model_output (cls , ** model_output ) -> Metric :
251+ def from_model_output (cls : type [ M ] , ** model_output ) -> M :
237252 mask = model_output .get ("mask" )
238253 output = fun (** model_output )
239254 if isinstance (output , Mapping ) and "mask" in output :
@@ -266,7 +281,7 @@ def from_model_output(cls, **model_output) -> Metric:
266281 return FromFun
267282
268283 @classmethod
269- def from_output (cls , name : str ): # pylint: disable=g-bare-generic
284+ def from_output (cls , name : str ): # No way to annotate return type
270285 """Calls `cls.from_model_output` with model output named `name`.
271286
272287 Synopsis:
@@ -295,7 +310,7 @@ class FromOutput(cls):
295310 """Wrapper Metric class that collects output named `name`."""
296311
297312 @classmethod
298- def from_model_output (cls , ** model_output ) -> Metric :
313+ def from_model_output (cls : type [ M ] , ** model_output ) -> M :
299314 output = jnp .array (model_output [name ])
300315 mask = model_output .get ("mask" )
301316 if mask is not None and (output .shape or [0 ])[0 ] != mask .shape [0 ]:
@@ -366,10 +381,10 @@ def merge(update):
366381 values : dict [str , tuple [np .ndarray , ...]]
367382
368383 @classmethod
369- def empty (cls ) -> " CollectingMetric" :
384+ def empty (cls ) -> CollectingMetric :
370385 return cls (values = {})
371386
372- def merge (self , other : " CollectingMetric" ) -> " CollectingMetric" :
387+ def merge (self , other : CollectingMetric ) -> CollectingMetric :
373388 values = {
374389 name : (* value , * other .values [name ])
375390 for name , value in self .values .items ()
@@ -384,25 +399,24 @@ def merge(self, other: "CollectingMetric") -> "CollectingMetric":
384399 return self
385400 return type (self )(jax .tree_map (np .asarray , values ))
386401
387- def reduce (self ) -> " CollectingMetric" :
402+ def reduce (self ) -> CollectingMetric :
388403 # Note that this is usually called from inside a `pmap()` via
389404 # `Collection.gather_from_model_output()` so we concatenate using jnp.
390405 return type (self )(
391406 {name : jnp .concatenate (values ) for name , values in self .values .items ()})
392407
393- def compute (self ) -> dict [ str , np . ndarray ]:
408+ def compute (self ): # No return type annotation, so subclasses can override
394409 return {k : np .concatenate (v ) for k , v in self .values .items ()}
395410
396411 @classmethod
397- def from_outputs (cls , names : Sequence [str ]):
412+ def from_outputs (cls , names : Sequence [str ]) -> type [ CollectingMetric ] :
398413 """Returns a metric class that collects all model outputs named `names`."""
399414
400415 @flax .struct .dataclass
401416 class FromOutputs (cls ): # pylint:disable=missing-class-docstring
402417
403418 @classmethod
404- def from_model_output (cls , ** model_output ) -> Metric :
405-
419+ def from_model_output (cls : type [M ], ** model_output ) -> M :
406420 def make_array (value ):
407421 value = jnp .array (value )
408422 # Can't jnp.concatenate() scalars, promote to shape=(1,) in that case.
@@ -420,10 +434,10 @@ class _ReductionCounter(Metric):
420434 value : jnp .array
421435
422436 @classmethod
423- def empty (cls ):
437+ def empty (cls ) -> _ReductionCounter :
424438 return cls (value = jnp .array (1 , jnp .int32 ))
425439
426- def merge (self , other : " _ReductionCounter" ) -> " _ReductionCounter" :
440+ def merge (self , other : _ReductionCounter ) -> _ReductionCounter :
427441 return _ReductionCounter (self .value + other .value )
428442
429443
@@ -461,7 +475,7 @@ class Metrics(Collection):
461475 _reduction_counter : _ReductionCounter
462476
463477 @classmethod
464- def create (cls , ** metrics : type [Metric ]) -> type [" Collection" ]:
478+ def create (cls , ** metrics : type [Metric ]) -> type [Collection ]:
465479 """Handy short-cut to define a `Collection` inline.
466480
467481 Instead declaring a `Collection` dataclass:
@@ -487,7 +501,7 @@ class MyMetrics(metrics.Collection):
487501 type ("_InlineCollection" , (Collection ,), {"__annotations__" : metrics }))
488502
489503 @classmethod
490- def create_collection (cls , ** metrics : Metric ) -> " Collection" :
504+ def create_collection (cls , ** metrics : Metric ) -> Collection :
491505 """Creates a custom collection object with fields metrics.
492506
493507 This object will be an instance of custom subclass of `Collection` with
@@ -650,10 +664,12 @@ class LastValue(Metric):
650664 total : jnp .array
651665 count : jnp .array
652666
653- def __init__ (self , total : Optional [jnp .array ] = None ,
654- count : Optional [jnp .array ] = None ,
655- value : Optional [jnp .array ] = None ,
656- ):
667+ def __init__ (
668+ self ,
669+ total : jnp .array | None = None ,
670+ count : jnp .array | None = None ,
671+ value : jnp .array | None = None ,
672+ ):
657673 """Constructor which supports keyword argument value as initializer.
658674
659675 If "value" is provided, then "total" should *not* be provided.
@@ -673,26 +689,25 @@ def __init__(self, total: Optional[jnp.array] = None,
673689 object .__setattr__ (self , "count" , count )
674690
675691 @classmethod
676- def empty (cls ):
692+ def empty (cls ) -> LastValue :
677693 return cls (jnp .array (0 , jnp .float32 ), count = jnp .array (0 , jnp .int32 ))
678694
679695 @classmethod
680- def from_model_output (cls ,
681- value : jnp .array ,
682- mask : Optional [jnp .array ] = None ,
683- ** _ ) -> Metric :
696+ def from_model_output (
697+ cls , value : jnp .array , mask : jnp .array | None = None , ** _
698+ ) -> LastValue :
684699 if mask is None :
685700 mask = jnp .ones ((value .shape or [()])[0 ])
686701 return cls (
687702 total = jnp .where (mask , value , jnp .zeros_like (value )).sum (),
688703 count = mask .sum ().astype (jnp .int32 ),
689704 )
690705
691- def merge (self , other : " LastValue" ) -> " LastValue" :
706+ def merge (self , other : LastValue ) -> LastValue :
692707 _assert_same_shape (self .value , other .value )
693708 return other
694709
695- def _reduce_merge (self , other : " LastValue" ) -> " LastValue" :
710+ def _reduce_merge (self , other : LastValue ) -> LastValue :
696711 # We need to average during reduction.
697712 _assert_same_shape (self .total , other .total )
698713 return type (self )(
@@ -730,14 +745,13 @@ class Average(Metric):
730745 count : jnp .array
731746
732747 @classmethod
733- def empty (cls ) -> Metric :
748+ def empty (cls ) -> Average :
734749 return cls (total = jnp .array (0 , jnp .float32 ), count = jnp .array (0 , jnp .int32 ))
735750
736751 @classmethod
737- def from_model_output (cls ,
738- values : jnp .array ,
739- mask : Optional [jnp .array ] = None ,
740- ** _ ) -> Metric :
752+ def from_model_output (
753+ cls , values : jnp .array , mask : jnp .array | None = None , ** _
754+ ) -> Average :
741755 if values .ndim == 0 :
742756 values = values [None ]
743757 if mask is None :
@@ -760,7 +774,7 @@ def from_model_output(cls,
760774 jnp .zeros_like (values , dtype = jnp .int32 )).sum (),
761775 )
762776
763- def merge (self , other : " Average" ) -> " Average" :
777+ def merge (self , other : Average ) -> Average :
764778 _assert_same_shape (self .total , other .total )
765779 return type (self )(
766780 total = self .total + other .total ,
@@ -783,17 +797,16 @@ class Std(Metric):
783797 count : jnp .array
784798
785799 @classmethod
786- def empty (cls ):
800+ def empty (cls ) -> Std :
787801 return cls (
788802 total = jnp .array (0 , jnp .float32 ),
789803 sum_of_squares = jnp .array (0 , jnp .float32 ),
790804 count = jnp .array (0 , jnp .int32 ))
791805
792806 @classmethod
793- def from_model_output (cls ,
794- values : jnp .array ,
795- mask : Optional [jnp .array ] = None ,
796- ** _ ) -> Metric :
807+ def from_model_output (
808+ cls , values : jnp .array , mask : jnp .array | None = None , ** _
809+ ) -> Std :
797810 if values .ndim == 0 :
798811 values = values [None ]
799812 utils .check_param (values , ndim = 1 )
@@ -805,7 +818,7 @@ def from_model_output(cls,
805818 count = mask .sum (),
806819 )
807820
808- def merge (self , other : " Std" ) -> " Std" :
821+ def merge (self , other : Std ) -> Std :
809822 _assert_same_shape (self .total , other .total )
810823 return type (self )(
811824 total = self .total + other .total ,
0 commit comments