Skip to content

Commit 3b9f5be

Browse files
josephlrcopybara-github
authored andcommitted
clu.metrics: Fix type annotations for Metric
Similar to the previous change for `Collection`, this allows the methods and classmethods on `Metric` and its subclasses to be properly typed. This allows us to remove the `disable=g-bare-generic` lint. This also enables typechecking for `metrics_test.py`, which failed to typecheck before this change. To get `Metrics.from_fun()` to properly typecheck, we introduce a `FromFunCallable` [Callback Protocol](https://mypy.readthedocs.io/en/stable/protocols.html#callback-protocols) to encode the type of "a function that takes in any number of `ArrayLike` types and outputs `Array` or a dictionary with `Array` keys". Finally, this change uses `Protocol` instead of `abc.ABC` for `Value`, so that we can call `v.value` on abstract `Value` types. PiperOrigin-RevId: 529495726
1 parent f8eec70 commit 3b9f5be

File tree

2 files changed

+63
-50
lines changed

2 files changed

+63
-50
lines changed

clu/metrics.py

Lines changed: 59 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ def evaluate(model, p_variables, test_ds):
5555
ms = p_eval_step(ms, model, p_variables, inputs, labels)
5656
return ms.unreplicate().compute()
5757
"""
58-
from collections.abc import Callable, Mapping, Sequence
59-
from typing import Any, Optional, TypeVar
58+
from __future__ import annotations
59+
from collections.abc import Mapping, Sequence
60+
from typing import Any, TypeVar, Protocol
6061

6162
from absl import logging
6263

@@ -67,6 +68,17 @@ def evaluate(model, p_variables, test_ds):
6768
import jax.numpy as jnp
6869
import numpy as np
6970

71+
Array = jax.Array
72+
ArrayLike = jax.typing.ArrayLike
73+
74+
75+
class FromFunCallable(Protocol):
76+
"""The type of functions that can be passed to `Metrics.from_fun()`."""
77+
78+
def __call__(self, **kwargs: ArrayLike) -> Array | Mapping[str, Array]:
79+
"""Returns the argument/arguments passed to the base from_model_output()."""
80+
81+
7082
# TODO(b/200953513): Migrate away from logging imports (on module level)
7183
# to logging the actual usage. See b/200953513.
7284

@@ -78,6 +90,9 @@ def _assert_same_shape(a: jnp.array, b: jnp.array):
7890
raise ValueError(f"Expected same shape: {a.shape} != {b.shape}")
7991

8092

93+
M = TypeVar("M", bound="Metric")
94+
95+
8196
class Metric:
8297
"""Interface for computing metrics from intermediate values.
8398
@@ -114,11 +129,11 @@ def compute(self):
114129
"""
115130

116131
@classmethod
117-
def from_model_output(cls, *args, **kwargs) -> "Metric":
132+
def from_model_output(cls: type[M], *args, **kwargs) -> M:
118133
"""Creates a `Metric` from model outputs."""
119134
raise NotImplementedError("Must override from_model_output()")
120135

121-
def merge(self, other: "Metric") -> "Metric":
136+
def merge(self: M, other: M) -> M:
122137
"""Returns `Metric` that is the accumulation of `self` and `other`.
123138
124139
Args:
@@ -141,23 +156,23 @@ def merge(self, other: "Metric") -> "Metric":
141156
# `_reduce_merge()` must be associative[1], otherwise we would get
142157
# different results when using different devices.
143158
# [1] https://en.wikipedia.org/wiki/Associative_property
144-
def _reduce_merge(self, other: "Metric") -> "Metric":
159+
def _reduce_merge(self: M, other: M) -> M:
145160
return self.merge(other)
146161

147162
def compute(self) -> jnp.array:
148163
"""Computes final metrics from intermediate values."""
149164
raise NotImplementedError("Must override compute()")
150165

151166
@classmethod
152-
def empty(cls) -> "Metric":
167+
def empty(cls: type[M]) -> M:
153168
"""Returns an empty instance (i.e. `.merge(Metric.empty())` is a no-op)."""
154169
raise NotImplementedError("Must override empty()")
155170

156171
def compute_value(self) -> clu.values.Value:
157172
"""Wraps compute() and returns a values.Value."""
158173
return clu.values.Scalar(self.compute())
159174

160-
def reduce(self) -> "Metric":
175+
def reduce(self: M) -> M:
161176
"""Reduces the metric along it first axis by calling `_reduce_merge()`.
162177
163178
This function primary use case is to aggregate metrics collected across
@@ -173,7 +188,7 @@ def reduce(self) -> "Metric":
173188
reduced metric.
174189
"""
175190

176-
def reduce_step(reduced: Metric, metric: Metric) -> tuple[Metric, None]:
191+
def reduce_step(reduced: M, metric: M) -> tuple[M, None]:
177192
# pylint: disable-next=protected-access
178193
return reduced._reduce_merge(metric), None
179194

@@ -183,7 +198,7 @@ def reduce_step(reduced: Metric, metric: Metric) -> tuple[Metric, None]:
183198
return jax.lax.scan(reduce_step, first, remainder)[0]
184199

185200
@classmethod
186-
def from_fun(cls, fun: Callable): # pylint: disable=g-bare-generic
201+
def from_fun(cls, fun: FromFunCallable): # No way to annotate return type
187202
"""Calls `cls.from_model_output` with the return value from `fun`.
188203
189204
Returns a `Metric` derived from `cls` whose `.from_model_output` (1) calls
@@ -233,7 +248,7 @@ class FromFun(cls):
233248
"""Wrapper Metric class that collects output after applying `fun`."""
234249

235250
@classmethod
236-
def from_model_output(cls, **model_output) -> Metric:
251+
def from_model_output(cls: type[M], **model_output) -> M:
237252
mask = model_output.get("mask")
238253
output = fun(**model_output)
239254
if isinstance(output, Mapping) and "mask" in output:
@@ -266,7 +281,7 @@ def from_model_output(cls, **model_output) -> Metric:
266281
return FromFun
267282

268283
@classmethod
269-
def from_output(cls, name: str): # pylint: disable=g-bare-generic
284+
def from_output(cls, name: str): # No way to annotate return type
270285
"""Calls `cls.from_model_output` with model output named `name`.
271286
272287
Synopsis:
@@ -295,7 +310,7 @@ class FromOutput(cls):
295310
"""Wrapper Metric class that collects output named `name`."""
296311

297312
@classmethod
298-
def from_model_output(cls, **model_output) -> Metric:
313+
def from_model_output(cls: type[M], **model_output) -> M:
299314
output = jnp.array(model_output[name])
300315
mask = model_output.get("mask")
301316
if mask is not None and (output.shape or [0])[0] != mask.shape[0]:
@@ -366,10 +381,10 @@ def merge(update):
366381
values: dict[str, tuple[np.ndarray, ...]]
367382

368383
@classmethod
369-
def empty(cls) -> "CollectingMetric":
384+
def empty(cls) -> CollectingMetric:
370385
return cls(values={})
371386

372-
def merge(self, other: "CollectingMetric") -> "CollectingMetric":
387+
def merge(self, other: CollectingMetric) -> CollectingMetric:
373388
values = {
374389
name: (*value, *other.values[name])
375390
for name, value in self.values.items()
@@ -384,25 +399,24 @@ def merge(self, other: "CollectingMetric") -> "CollectingMetric":
384399
return self
385400
return type(self)(jax.tree_map(np.asarray, values))
386401

387-
def reduce(self) -> "CollectingMetric":
402+
def reduce(self) -> CollectingMetric:
388403
# Note that this is usually called from inside a `pmap()` via
389404
# `Collection.gather_from_model_output()` so we concatenate using jnp.
390405
return type(self)(
391406
{name: jnp.concatenate(values) for name, values in self.values.items()})
392407

393-
def compute(self) -> dict[str, np.ndarray]:
408+
def compute(self): # No return type annotation, so subclasses can override
394409
return {k: np.concatenate(v) for k, v in self.values.items()}
395410

396411
@classmethod
397-
def from_outputs(cls, names: Sequence[str]):
412+
def from_outputs(cls, names: Sequence[str]) -> type[CollectingMetric]:
398413
"""Returns a metric class that collects all model outputs named `names`."""
399414

400415
@flax.struct.dataclass
401416
class FromOutputs(cls): # pylint:disable=missing-class-docstring
402417

403418
@classmethod
404-
def from_model_output(cls, **model_output) -> Metric:
405-
419+
def from_model_output(cls: type[M], **model_output) -> M:
406420
def make_array(value):
407421
value = jnp.array(value)
408422
# Can't jnp.concatenate() scalars, promote to shape=(1,) in that case.
@@ -420,10 +434,10 @@ class _ReductionCounter(Metric):
420434
value: jnp.array
421435

422436
@classmethod
423-
def empty(cls):
437+
def empty(cls) -> _ReductionCounter:
424438
return cls(value=jnp.array(1, jnp.int32))
425439

426-
def merge(self, other: "_ReductionCounter") -> "_ReductionCounter":
440+
def merge(self, other: _ReductionCounter) -> _ReductionCounter:
427441
return _ReductionCounter(self.value + other.value)
428442

429443

@@ -461,7 +475,7 @@ class Metrics(Collection):
461475
_reduction_counter: _ReductionCounter
462476

463477
@classmethod
464-
def create(cls, **metrics: type[Metric]) -> type["Collection"]:
478+
def create(cls, **metrics: type[Metric]) -> type[Collection]:
465479
"""Handy short-cut to define a `Collection` inline.
466480
467481
Instead declaring a `Collection` dataclass:
@@ -487,7 +501,7 @@ class MyMetrics(metrics.Collection):
487501
type("_InlineCollection", (Collection,), {"__annotations__": metrics}))
488502

489503
@classmethod
490-
def create_collection(cls, **metrics: Metric) -> "Collection":
504+
def create_collection(cls, **metrics: Metric) -> Collection:
491505
"""Creates a custom collection object with fields metrics.
492506
493507
This object will be an instance of custom subclass of `Collection` with
@@ -650,10 +664,12 @@ class LastValue(Metric):
650664
total: jnp.array
651665
count: jnp.array
652666

653-
def __init__(self, total: Optional[jnp.array] = None,
654-
count: Optional[jnp.array] = None,
655-
value: Optional[jnp.array] = None,
656-
):
667+
def __init__(
668+
self,
669+
total: jnp.array | None = None,
670+
count: jnp.array | None = None,
671+
value: jnp.array | None = None,
672+
):
657673
"""Constructor which supports keyword argument value as initializer.
658674
659675
If "value" is provided, then "total" should *not* be provided.
@@ -673,26 +689,25 @@ def __init__(self, total: Optional[jnp.array] = None,
673689
object.__setattr__(self, "count", count)
674690

675691
@classmethod
676-
def empty(cls):
692+
def empty(cls) -> LastValue:
677693
return cls(jnp.array(0, jnp.float32), count=jnp.array(0, jnp.int32))
678694

679695
@classmethod
680-
def from_model_output(cls,
681-
value: jnp.array,
682-
mask: Optional[jnp.array] = None,
683-
**_) -> Metric:
696+
def from_model_output(
697+
cls, value: jnp.array, mask: jnp.array | None = None, **_
698+
) -> LastValue:
684699
if mask is None:
685700
mask = jnp.ones((value.shape or [()])[0])
686701
return cls(
687702
total=jnp.where(mask, value, jnp.zeros_like(value)).sum(),
688703
count=mask.sum().astype(jnp.int32),
689704
)
690705

691-
def merge(self, other: "LastValue") -> "LastValue":
706+
def merge(self, other: LastValue) -> LastValue:
692707
_assert_same_shape(self.value, other.value)
693708
return other
694709

695-
def _reduce_merge(self, other: "LastValue") -> "LastValue":
710+
def _reduce_merge(self, other: LastValue) -> LastValue:
696711
# We need to average during reduction.
697712
_assert_same_shape(self.total, other.total)
698713
return type(self)(
@@ -730,14 +745,13 @@ class Average(Metric):
730745
count: jnp.array
731746

732747
@classmethod
733-
def empty(cls) -> Metric:
748+
def empty(cls) -> Average:
734749
return cls(total=jnp.array(0, jnp.float32), count=jnp.array(0, jnp.int32))
735750

736751
@classmethod
737-
def from_model_output(cls,
738-
values: jnp.array,
739-
mask: Optional[jnp.array] = None,
740-
**_) -> Metric:
752+
def from_model_output(
753+
cls, values: jnp.array, mask: jnp.array | None = None, **_
754+
) -> Average:
741755
if values.ndim == 0:
742756
values = values[None]
743757
if mask is None:
@@ -760,7 +774,7 @@ def from_model_output(cls,
760774
jnp.zeros_like(values, dtype=jnp.int32)).sum(),
761775
)
762776

763-
def merge(self, other: "Average") -> "Average":
777+
def merge(self, other: Average) -> Average:
764778
_assert_same_shape(self.total, other.total)
765779
return type(self)(
766780
total=self.total + other.total,
@@ -783,17 +797,16 @@ class Std(Metric):
783797
count: jnp.array
784798

785799
@classmethod
786-
def empty(cls):
800+
def empty(cls) -> Std:
787801
return cls(
788802
total=jnp.array(0, jnp.float32),
789803
sum_of_squares=jnp.array(0, jnp.float32),
790804
count=jnp.array(0, jnp.int32))
791805

792806
@classmethod
793-
def from_model_output(cls,
794-
values: jnp.array,
795-
mask: Optional[jnp.array] = None,
796-
**_) -> Metric:
807+
def from_model_output(
808+
cls, values: jnp.array, mask: jnp.array | None = None, **_
809+
) -> Std:
797810
if values.ndim == 0:
798811
values = values[None]
799812
utils.check_param(values, ndim=1)
@@ -805,7 +818,7 @@ def from_model_output(cls,
805818
count=mask.sum(),
806819
)
807820

808-
def merge(self, other: "Std") -> "Std":
821+
def merge(self, other: Std) -> Std:
809822
_assert_same_shape(self.total, other.total)
810823
return type(self)(
811824
total=self.total + other.total,

clu/values.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717
A Metric should return one of the following types when compute() is called.
1818
"""
1919

20-
import abc
2120
import dataclasses
22-
from typing import Any, Union
21+
from typing import Any, Union, Protocol, runtime_checkable
2322

2423
import jax.numpy as jnp
2524
import numpy as np
@@ -28,13 +27,14 @@
2827
ScalarType = Union[int, float, np.number, np.ndarray, jnp.ndarray]
2928

3029

31-
class Value(abc.ABC):
30+
@runtime_checkable
31+
class Value(Protocol):
3232
"""Class defining available metric computation return values.
3333
3434
Types mirror those available in MetricWriter. See
3535
clu/metric_writers/interface.py
3636
"""
37-
pass
37+
value: Any
3838

3939

4040
@dataclasses.dataclass

0 commit comments

Comments
 (0)