Skip to content

Commit 5522c69

Browse files
authored
implement send and recv using collective_permute (#9373)
#9315
1 parent abf18e4 commit 5522c69

File tree

4 files changed

+66
-67
lines changed

4 files changed

+66
-67
lines changed

test/pjrt/test_collective_ops_tpu.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,56 @@ def test_all_to_all_single(self, use_dynamo):
336336
expected.sort().values),
337337
f"Got {val}, expected {expected}")
338338

339+
@staticmethod
340+
def _send_recv_pipeline():
341+
dist.init_process_group("xla", init_method='xla://')
342+
device = torch_xla.device()
343+
world_size = xr.world_size()
344+
cutoff = world_size // 2
345+
index = xr.global_ordinal()
346+
tensor = torch.tensor([index], dtype=torch.float, device=device)
347+
if index < cutoff:
348+
dist.send(tensor, index + cutoff)
349+
else:
350+
dist.recv(tensor, index - cutoff)
351+
return tensor.cpu()
352+
353+
@staticmethod
354+
def _send_recv_permute():
355+
dist.init_process_group("xla", init_method='xla://')
356+
device = torch_xla.device()
357+
world_size = xr.world_size()
358+
index = xr.global_ordinal()
359+
sending_tensor = torch.tensor([index], dtype=torch.float, device=device)
360+
receiving_tensor = torch.tensor([-1.0], dtype=torch.float, device=device)
361+
if index % 2 == 0:
362+
dist.send(sending_tensor, (index + 1) % world_size)
363+
dist.recv(receiving_tensor, (index - 1) % world_size)
364+
else:
365+
dist.recv(receiving_tensor, (index - 1) % world_size)
366+
dist.send(sending_tensor, (index + 1) % world_size)
367+
return receiving_tensor.cpu()
368+
369+
@absltest.skipUnless(tpu.num_available_devices() % 2 == 0,
370+
"Send/Recv test requires even number of devices")
371+
def test_send_recv_pipeline(self):
372+
"""Send tensors on first N/2 devices to second N/2 devices."""
373+
results = pjrt.run_multiprocess(self._send_recv_pipeline)
374+
world_size = tpu.num_expected_global_devices()
375+
for ordinal, value in results.items():
376+
expected = ordinal if ordinal < world_size // 2 else ordinal - world_size // 2
377+
np.testing.assert_array_equal(value, [expected])
378+
379+
@absltest.skipUnless(tpu.num_available_devices() % 2 == 0,
380+
"Send/Recv test requires even number of devices")
381+
def test_send_recv_permute(self):
382+
"""Send tensor on device i to i + 1 (module world size)."""
383+
results = pjrt.run_multiprocess(self._send_recv_permute)
384+
world_size = tpu.num_expected_global_devices()
385+
for ordinal, value in results.items():
386+
expected = (ordinal - 1) % world_size
387+
np.testing.assert_array_equal(value, [expected])
388+
339389
@staticmethod
340390
def _all_to_all():
341391
dist.init_process_group("xla", init_method='xla://')

test/test_torch_distributed_xla_backend.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -166,47 +166,6 @@ def test_reduce_scatter_coalesced(self):
166166
# purge all computations attached the device.
167167
torch_xla.sync()
168168

