@@ -334,34 +334,27 @@ def test_collection_create(self):
334334
335335 def test_collection_create_custom_mask (self ):
336336
337- def with_head1 (logits , labels , head1_mask , ** _ ):
338- return dict (logits = logits , labels = labels , mask = head1_mask )
337+ def with_head1 (logits , labels , mask , head1_mask , ** _ ):
338+ return dict (logits = logits , labels = labels , mask = head1_mask & mask )
339339
340- def with_head2 (logits , labels , head2_mask , ** _ ):
341- return dict (logits = logits , labels = labels , mask = head2_mask )
340+ def with_head2 (logits , labels , mask , head2_mask , ** _ ):
341+ return dict (logits = logits , labels = labels , mask = head2_mask & mask )
342342
343343 collection = metrics .Collection .create (
344344 head1_accuracy = metrics .Accuracy .from_fun (with_head1 ),
345345 head2_accuracy = metrics .Accuracy .from_fun (with_head2 )
346346 )
347- with self .assertRaisesRegex (
348- ValueError , "but a 'mask' field was already given" ):
349- collection .single_from_model_output (
350- logits = jnp .array ([[- 1. , 1. ], [1. , - 1. ]]),
351- labels = jnp .array ([0 , 0 ]), # i.e. 1st incorrect, 2nd correct
352- head1_mask = jnp .array ([True , False ]), # ignore the 2nd.
353- head2_mask = jnp .array ([True , False ]), # ignore the 2nd.
354- mask = jnp .array ([False , True ]), # raises the error.
355- )
356347
357348 chex .assert_trees_all_close (
358349 collection .single_from_model_output (
359- logits = jnp .array ([[- 1. , 1. ], [1. , - 1. ]]),
350+ logits = jnp .array ([[- 1.0 , 1.0 ], [1.0 , - 1.0 ]]),
360351 labels = jnp .array ([0 , 0 ]), # i.e. 1st incorrect, 2nd correct
352+ mask = jnp .array ([True , True ]),
361353 head1_mask = jnp .array ([True , False ]), # ignore the 2nd.
362354 head2_mask = jnp .array ([False , True ]), # ignore the 1st.
363355 ).compute (),
364- {"head1_accuracy" : 0.0 , "head2_accuracy" : 1.0 })
356+ {"head1_accuracy" : 0.0 , "head2_accuracy" : 1.0 },
357+ )
365358
366359 def test_collection_create_collection (self ):
367360 collection = metrics .Collection .create_collection (
0 commit comments