Skip to content

Commit 9cfdeed

Browse files
josephlrcopybara-github
authored andcommitted
Fix Accuracy.from_model_output return type
PiperOrigin-RevId: 529504950
1 parent 3b9f5be commit 9cfdeed

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

clu/metrics.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -852,11 +852,14 @@ class Accuracy(Average):
852852
"""
853853

854854
@classmethod
855-
def from_model_output(cls, *, logits: jnp.array, labels: jnp.array,
856-
**kwargs) -> Metric:
855+
def from_model_output(
856+
cls, *, logits: jnp.array, labels: jnp.array, **kwargs
857+
) -> Accuracy:
857858
if logits.ndim != labels.ndim + 1 or labels.dtype != jnp.int32:
858859
raise ValueError(
859860
f"Expected labels.dtype==jnp.int32 and logits.ndim={logits.ndim}=="
860861
f"labels.ndim+1={labels.ndim + 1}")
861-
return super().from_model_output(
862-
values=(logits.argmax(axis=-1) == labels).astype(jnp.float32), **kwargs)
862+
metric = super().from_model_output(
863+
values=(logits.argmax(axis=-1) == labels).astype(jnp.float32), **kwargs
864+
)
865+
return cls(**vars(metric)) # cls(metrics) doesn't work for a dataclass

0 commit comments

Comments
 (0)