@@ -650,6 +650,10 @@ def unreplicate(self: C) -> C:
650650 return flax .jax_utils .unreplicate (self )
651651
652652
653+ # Sentinel to make LastValue.__init__ support tree manipulations that use None.
654+ _default = object ()
655+
656+
653657@flax .struct .dataclass
654658class LastValue (Metric ):
655659 """Keeps the last average global batch value.
@@ -660,32 +664,43 @@ class LastValue(Metric):
660664 in cases when batch is distributed across multiple devices and need
661665 to be averaged later. However, we don't inherit from Average to
662666 maintain backward compatibility in case of isinstance(metric, Average)
663- check. For backward compatibility this class can be initialized using the
664- keyword `LastValue(value=10)` or `total` and `count` .
667+ check. For backward compatibility this class can also be initialized as
668+ if the constructor was __init__(value) .
665669 """
666670 total : jnp .array
667671 count : jnp .array
668672
669673 def __init__ (
670674 self ,
671- total : jnp .array | None = None ,
672- count : jnp .array | None = None ,
673- value : jnp .array | None = None ,
675+ total : jnp .array | _default = _default ,
676+ count : jnp .array | _default = _default ,
677+ value : jnp .array | _default = _default ,
674678 ):
675- """Constructor which supports keyword argument value as initializer .
679+ """Backward compatibility constructor .
676680
677- If "value" is provided, then "total" should *not* be provided.
681+ It is intended to be constructed as __init__(total, count). When doing so
682+ the arguments are assigned as instance attributes without extra operations.
683+ For backward compatibility it also supports __init__(value) code paths.
678684
679685 Args:
680686 total: Total value.
681- count: Count of examples, 1 if not provided
682- value: Value, if provided, will be assumed to be "count " of values.
687+ count: Count of examples, 1 if not provided.
688+ value: Value, if provided, will be assumed to be "total " of values.
683689 """
684- count = count if count is not None else jnp .array (1 , dtype = jnp .int32 )
685- if value is not None :
686- if total is not None :
687- raise ValueError ("Only one of 'total' and 'value' should be None. "
688- f'Got { total } , { value } ' )
690+ # Note: This code should not use None to detect a default argument, also it
691+ # should avoid doing any logic when its being called by tree_utils.
692+ # That is a requirement for tree manipulations where leafs that use other
693+ # values like shapes/sharding information or even None.
694+ # Per https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html
695+ # classes should provide a static create() method, but here we overload
696+ # the constructor for backward compatibility when it was LastValue(value).
697+ count = count if count is not _default else jnp .array (1 , dtype = jnp .int32 )
698+ if (value is _default ) == (total is _default ):
699+ raise ValueError (
700+ "Exactly one of 'total' and 'value' should be passed. "
701+ f"Got { total } , { value } "
702+ )
703+ if total is _default :
689704 total = value * count
690705 object .__setattr__ (self , "total" , total )
691706 object .__setattr__ (self , "count" , count )
0 commit comments