Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/python/distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ made available.
init
all_sum
all_gather
all_to_all
send
recv
recv_like
15 changes: 15 additions & 0 deletions mlx/backend/cpu/distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,19 @@ void ReduceScatter::eval_cpu(
std::vector<array>& outputs) {
throw std::runtime_error("[ReduceScatter] Not implemented yet.");
}

void AllToAll::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
distributed::detail::all_to_all(group(), in, outputs[0], stream());
if (copied) {
auto& enc = cpu::get_command_encoder(stream());
enc.add_temporary(in);
}
}

} // namespace mlx::core::distributed
4 changes: 4 additions & 0 deletions mlx/backend/cuda/distributed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,8 @@ void ReduceScatter::eval_gpu(
throw std::runtime_error("Only sum scatter is supported. ");
}
}

void AllToAll::eval_gpu(const std::vector<array>&, std::vector<array>&) {
throw std::runtime_error("[AllToAll::eval_gpu] has no CUDA implementation.");
}
} // namespace mlx::core::distributed
4 changes: 4 additions & 0 deletions mlx/backend/metal/distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,8 @@ void ReduceScatter::eval_gpu(const std::vector<array>&, std::vector<array>&) {
"[ReduceScatter::eval_gpu] has no GPU implementation.");
}

void AllToAll::eval_gpu(const std::vector<array>&, std::vector<array>&) {
throw std::runtime_error("[AllToAll::eval_gpu] has no GPU implementation.");
}

} // namespace mlx::core::distributed
1 change: 1 addition & 0 deletions mlx/backend/no_cpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ NO_CPU_MULTI(AllGather)
NO_CPU_MULTI(Send)
NO_CPU_MULTI(Recv)
NO_CPU_MULTI(ReduceScatter)
NO_CPU_MULTI(AllToAll)
} // namespace distributed

} // namespace mlx::core
1 change: 1 addition & 0 deletions mlx/backend/no_gpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
NO_GPU_MULTI(ReduceScatter)
NO_GPU_MULTI(AllToAll)
} // namespace distributed

} // namespace mlx::core
8 changes: 8 additions & 0 deletions mlx/distributed/distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ void sum_scatter(
group.raw_group()->sum_scatter(input, output, stream);
}

void all_to_all(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_to_all(input, output, stream);
}

class EmptyGroup : public GroupImpl {
public:
Stream communication_stream(StreamOrDevice s) override {
Expand Down Expand Up @@ -98,6 +102,10 @@ class EmptyGroup : public GroupImpl {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
void all_to_all(const array&, array&, Stream) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
};

} // namespace detail
Expand Down
4 changes: 4 additions & 0 deletions mlx/distributed/distributed_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class GroupImpl {
virtual void all_min(const array& input, array& output, Stream stream) = 0;
virtual void
sum_scatter(const array& input, array& output, Stream stream) = 0;
virtual void all_to_all(const array& input, array& output, Stream stream) = 0;
};

/* Define the MLX stream that the communication should happen in. */
Expand All @@ -56,4 +57,7 @@ void all_min(Group group, const array& input, array& output, Stream stream);
/** Reduce scatter with average operation */
void sum_scatter(Group group, const array& input, array& output, Stream stream);

/** All-to-all exchange */
void all_to_all(Group group, const array& input, array& output, Stream stream);

} // namespace mlx::core::distributed::detail
102 changes: 102 additions & 0 deletions mlx/distributed/jaccl/mesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,4 +448,106 @@ void MeshGroup::all_reduce(
});
}

