Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions docs/src/usage/distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@ Distributed Communication

MLX supports distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. At the
moment we support two different communication backends:
moment we support three different communication backends:

* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
full-featured and mature distributed communications library
* A **ring** backend of our own that uses native TCP sockets and should be
faster for thunderbolt connections.
* A **ring** backend of our own that uses native TCP sockets. It should be
faster for thunderbolt connections, but it also works over Ethernet.
* `nccl <https://developer.nvidia.com/nccl>`_, for use in CUDA environments.

The list of all currently supported operations and their documentation can be
seen in the :ref:`API docs<distributed>`.
Expand Down Expand Up @@ -84,9 +85,8 @@ Selecting Backend
^^^^^^^^^^^^^^^^^

You can select the backend you want to use when calling :func:`init` by passing
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
both fail then a singleton group is created.
one of ``{'any', 'ring', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
available backends. If they all fail then a singleton group is created.

.. note::
After a distributed backend is successfully initialized :func:`init` will
Expand Down Expand Up @@ -220,22 +220,24 @@ print 4 etc.
Installing MPI
^^^^^^^^^^^^^^

MPI can be installed with Homebrew, using the Anaconda package manager or
MPI can be installed with Homebrew, pip, using the Anaconda package manager, or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:

.. code:: shell

$ conda install conda-forge::openmpi

Installing with Homebrew may require specifying the location of ``libmpi.dyld``
Installing with Homebrew or pip requires specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
done automatically by ``mlx.launch``.
done automatically by ``mlx.launch``. Some environments use a non-standard
library filename that can be specified using the ``MPI_LIBNAME`` environment
variable. This is automatically taken care of by ``mlx.launch`` as well.

.. code:: shell

$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ -x MPI_LIBNAME=libmpi.40.dylib python test.py
$ # or simply
$ mlx.launch -n 2 test.py

Expand Down
13 changes: 10 additions & 3 deletions mlx/distributed/mpi/mpi.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.

#include <dlfcn.h>
#include <cstdlib>
#include <iostream>

#include "mlx/backend/cpu/encoder.h"
Expand All @@ -19,11 +20,17 @@
} \
}

static const char* get_libmpi_name() {
const char* libname = std::getenv("MLX_MPI_LIBNAME");
if (libname != nullptr) {
return libname;
}
#ifdef __APPLE__
static constexpr const char* libmpi_name = "libmpi.dylib";
return "libmpi.dylib";
#else
static constexpr const char* libmpi_name = "libmpi.so";
return "libmpi.so";
#endif
}

namespace mlx::core::distributed::mpi {

Expand Down Expand Up @@ -94,7 +101,7 @@ struct MPIWrapper {
MPIWrapper() {
initialized_ = false;

libmpi_handle_ = dlopen(libmpi_name, RTLD_NOW | RTLD_GLOBAL);
libmpi_handle_ = dlopen(get_libmpi_name(), RTLD_NOW | RTLD_GLOBAL);
if (libmpi_handle_ == nullptr) {
return;
}
Expand Down
34 changes: 30 additions & 4 deletions python/mlx/distributed_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ipaddress
import json
import os
import platform
import shlex
import shutil
import sys
Expand Down Expand Up @@ -386,15 +387,40 @@ def node_thread(rank, host, hostfile, input_queue):
t.join()


def get_mpi_libname():
try:
ompi_info = run(["which", "ompi_info"], check=True, capture_output=True)
ompi_info = ompi_info.stdout.strip().decode()

if platform.system() == "Darwin":
otool_output = run(
["otool", "-L", ompi_info], check=True, capture_output=True
)
else:
otool_output = run(["ldd", ompi_info], check=True, capture_output=True)
otool_output = otool_output.stdout.decode()

# StopIteration if not found
libmpi_line = next(
filter(lambda line: "libmpi" in line, otool_output.splitlines())
)
return libmpi_line.strip().split()[0].removeprefix("@rpath/")
except:
return None


def launch_mpi(parser, hosts, args, command):
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
mpirun = mpirun.stdout.strip().decode()

# Homebrew libmpi doesn't work with anaconda python out of the box.
# TODO: Check if we should do this with every mpirun
if "homebrew" in mpirun:
# Compatibility with homebrew and pip installs
mpi_libname = get_mpi_libname()
if mpi_libname is not None:
dyld = Path(mpirun).parent.parent / "lib"
args.env = [f"DYLD_LIBRARY_PATH={str(dyld)}"] + args.env
args.env = [
f"DYLD_LIBRARY_PATH={str(dyld)}",
f"MLX_MPI_LIBNAME={mpi_libname}",
] + args.env

log(args.verbose, f"Using '{mpirun}'")
with tempfile.NamedTemporaryFile(mode="w") as f:
Expand Down