Skip to content

Commit 4e9da73

Browse files
andsteingcopybara-github
authored andcommitted
Clu metrics + __future__.annotations
PiperOrigin-RevId: 547715759
1 parent 2b28d3b commit 4e9da73

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

clu/metrics.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def evaluate(model, p_variables, test_ds):
5757
"""
5858
from __future__ import annotations
5959
from collections.abc import Mapping, Sequence
60+
import inspect
6061
from typing import Any, TypeVar, Protocol
6162

6263
from absl import logging
@@ -534,8 +535,11 @@ def empty(cls: type[C]) -> C:
534535
_reduction_counter=_ReductionCounter(jnp.array(1, dtype=jnp.int32)),
535536
**{
536537
metric_name: metric.empty()
537-
for metric_name, metric in cls.__annotations__.items()
538-
})
538+
for metric_name, metric in inspect.get_annotations(
539+
cls, eval_str=True
540+
).items()
541+
},
542+
)
539543

540544
@classmethod
541545
def _from_model_output(cls: type[C], **kwargs) -> C:
@@ -544,8 +548,11 @@ def _from_model_output(cls: type[C], **kwargs) -> C:
544548
_reduction_counter=_ReductionCounter(jnp.array(1, dtype=jnp.int32)),
545549
**{
546550
metric_name: metric.from_model_output(**kwargs)
547-
for metric_name, metric in cls.__annotations__.items()
548-
})
551+
for metric_name, metric in inspect.get_annotations(
552+
cls, eval_str=True
553+
).items()
554+
},
555+
)
549556

550557
@classmethod
551558
def single_from_model_output(cls: type[C], **kwargs) -> C:

clu/metrics_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""Tests for clu.metrics."""
1616

17+
from __future__ import annotations
18+
1719
import functools
1820
from unittest import mock
1921

0 commit comments

Comments
 (0)