void MeshGroup::all_to_all(const array& input, array& output, Stream stream) {
if (size_ != 2) {
throw std::runtime_error(
"[jaccl] all_to_all currently supports size == 2, got " +
std::to_string(size_) + ".");
}
auto in_ptr = input.data<char>();
auto out_ptr = output.data<char>();
if (in_ptr == out_ptr) {
throw std::runtime_error(
"[jaccl] in-place all_to_all is not supported (input/output alias).");
}
int64_t n_bytes = static_cast<int64_t>(input.nbytes());

auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);
encoder.set_output_array(output);
encoder.dispatch([in_ptr, out_ptr, n_bytes, this]() {
constexpr int PIPELINE = 2;
constexpr int WC_NUM = PIPELINE * 2;

int peer = 1 - rank_;
int64_t per_peer_bytes = n_bytes / size_;

// Local chunk: input[rank] -> output[rank]
std::memcpy(
out_ptr + rank_ * per_peer_bytes,
in_ptr + rank_ * per_peer_bytes,
per_peer_bytes);

if (per_peer_bytes == 0)
return;

char* send_src = const_cast<char*>(in_ptr) + peer * per_peer_bytes;
char* recv_dst = out_ptr + peer * per_peer_bytes;

auto [sz, N] = buffer_size_from_message(per_peer_bytes);

int in_flight = 0;
int64_t read_offset = 0;
int64_t write_offset = 0;

// Prefill: recv-first (deadlock prevention)
int buff = 0;
while (read_offset < per_peer_bytes && buff < PIPELINE) {
recv_from(sz, peer, buff);
in_flight++;

std::copy(
send_src + read_offset,
send_src +
std::min(read_offset + static_cast<int64_t>(N), per_peer_bytes),
send_buffer(sz, buff).begin<char>());
send_to(sz, peer, buff);
in_flight++;

read_offset += N;
buff++;
}

// Single poll loop
while (in_flight > 0) {
ibv_wc wc[WC_NUM];
int n = connections_[peer].poll(WC_NUM, wc);

for (int i = 0; i < n; i++) {
int work_type = wc[i].wr_id >> 16;
int b = (wc[i].wr_id >> 8) & 0xff;

in_flight--;

if (work_type == SEND_WR) {
if (read_offset < per_peer_bytes) {
std::copy(
send_src + read_offset,
send_src +
std::min(
read_offset + static_cast<int64_t>(N), per_peer_bytes),
send_buffer(sz, b).begin<char>());
send_to(sz, peer, b);
in_flight++;
read_offset += N;
}
} else if (work_type == RECV_WR) {
std::copy(
recv_buffer(sz, b, peer).begin<char>(),
recv_buffer(sz, b, peer).begin<char>() +
std::min(
static_cast<int64_t>(N), per_peer_bytes - write_offset),
recv_dst + write_offset);
write_offset += N;

if (write_offset + (PIPELINE - 1) * N < per_peer_bytes) {
recv_from(sz, peer, b);
in_flight++;
}
}
}
}
});
}

} // namespace mlx::core::distributed::jaccl
1 change: 1 addition & 0 deletions mlx/distributed/jaccl/mesh.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class MeshGroup : public GroupImpl {
void all_max(const array& input, array& output, Stream stream) override;
void all_min(const array& input, array& output, Stream stream) override;
void all_gather(const array& input, array& output, Stream stream) override;
void all_to_all(const array& input, array& output, Stream stream) override;
void send(const array& input, int dst, Stream stream) override;
void recv(array& out, int src, Stream stream) override;

Expand Down
4 changes: 4 additions & 0 deletions mlx/distributed/jaccl/ring.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class RingGroup : public GroupImpl {
throw std::runtime_error("[jaccl] sum_scatter not supported.");
}

void all_to_all(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[jaccl] all_to_all not supported.");
}

std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
throw std::runtime_error("[jaccl] Group split not supported.");
}
Expand Down
25 changes: 25 additions & 0 deletions mlx/distributed/mpi/mpi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ struct MPIWrapper {
LOAD_SYMBOL(MPI_Allgather, all_gather);
LOAD_SYMBOL(MPI_Send, send);
LOAD_SYMBOL(MPI_Recv, recv);
LOAD_SYMBOL(MPI_Alltoall, all_to_all);
LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous);
LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit);
LOAD_SYMBOL(MPI_Op_create, mpi_op_create);
Expand Down Expand Up @@ -294,6 +295,14 @@ struct MPIWrapper {
int (*comm_free)(MPI_Comm*);
int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm);
int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*);
int (*all_to_all)(
const void*,
int,
MPI_Datatype,
void*,
int,
MPI_Datatype,
MPI_Comm);

