Skip to content

Commit dd4af73

Browse files
andresusanopintocopybara-github
authored andcommitted
Fix LastValue to support tree_manipulations that map leafs to other objects (e.g. None).
PiperOrigin-RevId: 537804323
1 parent 8542e55 commit dd4af73

File tree

2 files changed

+44
-14
lines changed

2 files changed

+44
-14
lines changed

clu/metrics.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,10 @@ def unreplicate(self: C) -> C:
650650
return flax.jax_utils.unreplicate(self)
651651

652652

653+
# Sentinel to make LastValue.__init__ support tree manipulations that use None.
654+
_default = object()
655+
656+
653657
@flax.struct.dataclass
654658
class LastValue(Metric):
655659
"""Keeps the last average global batch value.
@@ -660,32 +664,43 @@ class LastValue(Metric):
660664
in cases when batch is distributed across multiple devices and need
661665
to be averaged later. However, we don't inherit from Average to
662666
maintain backward compatibility in case of isinstance(metric, Average)
663-
check. For backward compatibility this class can be initialized using the
664-
keyword `LastValue(value=10)` or `total` and `count`.
667+
check. For backward compatibility this class can also be initialized as
668+
if the constructor was __init__(value).
665669
"""
666670
total: jnp.array
667671
count: jnp.array
668672

669673
def __init__(
670674
self,
671-
total: jnp.array | None = None,
672-
count: jnp.array | None = None,
673-
value: jnp.array | None = None,
675+
total: jnp.array | _default = _default,
676+
count: jnp.array | _default = _default,
677+
value: jnp.array | _default = _default,
674678
):
675-
"""Constructor which supports keyword argument value as initializer.
679+
"""Backward compatibility constructor.
676680
677-
If "value" is provided, then "total" should *not* be provided.
681+
It is intended to be constructed as __init__(total, count). When doing so
682+
the arguments are assigned as instance attributes without extra operations.
683+
For backward compatibility it also supports __init__(value) code paths.
678684
679685
Args:
680686
total: Total value.
681-
count: Count of examples, 1 if not provided
682-
value: Value, if provided, will be assumed to be "count" of values.
687+
count: Count of examples, 1 if not provided.
688+
value: Value, if provided, will be assumed to be "total" of values.
683689
"""
684-
count = count if count is not None else jnp.array(1, dtype=jnp.int32)
685-
if value is not None:
686-
if total is not None:
687-
raise ValueError("Only one of 'total' and 'value' should be None. "
688-
f'Got {total}, {value}')
690+
# Note: This code should not use None to detect a default argument, also it
691+
# should avoid doing any logic when its being called by tree_utils.
692+
# That is a requirement for tree manipulations where leafs that use other
693+
# values like shapes/sharding information or even None.
694+
# Per https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html
695+
# classes should provide a static create() method, but here we overload
696+
# the constructor for backward compatibility when it was LastValue(value).
697+
count = count if count is not _default else jnp.array(1, dtype=jnp.int32)
698+
if (value is _default) == (total is _default):
699+
raise ValueError(
700+
"Exactly one of 'total' and 'value' should be passed. "
701+
f"Got {total}, {value}"
702+
)
703+
if total is _default:
689704
total = value * count
690705
object.__setattr__(self, "total", total)
691706
object.__setattr__(self, "count", count)

clu/metrics_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,21 @@ def test_metric_last_value_legacy_kwarg_value(self):
184184
metric = metrics.LastValue(value=2.0, count=3)
185185
self.assertEqual(metric.total, 6.0)
186186

187+
def test_metric_last_value_tree_manipulation(self):
188+
# Test mapping leaves to other non array values (e.g.: None).
189+
metric = metrics.LastValue(value=2.0)
190+
metric = jax.tree_map(lambda x: None, metric)
191+
self.assertIsNone(metric.total, None)
192+
self.assertIsNone(metric.count, None)
193+
metric = metrics.LastValue(value=2.0, count=3)
194+
metric = jax.tree_map(lambda x: None, metric)
195+
self.assertIsNone(metric.total, None)
196+
self.assertIsNone(metric.count, None)
197+
metric = metrics.LastValue(2.0)
198+
metric = jax.tree_map(lambda x: None, metric)
199+
self.assertIsNone(metric.total, None)
200+
self.assertIsNone(metric.count, None)
201+
187202
def test_from_fun_with_single_output(self):
188203

189204
def accuracy(*, logits, labels, **_):

0 commit comments

Comments
 (0)