Skip to content

Commit 20d271a

Browse files
CLU Authorscopybara-github
authored andcommitted
Fix TPU time measurements in PeriodicAction.timed().
The current code uses the time difference between the completions of two `_squareit()` calls to measure the TPU time spent in yield. But the two calls to `_squareit()` are started on a worker thread and are not synchronized with the calling thread. ``` with report.timed('foobar'): tpu_fn() # <- this can slip in before start_measurement() tpu_fn() # <- this can slip in before stop_measurement() ``` To guarantee that the TPU programs are dispatched in the right order we need to dispatch them on the main thread. `block_until_ready()` is still called on the WorkerThread so the main thread isn't blocked. Drive-by changes: - Split the two code paths (`wait_jax_async_dispatch`) as they share even less common code now. - No longer launder variables through `_time_per_thread`. Instead pass the start time in a future. PiperOrigin-RevId: 538146372
1 parent dd4af73 commit 20d271a

File tree

1 file changed

+32
-24
lines changed

1 file changed

+32
-24
lines changed

clu/periodic_actions.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import concurrent.futures
2020
import contextlib
2121
import os
22-
import queue
2322
import time
2423
from typing import Callable, Iterable, Optional, Sequence
2524

@@ -46,16 +45,6 @@ def _squareit(x):
4645
return x**2
4746

4847

49-
def _wait_jax_async_dispatch():
50-
"""Creates a simple JAX program and waits for its completion.
51-
52-
Since JAX operations are put in a queue and dispatched one after the other,
53-
all previously enqueued computations will be finished after a call to this
54-
function.
55-
"""
56-
_squareit(jnp.array(0.)).block_until_ready()
57-
58-
5948
def _format_secs(secs: float):
6049
"""Formats seconds like 123456.7 to strings like "1d10h17m"."""
6150
s = ""
@@ -197,7 +186,6 @@ def __init__(self,
197186
num_train_steps = None
198187
self._num_train_steps = num_train_steps
199188
self._writer = writer
200-
self._waiting_for_part = collections.defaultdict(queue.Queue)
201189
self._time_per_part = collections.defaultdict(float)
202190
self._t0 = time.monotonic()
203191
# Using max_worker=1 guarantees that the calls to _wait_jax_async_dispatch()
@@ -269,20 +257,40 @@ def timed(self, name: str, wait_jax_async_dispatch: bool = True):
269257
dispatch queue, then both measurements are identical.
270258
"""
271259
# pylint: enable=g-doc-return-or-yield
272-
def start_measurement():
273-
if wait_jax_async_dispatch:
274-
_wait_jax_async_dispatch()
275-
self._waiting_for_part[name].put(time.monotonic())
276-
self._executor.submit(start_measurement)
277-
260+
if not wait_jax_async_dispatch:
261+
# Easy case, just measure walltime.
262+
start = time.monotonic()
263+
yield
264+
self._time_per_part[name] += time.monotonic() - start
265+
return
266+
267+
def start_measurement(barrier: jax.Array) -> float:
268+
barrier.block_until_ready()
269+
return time.monotonic()
270+
271+
def stop_measurement(
272+
start_future: concurrent.futures.Future[float], barrier: jax.Array
273+
):
274+
barrier.block_until_ready()
275+
self._time_per_part[name] += time.monotonic() - start_future.result()
276+
277+
# Call _squareit on this thread so that it is guaranteed to be dispatched
278+
# to the TPU before any computations inside `yield`.
279+
start_future = self._executor.submit(
280+
start_measurement, barrier=_squareit(jnp.array(0.0))
281+
)
278282
yield
279283

280-
def stop_measurement():
281-
if wait_jax_async_dispatch:
282-
_wait_jax_async_dispatch()
283-
dt = time.monotonic() - self._waiting_for_part[name].get()
284-
self._time_per_part[name] += dt
285-
self._executor.submit(stop_measurement)
284+
# Same pattern: _squareit is dispatched after any programs dispatched from
285+
# within `yield` and before any programs following this method. The time
286+
# difference between the completion of the first _squareit and the this one
287+
# is the time the TPU spent executing programs dispatched from within
288+
# `yield`.
289+
self._executor.submit(
290+
stop_measurement,
291+
start_future=start_future,
292+
barrier=_squareit(jnp.array(0.0)),
293+
)
286294

287295

288296
class Profile(PeriodicAction):

0 commit comments

Comments
 (0)