169-
@patch_world(0, 6)
170-
def test_send(self):
171-
device = torch_xla.device()
172-
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
173-
input_list = [tensor]
174-
175-
with mock.patch.object(
176-
torch_xla.distributed.xla_backend.ProcessGroupXla,
177-
'make_send_channel_id',
178-
new=lambda self, dst_rank, tag: dst_rank * 2):
179-
dist.send(tensor, 1)
180-
181-
send_pattern = r'%send\.\d+ = .+ send\(.+\), channel_id=2'
182-
senddone_pattern = r'%send\-done\.\d+ = .+ send\-done\(.+\), channel_id=2'
183-
hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
184-
hlo_matches(hlo, send_pattern)
185-
hlo_matches(hlo, senddone_pattern)
186-
187-
# Don't try to run Send on CPU because it's not implemented
188-
torch_xla._XLAC._clear_pending_irs(str(torch_xla.device()))
189-
190-
@patch_world(0, 6)
191-
def test_recv(self):
192-
device = torch_xla.device()
193-
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
194-
195-
with mock.patch.object(
196-
torch_xla.distributed.xla_backend.ProcessGroupXla,
197-
'make_recv_channel_id',
198-
new=lambda self, src_rank, tag: src_rank * 3):
199-
dist.recv(tensor, 1)
200-
201-
recv_pattern = r'%recv\.\d+ = .+ recv\(.+\), channel_id=3'
202-
recvdone_pattern = r'%recv\-done\.\d+ = .+ recv\-done\(.+\), channel_id=3'
203-
hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
204-
hlo_matches(hlo, recv_pattern)
205-
hlo_matches(hlo, recvdone_pattern)
206-
207-
# Don't try to run Recv on CPU because it's not implemented
208-
torch_xla._XLAC._clear_pending_irs(str(torch_xla.device()))
209-
210169
@patch_world(rank=0, size=12)
211170
def test_new_group_no_ranks(self):
212171
with new_group_barrier_disabled():

torch_xla/core/xla_model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -748,9 +748,6 @@ def collective_permute(value: torch.Tensor,
748748
pairs: List[List[int]]) -> torch.Tensor:
749749
"""Performs a XLA `CollectivePermute()` operation on the input tensor.
750750
751-
WARNING: This function is not very reliable, may produce wrong results under
752-
certain inputs. Use it at your own risk.
753-
754751
See: https://www.tensorflow.org/xla/operation_semantics#collectivepermute
755752
756753
Args:

torch_xla/distributed/xla_backend.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -337,41 +337,34 @@ def scatter(self, output_tensor_list: list[torch.Tensor],
337337
rs_opts.reduceOp = dist.ReduceOp.SUM
338338
return self.reduce_scatter(output_tensor_list, inputs, rs_opts)
339339

340-
# Dummy channel id maker. Different backend (TPU, GPU, etc) should replace
341-
# the maker with their specific one. See unit test in
342-
# test/test_torch_distributed_xla_backend.py for an example.
343-
def make_send_channel_id(self, dst_rank, tag):
344-
raise NotImplementedError
345-
346340
# Call site e.g.
347341
# https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L877
348342
def send(self, tensors, dst_rank, tag=0):
343+
logging.warning(
344+
"Individual send/recv ops are inefficient on an XLA device. Consider using xla_model.collective_permute() and specifying all source-target pairs."
345+
)
349346
results = []
350347
for t in tensors:
351-
channel_id = self.make_send_channel_id(dst_rank, tag)
352-
# The input will be returned as result.
353-
input_as_result = xm.send(t, channel_id)
354-
# Make the sent tensor depend on the token, such that the `send`
355-
# op can actually be built into the computation graph.
356-
with torch.no_grad():
357-
t.copy_(input_as_result)
358-
results.append(input_as_result)
348+
result_t = xm.collective_permute(
349+
t, pairs=[[xr.global_ordinal(), dst_rank]])
350+
torch_xla.sync()
351+
results.append(result_t)
359352
return _ret_work(results)
360353

361-
# Dummy channel id maker. Different backend (TPU, GPU, etc) should replace
362-
# the maker with their specific one. See unit test in
363-
# test/test_torch_distributed_xla_backend.py for an example.
364-
def make_recv_channel_id(self, src_rank, tag):
365-
raise NotImplementedError
366-
367354
# Call site e.g.
368355
# https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L913
369356
def recv(self, out_tensors, src_rank, tag=0):
357+
logging.warning(
358+
"Individual send/recv ops are inefficient on an XLA device. Consider using xla_model.collective_permute() and specifying all source-target pairs."
359+
)
370360
results = []
371361
for ot in out_tensors:
372-
channel_id = self.make_recv_channel_id(src_rank, tag)
373-
result = xm.recv(ot, channel_id)
374-
results.append(result)
362+
result_t = xm.collective_permute(
363+
ot, pairs=[[src_rank, xr.global_ordinal()]])
364+
torch_xla.sync()
365+
with torch.no_grad():
366+
ot.copy_(result_t)
367+
results.append(result_t)
375368
return _ret_work(results)
376369

377370
def recv_anysource(self, *args):

0 commit comments

Comments
 (0)