Skip to content

Commit 40d8cfa

Browse files
CLU Authorscopybara-github
authored andcommitted
Turn merge_reduce into a private method and add some documentation.
PiperOrigin-RevId: 527387176
1 parent d410a09 commit 40d8cfa

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

clu/metrics.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,16 @@ def merge(self, other: "Metric") -> "Metric":
132132
"""
133133
raise NotImplementedError("Must override merge()")
134134

135-
def merge_reduce(self, other: "Metric") -> "Metric":
135+
# The variant of `merge()` called inside `reduce()`. While `merge()` and
136+
# `_reduce_merge()` will be the same in many cases, there are exceptions:
137+
# see `LastValue` for an example of a `Metric` which aggregates values
138+
# differently over training steps compared how it aggregates them over
139+
# accelerators.
140+
#
141+
# `_reduce_merge()` must be associative[1], otherwise we would get
142+
# different results when using different devices.
143+
# [1] https://en.wikipedia.org/wiki/Associative_property
144+
def _reduce_merge(self, other: "Metric") -> "Metric":
136145
return self.merge(other)
137146

138147
def compute(self) -> jnp.array:
@@ -149,14 +158,14 @@ def compute_value(self) -> clu.values.Value:
149158
return clu.values.Scalar(self.compute())
150159

151160
def reduce(self) -> "Metric":
152-
"""Reduces the metric along it first axis by calling `reduce_merge()`.
161+
"""Reduces the metric along it first axis by calling `_reduce_merge()`.
153162
154163
This function primary use case is to aggregate metrics collected across
155164
multiple devices, rather than "merging" metrics across multiple steps.
156165
157166
In many cases these have the same semantics (such as `Average`), but
158-
in some such as LastValue's batch averaging, reduction across devices is
159-
averaging, while reduction across steps is taking the last value.
167+
in some such as `LastValue`'s batch averaging, reduction across devices
168+
is averaging, while reduction across steps is taking the last value.
160169
161170
See `Collection.reduce`, for usage patterns.
162171
@@ -165,7 +174,8 @@ def reduce(self) -> "Metric":
165174
"""
166175

167176
def reduce_step(reduced: Metric, metric: Metric) -> Tuple[Metric, None]:
168-
return reduced.merge_reduce(metric), None
177+
# pylint: disable-next=protected-access
178+
return reduced._reduce_merge(metric), None
169179

170180
first = jax.tree_map(lambda x: x[0], self)
171181
remainder = jax.tree_map(lambda x: x[1:], self)
@@ -681,8 +691,8 @@ def merge(self, other: "LastValue") -> "LastValue":
681691
_assert_same_shape(self.value, other.value)
682692
return other
683693

684-
def merge_reduce(self, other: "LastValue") -> "LastValue":
685-
# We need to average during reduction
694+
def _reduce_merge(self, other: "LastValue") -> "LastValue":
695+
# We need to average during reduction.
686696
_assert_same_shape(self.total, other.total)
687697
return type(self)(
688698
total=self.total + other.total,

0 commit comments

Comments
 (0)