|
1 | 1 | from copy import deepcopy |
| 2 | +import glob |
| 3 | +import os |
2 | 4 | from absl.testing import absltest |
3 | 5 | from absl import flags |
4 | 6 | import time |
@@ -369,7 +371,7 @@ def original_func(a, b): |
369 | 371 | self.assertIsNone(a_pure.grad) |
370 | 372 | self.assertIsNone(b_pure.grad) |
371 | 373 |
|
372 | | - def test_composibility_with_call_jax(self): |
| 374 | + def test_composability_with_call_jax(self): |
373 | 375 |
|
374 | 376 | def jax_func(a, b): |
375 | 377 | return jnp.dot(a, b) |
@@ -407,6 +409,42 @@ def f(a, b): |
407 | 409 | msg="Forward outputs do not match", |
408 | 410 | check_device=False) |
409 | 411 |
|
| 412 | + def test_assume_pure_profile(self): |
| 413 | + """Test that xp.Trace works inside assume_pure.""" |
| 414 | + import torch_xla.debug.profiler as xp |
| 415 | + |
| 416 | + # Arrange |
| 417 | + MAGIC_STRING = 'foobar123' |
| 418 | + |
| 419 | + @assume_pure |
| 420 | + def torch_func(a, b): |
| 421 | + with xp.Trace(MAGIC_STRING): |
| 422 | + return torch.matmul(a, b) |
| 423 | + |
| 424 | + # Precompile it such that it won't be traced again on CPU. |
| 425 | + # This way we exclusively test the device-side profiles. |
| 426 | + a = torch.randn(3, 3, device='xla') |
| 427 | + b = torch.randn(3, 3, device='xla') |
| 428 | + _ = torch_func(a, b) |
| 429 | + |
| 430 | + # Act |
| 431 | + tempdir = self.create_tempdir().full_path |
| 432 | + xp.start_trace(tempdir) |
| 433 | + _ = torch_func(a, b) |
| 434 | + torch_xla.sync(wait=True) |
| 435 | + xp.stop_trace() |
| 436 | + |
| 437 | + # Assert |
| 438 | + files = glob.glob( |
| 439 | + os.path.join(tempdir, '**', '*.xplane.pb'), recursive=True) |
| 440 | + self.assertEqual(len(files), 1) |
| 441 | + |
| 442 | + path = files[0] |
| 443 | + with open(path, 'rb') as f: |
| 444 | + proto_str = str(f.read()) |
| 445 | + self.assertTrue(MAGIC_STRING in proto_str, |
| 446 | + f'Expected "{MAGIC_STRING}" trace in: {path}') |
| 447 | + |
410 | 448 |
|
411 | 449 | FLAGS = flags.FLAGS |
412 | 450 | flags.DEFINE_integer( |
|
0 commit comments