@@ -132,7 +132,16 @@ def merge(self, other: "Metric") -> "Metric":
132132 """
133133 raise NotImplementedError ("Must override merge()" )
134134
135- def merge_reduce (self , other : "Metric" ) -> "Metric" :
135+ # The variant of `merge()` called inside `reduce()`. While `merge()` and
136+ # `_reduce_merge()` will be the same in many cases, there are exceptions:
137+ # see `LastValue` for an example of a `Metric` which aggregates values
138+ # differently over training steps compared how it aggregates them over
139+ # accelerators.
140+ #
141+ # `_reduce_merge()` must be associative[1], otherwise we would get
142+ # different results when using different devices.
143+ # [1] https://en.wikipedia.org/wiki/Associative_property
144+ def _reduce_merge (self , other : "Metric" ) -> "Metric" :
136145 return self .merge (other )
137146
138147 def compute (self ) -> jnp .array :
@@ -149,14 +158,14 @@ def compute_value(self) -> clu.values.Value:
149158 return clu .values .Scalar (self .compute ())
150159
151160 def reduce (self ) -> "Metric" :
152- """Reduces the metric along it first axis by calling `reduce_merge ()`.
161+ """Reduces the metric along it first axis by calling `_reduce_merge ()`.
153162
154163 This function primary use case is to aggregate metrics collected across
155164 multiple devices, rather than "merging" metrics across multiple steps.
156165
157166 In many cases these have the same semantics (such as `Average`), but
158- in some such as LastValue's batch averaging, reduction across devices is
159- averaging, while reduction across steps is taking the last value.
167+ in some such as ` LastValue` 's batch averaging, reduction across devices
168+ is averaging, while reduction across steps is taking the last value.
160169
161170 See `Collection.reduce`, for usage patterns.
162171
@@ -165,7 +174,8 @@ def reduce(self) -> "Metric":
165174 """
166175
167176 def reduce_step (reduced : Metric , metric : Metric ) -> Tuple [Metric , None ]:
168- return reduced .merge_reduce (metric ), None
177+ # pylint: disable-next=protected-access
178+ return reduced ._reduce_merge (metric ), None
169179
170180 first = jax .tree_map (lambda x : x [0 ], self )
171181 remainder = jax .tree_map (lambda x : x [1 :], self )
@@ -681,8 +691,8 @@ def merge(self, other: "LastValue") -> "LastValue":
681691 _assert_same_shape (self .value , other .value )
682692 return other
683693
684- def merge_reduce (self , other : "LastValue" ) -> "LastValue" :
685- # We need to average during reduction
694+ def _reduce_merge (self , other : "LastValue" ) -> "LastValue" :
695+ # We need to average during reduction.
686696 _assert_same_shape (self .total , other .total )
687697 return type (self )(
688698 total = self .total + other .total ,
0 commit comments