Skip to content

Commit d663c7e

Browse files
authored
Support trace_me and xp.Trace in assume_pure (#9311)
1 parent 6e8e7db commit d663c7e

File tree

5 files changed

+61
-12
lines changed

5 files changed

+61
-12
lines changed

test/test_assume_pure.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from copy import deepcopy
2+
import glob
3+
import os
24
from absl.testing import absltest
35
from absl import flags
46
import time
@@ -369,7 +371,7 @@ def original_func(a, b):
369371
self.assertIsNone(a_pure.grad)
370372
self.assertIsNone(b_pure.grad)
371373

372-
def test_composibility_with_call_jax(self):
374+
def test_composability_with_call_jax(self):
373375

374376
def jax_func(a, b):
375377
return jnp.dot(a, b)
@@ -407,6 +409,42 @@ def f(a, b):
407409
msg="Forward outputs do not match",
408410
check_device=False)
409411

412+
def test_assume_pure_profile(self):
413+
"""Test that xp.Trace works inside assume_pure."""
414+
import torch_xla.debug.profiler as xp
415+
416+
# Arrange
417+
MAGIC_STRING = 'foobar123'
418+
419+
@assume_pure
420+
def torch_func(a, b):
421+
with xp.Trace(MAGIC_STRING):
422+
return torch.matmul(a, b)
423+
424+
# Precompile it such that it won't be traced again on CPU.
425+
# This way we exclusively test the device-side profiles.
426+
a = torch.randn(3, 3, device='xla')
427+
b = torch.randn(3, 3, device='xla')
428+
_ = torch_func(a, b)
429+
430+
# Act
431+
tempdir = self.create_tempdir().full_path
432+
xp.start_trace(tempdir)
433+
_ = torch_func(a, b)
434+
torch_xla.sync(wait=True)
435+
xp.stop_trace()
436+
437+
# Assert
438+
files = glob.glob(
439+
os.path.join(tempdir, '**', '*.xplane.pb'), recursive=True)
440+
self.assertEqual(len(files), 1)
441+
442+
path = files[0]
443+
with open(path, 'rb') as f:
444+
proto_str = str(f.read())
445+
self.assertTrue(MAGIC_STRING in proto_str,
446+
f'Expected "{MAGIC_STRING}" trace in: {path}')
447+
410448

411449
FLAGS = flags.FLAGS
412450
flags.DEFINE_integer(

test/test_pallas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from torch.ao.quantization.utils import determine_qparams
1010

1111
import torch_xla
12-
import torch_xla.core.xla_model as xm
1312
from torch_xla import runtime as xr
1413
from torch_xla._internal import tpu
1514

@@ -26,6 +25,7 @@
2625
def with_jax_high_precision(func):
2726

2827
def wrapper(*args, **kwargs):
28+
import jax
2929
jax.config.update('jax_default_matmul_precision', "highest")
3030
try:
3131
result = func(*args, **kwargs)

test/test_profiler.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
import sys
77
import tempfile
8-
import time
8+
import signal
99
import unittest
1010

1111
import args_parse
@@ -23,7 +23,7 @@ def train_worker(port, training_started):
2323
batch_size=16,
2424
momentum=0.5,
2525
lr=0.01,
26-
num_epochs=10)
26+
num_epochs=100)
2727
flags.fake_data = True
2828
flags.profiler_port = port
2929

@@ -86,17 +86,26 @@ def test_trace_and_metrics(self):
8686
training_started = context.Event()
8787
p = context.Process(
8888
target=train_worker, args=(port, training_started), daemon=True)
89+
90+
# Wait for training to start.
8991
p.start()
90-
training_started.wait(60)
92+
training_started.wait(600)
9193

94+
# Take a profile.
9295
logdir = tempfile.mkdtemp()
9396
xp.trace(
9497
f'localhost:{port}',
9598
logdir,
9699
duration_ms=5000,
97100
num_tracing_attempts=5,
98101
delay_ms=1000)
99-
p.terminate()
102+
pid = p.pid
103+
assert pid is not None, 'Process ID should not be None'
104+
# Gracefully interrupt the process.
105+
os.kill(pid, signal.SIGINT)
106+
p.join()
107+
108+
# Validate the profiling output.
100109
path = self._check_xspace_pb_exist(logdir)
101110
self._check_trace_namespace_exists(path)
102111
self._check_metrics_warnings_exist(self.fname)

torch_xla/debug/profiler.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,9 @@ class Trace(torch_xla._XLAC.profiler.TraceMe):
107107
108108
The traces generated can then be collected using the above profiling APIs.
109109
The profiling server first needs to be started up and then can be sampled
110-
either using Tensorboard profiler plugin
111-
(https://github.com/tensorflow/profiler) or the
110+
either using xprof (https://github.com/openxla/xprof) or the
112111
:func:`~torch_xla.debug.profiler.trace` method.
113112
114-
Note: currently only supports PyTorch/XLA client side trace events. i.e.,
115-
the namespace won't group TPU worker side trace.
116-
117113
Example usage:
118114
```python
119115
server = xp.start_server(9012)
@@ -132,7 +128,13 @@ def __enter__(self):
132128
self.scope = torch_xla._XLAC.profiler.scope_pusher(self.name)
133129
super().__enter__()
134130

131+
# Also enter the JAX named scope, to support torchax lowering.
132+
import jax
133+
self._jax_scope = jax.named_scope(self.name)
134+
self._jax_scope.__enter__()
135+
135136
def __exit__(self, type, value, traceback):
137+
self._jax_scope.__exit__(type, value, traceback)
136138
if getattr(self, 'scope', None):
137139
del self.scope
138140
super().__exit__(type, value, traceback)

torch_xla/experimental/custom_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,7 @@ def flash_attention(
868868
sm_scale, ab, partition_spec, mesh)
869869

870870

871-
# This function should only be called and excuted on runtime.
871+
# This function should only be called and executed on runtime.
872872
def _ragged_paged_attention_runtime_check(
873873
q, # [max_num_batched_tokens, num_q_heads, head_dim]
874874
kv_pages, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]

0 commit comments

Comments
 (0)