// Objects
MPI_Comm comm_world_;
Expand Down Expand Up @@ -476,6 +485,22 @@ class MPIGroup : public GroupImpl {
throw std::runtime_error("[mpi] sum_scatter not yet implemented.");
}

void all_to_all(const array& input, array& output, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);
encoder.set_output_array(output);
int count = input.size() / size();
encoder.dispatch(
mpi().all_to_all,
input.data<void>(),
count,
mpi().datatype(input),
output.data<void>(),
count,
mpi().datatype(output),
comm_);
}

private:
MPI_Comm comm_;
bool global_;
Expand Down
4 changes: 4 additions & 0 deletions mlx/distributed/nccl/nccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,10 @@ class NCCLGroup : public GroupImpl {
});
}

void all_to_all(const array&, array&, Stream) override {
throw std::runtime_error("[nccl] all_to_all not yet implemented.");
}

template <typename T>
void all_reduce_impl(
const array& input,
Expand Down
23 changes: 23 additions & 0 deletions mlx/distributed/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,27 @@ array sum_scatter(
std::make_shared<ReduceScatter>(stream, group, ReduceScatter::Sum),
{x});
}

array all_to_all(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
return x;
}
if (x.ndim() < 1) {
throw std::invalid_argument("[all_to_all] Input must be at least 1-D.");
}
if (x.shape(0) % group.size() != 0) {
std::ostringstream msg;
msg << "[all_to_all] Invalid shape=" << x.shape() << " for a group of size "
<< group.size()
<< ". The first dimension (axis 0) must be divisible by the group size.";
throw std::invalid_argument(msg.str());
}
auto stream = detail::communication_stream(group, s);
return array(
x.shape(), x.dtype(), std::make_shared<AllToAll>(stream, group), {x});
}
} // namespace mlx::core::distributed
5 changes: 5 additions & 0 deletions mlx/distributed/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,9 @@ MLX_API array sum_scatter(
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});

MLX_API array all_to_all(
const array& x,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});

} // namespace mlx::core::distributed
21 changes: 21 additions & 0 deletions mlx/distributed/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,25 @@ std::pair<std::vector<array>, std::vector<int>> Send::vmap(
return {{send(inputs[0], dst_, group(), stream())}, axes};
}

std::pair<std::vector<array>, std::vector<int>> AllToAll::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
return {{all_to_all(inputs[0], group(), stream())}, axes};
}

std::vector<array> AllToAll::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {all_to_all(tangents[0], group(), stream())};
}

std::vector<array> AllToAll::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
return {all_to_all(cotangents[0], group(), stream())};
}

} // namespace mlx::core::distributed
25 changes: 25 additions & 0 deletions mlx/distributed/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,29 @@ class ReduceScatter : public DistPrimitive {
private:
ReduceType reduce_type_;
};

class AllToAll : public DistPrimitive {
public:
AllToAll(Stream stream, Group group) : DistPrimitive(stream, group) {}

void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;

std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
std::vector<array> jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;

DEFINE_NAME(AllToAll);
};
} // namespace mlx::core::distributed
4 changes: 4 additions & 0 deletions mlx/distributed/ring/ring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,10 @@ class RingGroup : public GroupImpl {
throw std::runtime_error("[ring] sum_scatter not supported.");
}

void all_to_all(const array&, array&, Stream) override {
throw std::runtime_error("[ring] all_to_all not supported.");
}

private:
template <typename T, typename ReduceOp>
void all_reduce(
Expand Down
Loading