Skip to content

Commit 3746749

Browse files
CLU Authorscopybara-github
authored andcommitted
Allow overwriting the mask from Metric.from_fun.
The docstring already claims that this is supported. But the code raised an error. PiperOrigin-RevId: 538503893
1 parent 20d271a commit 3746749

File tree

2 files changed

+8
-19
lines changed

2 files changed

+8
-19
lines changed

clu/metrics.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,6 @@ def from_model_output(cls: type[M], **model_output) -> M:
257257
output = dict(output)
258258
# pop mask to avoid multiple arg error later.
259259
output_mask = output.pop("mask", None)
260-
if mask is not None:
261-
raise ValueError(
262-
"fun %s provided a mask, but a 'mask' field was already "
263-
"given in the model output" % (fun,))
264260
mask = output_mask
265261
# Ignore the mask if its first dimension doesn't match that of the
266262
# output of `fun`.

clu/metrics_test.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -334,34 +334,27 @@ def test_collection_create(self):
334334

335335
def test_collection_create_custom_mask(self):
336336

337-
def with_head1(logits, labels, head1_mask, **_):
338-
return dict(logits=logits, labels=labels, mask=head1_mask)
337+
def with_head1(logits, labels, mask, head1_mask, **_):
338+
return dict(logits=logits, labels=labels, mask=head1_mask & mask)
339339

340-
def with_head2(logits, labels, head2_mask, **_):
341-
return dict(logits=logits, labels=labels, mask=head2_mask)
340+
def with_head2(logits, labels, mask, head2_mask, **_):
341+
return dict(logits=logits, labels=labels, mask=head2_mask & mask)
342342

343343
collection = metrics.Collection.create(
344344
head1_accuracy=metrics.Accuracy.from_fun(with_head1),
345345
head2_accuracy=metrics.Accuracy.from_fun(with_head2)
346346
)
347-
with self.assertRaisesRegex(
348-
ValueError, "but a 'mask' field was already given"):
349-
collection.single_from_model_output(
350-
logits=jnp.array([[-1., 1.], [1., -1.]]),
351-
labels=jnp.array([0, 0]), # i.e. 1st incorrect, 2nd correct
352-
head1_mask=jnp.array([True, False]), # ignore the 2nd.
353-
head2_mask=jnp.array([True, False]), # ignore the 2nd.
354-
mask=jnp.array([False, True]), # raises the error.
355-
)
356347

357348
chex.assert_trees_all_close(
358349
collection.single_from_model_output(
359-
logits=jnp.array([[-1., 1.], [1., -1.]]),
350+
logits=jnp.array([[-1.0, 1.0], [1.0, -1.0]]),
360351
labels=jnp.array([0, 0]), # i.e. 1st incorrect, 2nd correct
352+
mask=jnp.array([True, True]),
361353
head1_mask=jnp.array([True, False]), # ignore the 2nd.
362354
head2_mask=jnp.array([False, True]), # ignore the 1st.
363355
).compute(),
364-
{"head1_accuracy": 0.0, "head2_accuracy": 1.0})
356+
{"head1_accuracy": 0.0, "head2_accuracy": 1.0},
357+
)
365358

366359
def test_collection_create_collection(self):
367360
collection = metrics.Collection.create_collection(

0 commit comments

Comments
 (0)