From 66cc5f8ffcf2d6046c6fef8ff5731a1492e81ff5 Mon Sep 17 00:00:00 2001 From: 0xDaizz Date: Tue, 24 Feb 2026 19:34:17 +0900 Subject: [PATCH] Add all_to_all collective primitive MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add `mx.distributed.all_to_all(x)` — splits input along axis 0, sends chunk i to rank i, and concatenates received chunks. - CPU eval via backend-specific GroupImpl::all_to_all - MPI backend (MPI_Alltoall) and JACCL backend (RDMA pipelined) - VJP support (all_to_all is its own transpose) - GPU/CUDA stubs (not-implemented exceptions) Co-Authored-By: Claude Opus 4.6 --- docs/src/python/distributed.rst | 1 + mlx/backend/cpu/distributed.cpp | 15 +++ mlx/backend/cuda/distributed.cu | 4 + mlx/backend/metal/distributed.cpp | 4 + mlx/backend/no_cpu/primitives.cpp | 1 + mlx/backend/no_gpu/primitives.cpp | 1 + mlx/distributed/distributed.cpp | 8 ++ mlx/distributed/distributed_impl.h | 4 + mlx/distributed/jaccl/mesh.cpp | 102 +++++++++++++++++++ mlx/distributed/jaccl/mesh.h | 1 + mlx/distributed/jaccl/ring.h | 4 + mlx/distributed/mpi/mpi.cpp | 25 +++++ mlx/distributed/nccl/nccl.cpp | 4 + mlx/distributed/ops.cpp | 23 +++++ mlx/distributed/ops.h | 5 + mlx/distributed/primitives.cpp | 21 ++++ mlx/distributed/primitives.h | 25 +++++ mlx/distributed/ring/ring.cpp | 4 + python/src/distributed.cpp | 36 +++++++ python/tests/mlx_distributed_tests.py | 138 ++++++++++++++++++++++++++ 20 files changed, 426 insertions(+) diff --git a/docs/src/python/distributed.rst b/docs/src/python/distributed.rst index 8b48d727e0..31a0e91025 100644 --- a/docs/src/python/distributed.rst +++ b/docs/src/python/distributed.rst @@ -17,6 +17,7 @@ made available. init all_sum all_gather + all_to_all send recv recv_like diff --git a/mlx/backend/cpu/distributed.cpp b/mlx/backend/cpu/distributed.cpp index 22dc4b4cc8..c2dcc05258 100644 --- a/mlx/backend/cpu/distributed.cpp +++ b/mlx/backend/cpu/distributed.cpp @@ -100,4 +100,19 @@ void ReduceScatter::eval_cpu( std::vector& outputs) { throw std::runtime_error("[ReduceScatter] Not implemented yet."); } + +void AllToAll::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + auto [in, copied] = ensure_row_contiguous(inputs[0], stream()); + outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); + distributed::detail::all_to_all(group(), in, outputs[0], stream()); + if (copied) { + auto& enc = cpu::get_command_encoder(stream()); + enc.add_temporary(in); + } +} + } // namespace mlx::core::distributed diff --git a/mlx/backend/cuda/distributed.cu b/mlx/backend/cuda/distributed.cu index ac79875789..44dd2d06c0 100644 --- a/mlx/backend/cuda/distributed.cu +++ b/mlx/backend/cuda/distributed.cu @@ -118,4 +118,8 @@ void ReduceScatter::eval_gpu( throw std::runtime_error("Only sum scatter is supported. "); } } + +void AllToAll::eval_gpu(const std::vector&, std::vector&) { + throw std::runtime_error("[AllToAll::eval_gpu] has no CUDA implementation."); +} } // namespace mlx::core::distributed diff --git a/mlx/backend/metal/distributed.cpp b/mlx/backend/metal/distributed.cpp index 217ee3c946..88c322b660 100644 --- a/mlx/backend/metal/distributed.cpp +++ b/mlx/backend/metal/distributed.cpp @@ -35,4 +35,8 @@ void ReduceScatter::eval_gpu(const std::vector&, std::vector&) { "[ReduceScatter::eval_gpu] has no GPU implementation."); } +void AllToAll::eval_gpu(const std::vector&, std::vector&) { + throw std::runtime_error("[AllToAll::eval_gpu] has no GPU implementation."); +} + } // namespace mlx::core::distributed diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index ae51dd9b2f..ba9304329b 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -141,6 +141,7 @@ NO_CPU_MULTI(AllGather) NO_CPU_MULTI(Send) NO_CPU_MULTI(Recv) NO_CPU_MULTI(ReduceScatter) +NO_CPU_MULTI(AllToAll) } // namespace distributed } // namespace mlx::core diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 4819ed2724..2384ba29aa 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -180,6 +180,7 @@ NO_GPU_MULTI(AllGather) NO_GPU_MULTI(Send) NO_GPU_MULTI(Recv) NO_GPU_MULTI(ReduceScatter) +NO_GPU_MULTI(AllToAll) } // namespace distributed } // namespace mlx::core diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp index 3cde6a263b..60e1119648 100644 --- a/mlx/distributed/distributed.cpp +++ b/mlx/distributed/distributed.cpp @@ -50,6 +50,10 @@ void sum_scatter( group.raw_group()->sum_scatter(input, output, stream); } +void all_to_all(Group group, const array& input, array& output, Stream stream) { + group.raw_group()->all_to_all(input, output, stream); +} + class EmptyGroup : public GroupImpl { public: Stream communication_stream(StreamOrDevice s) override { @@ -98,6 +102,10 @@ class EmptyGroup : public GroupImpl { throw std::runtime_error( "Communication not implemented in an empty distributed group."); } + void all_to_all(const array&, array&, Stream) override { + throw std::runtime_error( + "Communication not implemented in an empty distributed group."); + } }; } // namespace detail diff --git a/mlx/distributed/distributed_impl.h b/mlx/distributed/distributed_impl.h index d889587abc..198c8f4554 100644 --- a/mlx/distributed/distributed_impl.h +++ b/mlx/distributed/distributed_impl.h @@ -30,6 +30,7 @@ class GroupImpl { virtual void all_min(const array& input, array& output, Stream stream) = 0; virtual void sum_scatter(const array& input, array& output, Stream stream) = 0; + virtual void all_to_all(const array& input, array& output, Stream stream) = 0; }; /* Define the MLX stream that the communication should happen in. */ @@ -56,4 +57,7 @@ void all_min(Group group, const array& input, array& output, Stream stream); /** Reduce scatter with average operation */ void sum_scatter(Group group, const array& input, array& output, Stream stream); +/** All-to-all exchange */ +void all_to_all(Group group, const array& input, array& output, Stream stream); + } // namespace mlx::core::distributed::detail diff --git a/mlx/distributed/jaccl/mesh.cpp b/mlx/distributed/jaccl/mesh.cpp index c8df4e6745..51f74758cd 100644 --- a/mlx/distributed/jaccl/mesh.cpp +++ b/mlx/distributed/jaccl/mesh.cpp @@ -448,4 +448,106 @@ void MeshGroup::all_reduce( }); } +void MeshGroup::all_to_all(const array& input, array& output, Stream stream) { + if (size_ != 2) { + throw std::runtime_error( + "[jaccl] all_to_all currently supports size == 2, got " + + std::to_string(size_) + "."); + } + auto in_ptr = input.data(); + auto out_ptr = output.data(); + if (in_ptr == out_ptr) { + throw std::runtime_error( + "[jaccl] in-place all_to_all is not supported (input/output alias)."); + } + int64_t n_bytes = static_cast(input.nbytes()); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(input); + encoder.set_output_array(output); + encoder.dispatch([in_ptr, out_ptr, n_bytes, this]() { + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE * 2; + + int peer = 1 - rank_; + int64_t per_peer_bytes = n_bytes / size_; + + // Local chunk: input[rank] -> output[rank] + std::memcpy( + out_ptr + rank_ * per_peer_bytes, + in_ptr + rank_ * per_peer_bytes, + per_peer_bytes); + + if (per_peer_bytes == 0) + return; + + char* send_src = const_cast(in_ptr) + peer * per_peer_bytes; + char* recv_dst = out_ptr + peer * per_peer_bytes; + + auto [sz, N] = buffer_size_from_message(per_peer_bytes); + + int in_flight = 0; + int64_t read_offset = 0; + int64_t write_offset = 0; + + // Prefill: recv-first (deadlock prevention) + int buff = 0; + while (read_offset < per_peer_bytes && buff < PIPELINE) { + recv_from(sz, peer, buff); + in_flight++; + + std::copy( + send_src + read_offset, + send_src + + std::min(read_offset + static_cast(N), per_peer_bytes), + send_buffer(sz, buff).begin()); + send_to(sz, peer, buff); + in_flight++; + + read_offset += N; + buff++; + } + + // Single poll loop + while (in_flight > 0) { + ibv_wc wc[WC_NUM]; + int n = connections_[peer].poll(WC_NUM, wc); + + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int b = (wc[i].wr_id >> 8) & 0xff; + + in_flight--; + + if (work_type == SEND_WR) { + if (read_offset < per_peer_bytes) { + std::copy( + send_src + read_offset, + send_src + + std::min( + read_offset + static_cast(N), per_peer_bytes), + send_buffer(sz, b).begin()); + send_to(sz, peer, b); + in_flight++; + read_offset += N; + } + } else if (work_type == RECV_WR) { + std::copy( + recv_buffer(sz, b, peer).begin(), + recv_buffer(sz, b, peer).begin() + + std::min( + static_cast(N), per_peer_bytes - write_offset), + recv_dst + write_offset); + write_offset += N; + + if (write_offset + (PIPELINE - 1) * N < per_peer_bytes) { + recv_from(sz, peer, b); + in_flight++; + } + } + } + } + }); +} + } // namespace mlx::core::distributed::jaccl diff --git a/mlx/distributed/jaccl/mesh.h b/mlx/distributed/jaccl/mesh.h index 6f779e9ccb..2a4ec51f28 100644 --- a/mlx/distributed/jaccl/mesh.h +++ b/mlx/distributed/jaccl/mesh.h @@ -41,6 +41,7 @@ class MeshGroup : public GroupImpl { void all_max(const array& input, array& output, Stream stream) override; void all_min(const array& input, array& output, Stream stream) override; void all_gather(const array& input, array& output, Stream stream) override; + void all_to_all(const array& input, array& output, Stream stream) override; void send(const array& input, int dst, Stream stream) override; void recv(array& out, int src, Stream stream) override; diff --git a/mlx/distributed/jaccl/ring.h b/mlx/distributed/jaccl/ring.h index a59ceb3dd8..e7dfbbf5ab 100644 --- a/mlx/distributed/jaccl/ring.h +++ b/mlx/distributed/jaccl/ring.h @@ -52,6 +52,10 @@ class RingGroup : public GroupImpl { throw std::runtime_error("[jaccl] sum_scatter not supported."); } + void all_to_all(const array& input, array& output, Stream stream) override { + throw std::runtime_error("[jaccl] all_to_all not supported."); + } + std::shared_ptr split(int color, int key = -1) override { throw std::runtime_error("[jaccl] Group split not supported."); } diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index 3b176e6e67..65aa3f7b19 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -131,6 +131,7 @@ struct MPIWrapper { LOAD_SYMBOL(MPI_Allgather, all_gather); LOAD_SYMBOL(MPI_Send, send); LOAD_SYMBOL(MPI_Recv, recv); + LOAD_SYMBOL(MPI_Alltoall, all_to_all); LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous); LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit); LOAD_SYMBOL(MPI_Op_create, mpi_op_create); @@ -294,6 +295,14 @@ struct MPIWrapper { int (*comm_free)(MPI_Comm*); int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm); int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*); + int (*all_to_all)( + const void*, + int, + MPI_Datatype, + void*, + int, + MPI_Datatype, + MPI_Comm); // Objects MPI_Comm comm_world_; @@ -476,6 +485,22 @@ class MPIGroup : public GroupImpl { throw std::runtime_error("[mpi] sum_scatter not yet implemented."); } + void all_to_all(const array& input, array& output, Stream stream) override { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(input); + encoder.set_output_array(output); + int count = input.size() / size(); + encoder.dispatch( + mpi().all_to_all, + input.data(), + count, + mpi().datatype(input), + output.data(), + count, + mpi().datatype(output), + comm_); + } + private: MPI_Comm comm_; bool global_; diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp index d8244bf94f..5b3ed53c9a 100644 --- a/mlx/distributed/nccl/nccl.cpp +++ b/mlx/distributed/nccl/nccl.cpp @@ -356,6 +356,10 @@ class NCCLGroup : public GroupImpl { }); } + void all_to_all(const array&, array&, Stream) override { + throw std::runtime_error("[nccl] all_to_all not yet implemented."); + } + template void all_reduce_impl( const array& input, diff --git a/mlx/distributed/ops.cpp b/mlx/distributed/ops.cpp index 1762f0e6bc..5afb2f04b7 100644 --- a/mlx/distributed/ops.cpp +++ b/mlx/distributed/ops.cpp @@ -183,4 +183,27 @@ array sum_scatter( std::make_shared(stream, group, ReduceScatter::Sum), {x}); } + +array all_to_all( + const array& x, + std::optional group_ /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + auto group = to_group(group_); + if (group.size() == 1) { + return x; + } + if (x.ndim() < 1) { + throw std::invalid_argument("[all_to_all] Input must be at least 1-D."); + } + if (x.shape(0) % group.size() != 0) { + std::ostringstream msg; + msg << "[all_to_all] Invalid shape=" << x.shape() << " for a group of size " + << group.size() + << ". The first dimension (axis 0) must be divisible by the group size."; + throw std::invalid_argument(msg.str()); + } + auto stream = detail::communication_stream(group, s); + return array( + x.shape(), x.dtype(), std::make_shared(stream, group), {x}); +} } // namespace mlx::core::distributed diff --git a/mlx/distributed/ops.h b/mlx/distributed/ops.h index e223c5bea2..daf9bc20da 100644 --- a/mlx/distributed/ops.h +++ b/mlx/distributed/ops.h @@ -54,4 +54,9 @@ MLX_API array sum_scatter( std::optional group = std::nullopt, StreamOrDevice s = {}); +MLX_API array all_to_all( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + } // namespace mlx::core::distributed diff --git a/mlx/distributed/primitives.cpp b/mlx/distributed/primitives.cpp index 5e8d5327a1..42281349c6 100644 --- a/mlx/distributed/primitives.cpp +++ b/mlx/distributed/primitives.cpp @@ -92,4 +92,25 @@ std::pair, std::vector> Send::vmap( return {{send(inputs[0], dst_, group(), stream())}, axes}; } +std::pair, std::vector> AllToAll::vmap( + const std::vector& inputs, + const std::vector& axes) { + return {{all_to_all(inputs[0], group(), stream())}, axes}; +} + +std::vector AllToAll::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector&) { + return {all_to_all(tangents[0], group(), stream())}; +} + +std::vector AllToAll::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector&, + const std::vector&) { + return {all_to_all(cotangents[0], group(), stream())}; +} + } // namespace mlx::core::distributed diff --git a/mlx/distributed/primitives.h b/mlx/distributed/primitives.h index 18a0d65f5f..c683888989 100644 --- a/mlx/distributed/primitives.h +++ b/mlx/distributed/primitives.h @@ -153,4 +153,29 @@ class ReduceScatter : public DistPrimitive { private: ReduceType reduce_type_; }; + +class AllToAll : public DistPrimitive { + public: + AllToAll(Stream stream, Group group) : DistPrimitive(stream, group) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + std::vector jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) override; + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_NAME(AllToAll); +}; } // namespace mlx::core::distributed diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index ea40042844..78cb69d937 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -579,6 +579,10 @@ class RingGroup : public GroupImpl { throw std::runtime_error("[ring] sum_scatter not supported."); } + void all_to_all(const array&, array&, Stream) override { + throw std::runtime_error("[ring] all_to_all not supported."); + } + private: template void all_reduce( diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp index 9f4a7cb59e..a28c4c390e 100644 --- a/python/src/distributed.cpp +++ b/python/src/distributed.cpp @@ -349,4 +349,40 @@ void init_distributed(nb::module_& parent_module) { Returns: array: The output array with shape ``[x.shape[0] // group.size(), *x.shape[1:]]``. )pbdoc"); + + m.def( + "all_to_all", + [](const ScalarOrArray& x, + std::optional group, + mx::StreamOrDevice s) { + return mx::distributed::all_to_all(to_array(x), group, s); + }, + "x"_a, + nb::kw_only(), + "group"_a = nb::none(), + "stream"_a = nb::none(), + nb::sig( + "def all_to_all(x: array, *, group: Optional[Group] = None, " + "stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + All-to-all exchange of data between processes. + + Each process splits its input along the first axis into ``group.size()`` + chunks and sends chunk *i* to process *i*. All processes receive one chunk + from every other process and concatenate them in rank order. The output + has the same shape as the input. + + ``x.shape[0]`` must be divisible by the group size. + + Args: + x (array): Input array. + group (Group): The group of processes that will participate in the + exchange. If set to ``None`` the global group is used. Default: + ``None``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The result of the all-to-all exchange. + )pbdoc"); } diff --git a/python/tests/mlx_distributed_tests.py b/python/tests/mlx_distributed_tests.py index 644da793be..cc1e6606a3 100644 --- a/python/tests/mlx_distributed_tests.py +++ b/python/tests/mlx_distributed_tests.py @@ -317,6 +317,144 @@ def sharding(path, weight): y2 = smod(x) self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4)) + def _skip_if_all_to_all_unsupported(self, group): + try: + test = mx.distributed.all_to_all(mx.zeros((group.size(),)), group=group) + mx.eval(test) + except RuntimeError as e: + msg = str(e) + if "not supported" in msg or "not yet implemented" in msg or "currently supports size" in msg: + self.skipTest(f"all_to_all not supported: {msg}") + raise + + def test_all_to_all(self): + group = mx.distributed.init() + self._skip_if_all_to_all_unsupported(group) + world_size = group.size() + rank = group.rank() + + # Test multiple dtypes + dtypes = [mx.float32, mx.float16, mx.bfloat16, mx.int32] + + for dt in dtypes: + # Create a [world_size * 4, 8] tensor with rank-specific values + rows = world_size * 4 + cols = 8 + x = (mx.ones((rows, cols), dtype=dt) * (rank * 100)) + mx.broadcast_to( + mx.arange(rows).reshape(-1, 1), (rows, cols) + ).astype(dt) + + y = mx.distributed.all_to_all(x, group=group) + mx.eval(y) + + # Output shape should equal input shape + self.assertEqual(y.shape, x.shape) + + if world_size == 1: + # For single process: all_to_all is identity + self.assertTrue(mx.array_equal(y, x).item()) + else: + # For multi-process: verify the all-to-all permutation + # Each rank's output chunk i should come from rank i's input chunk rank + chunk_size = rows // world_size + for src_rank in range(world_size): + out_chunk = y[src_rank * chunk_size : (src_rank + 1) * chunk_size] + # This chunk should be what src_rank sent to us (our rank-th chunk of src_rank's input) + expected_vals = ( + mx.ones((chunk_size, cols), dtype=dt) * (src_rank * 100) + ) + mx.broadcast_to( + mx.arange(rank * chunk_size, (rank + 1) * chunk_size).reshape( + -1, 1 + ), + (chunk_size, cols), + ).astype( + dt + ) + self.assertTrue(mx.array_equal(out_chunk, expected_vals).item()) + + def test_all_to_all_sizes(self): + group = mx.distributed.init() + self._skip_if_all_to_all_unsupported(group) + world_size = group.size() + + # Test various input sizes + sizes = [ + (world_size,), # minimal 1D + (world_size * 256, 64), # medium 2D + (world_size * 2, 3, 4, 5), # multi-dimensional + ] + + for sh in sizes: + x = mx.ones(sh, dtype=mx.float32) + y = mx.distributed.all_to_all(x, group=group) + mx.eval(y) + + self.assertEqual(y.shape, x.shape) + if world_size == 1: + self.assertTrue(mx.array_equal(y, x).item()) + + def test_all_to_all_non_contiguous(self): + group = mx.distributed.init() + self._skip_if_all_to_all_unsupported(group) + world_size = group.size() + + # Create a non-contiguous input via transpose then slice + base = mx.random.normal((8, world_size * 4)) + x_non_contig = base.T # shape (world_size * 4, 8), non-contiguous + + # Create contiguous copy + x_contig = mx.array(x_non_contig) + + y1 = mx.distributed.all_to_all(x_non_contig, group=group) + y2 = mx.distributed.all_to_all(x_contig, group=group) + mx.eval(y1, y2) + + self.assertTrue(mx.allclose(y1, y2).item()) + + def test_all_to_all_vjp(self): + group = mx.distributed.init() + self._skip_if_all_to_all_unsupported(group) + world_size = group.size() + + x = mx.random.normal((world_size * 4, 8)) + mx.eval(x) + + # Test mx.grad + grad_fn = mx.grad(lambda x: mx.distributed.all_to_all(x, group=group).sum()) + g = grad_fn(x) + mx.eval(g) + + if world_size == 1: + # For single process: gradient of identity + sum is all ones + self.assertTrue(mx.allclose(g, mx.ones_like(g)).item()) + + # Test mx.value_and_grad + val_grad_fn = mx.value_and_grad( + lambda x: mx.distributed.all_to_all(x, group=group).sum() + ) + val, g2 = val_grad_fn(x) + mx.eval(val, g2) + + self.assertEqual(g2.shape, x.shape) + + def test_all_to_all_shape_validation(self): + group = mx.distributed.init() + if group.size() == 1: + self.skipTest("requires world_size > 1") + self._skip_if_all_to_all_unsupported(group) + world_size = group.size() + + # Test that scalar input raises an exception + scalar = mx.array(1.0) + with self.assertRaises(Exception): + mx.eval(mx.distributed.all_to_all(scalar, group=group)) + + # Test that x.shape[0] % world_size != 0 raises (only meaningful for world_size > 1) + if world_size > 1: + bad = mx.ones((world_size * 4 + 1, 8)) + with self.assertRaises(Exception): + mx.eval(mx.distributed.all_to_all(bad, group=group)) + def test_all_gather(self): world = mx.distributed.init() dtypes = [