From 80692ec568e3bd979b239890cbd3979b18951276 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 7 Nov 2025 14:31:55 +0100 Subject: [PATCH 1/7] Compatibility with pip-installed openmpi --- mlx/distributed/mpi/mpi.cpp | 13 ++++++++++--- python/mlx/distributed_run.py | 23 +++++++++++++++++++---- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index bf87425e48..a5f9e32950 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -1,5 +1,6 @@ // Copyright © 2024 Apple Inc. +#include #include #include @@ -19,11 +20,17 @@ } \ } +static const char* get_libmpi_name() { + const char* libname = std::getenv("MPI_LIBNAME"); + if (libname != nullptr) { + return libname; + } #ifdef __APPLE__ -static constexpr const char* libmpi_name = "libmpi.dylib"; + return "libmpi.dylib"; #else -static constexpr const char* libmpi_name = "libmpi.so"; + return "libmpi.so"; #endif +} namespace mlx::core::distributed::mpi { @@ -94,7 +101,7 @@ struct MPIWrapper { MPIWrapper() { initialized_ = false; - libmpi_handle_ = dlopen(libmpi_name, RTLD_NOW | RTLD_GLOBAL); + libmpi_handle_ = dlopen(get_libmpi_name(), RTLD_NOW | RTLD_GLOBAL); if (libmpi_handle_ == nullptr) { return; } diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 448e3f9543..bebbf1ad02 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -386,15 +386,30 @@ def node_thread(rank, host, hostfile, input_queue): t.join() +def get_mpi_libname(): + try: + ompi_info = run(["which", "ompi_info"], check=True, capture_output=True) + ompi_info = ompi_info.stdout.strip().decode() + + otool_output = run(["otool", "-L", ompi_info], check=True, capture_output=True) + otool_output = otool_output.stdout.decode() + + # StopIteration if not found + libmpi_line = next(filter(lambda line: "libmpi" in line, otool_output.splitlines())) + return libmpi_line.strip().split()[0].removeprefix("@rpath/") + except: + return None + + def launch_mpi(parser, hosts, args, command): mpirun = run(["which", "mpirun"], check=True, capture_output=True) mpirun = mpirun.stdout.strip().decode() - # Homebrew libmpi doesn't work with anaconda python out of the box. - # TODO: Check if we should do this with every mpirun - if "homebrew" in mpirun: + # Compatibility with homebrew and pip installs + mpi_libname = get_mpi_libname() + if mpi_libname is not None: dyld = Path(mpirun).parent.parent / "lib" - args.env = [f"DYLD_LIBRARY_PATH={str(dyld)}"] + args.env + args.env = [f"DYLD_LIBRARY_PATH={str(dyld)}", f"MPI_LIBNAME={mpi_libname}"] + args.env log(args.verbose, f"Using '{mpirun}'") with tempfile.NamedTemporaryFile(mode="w") as f: From c7a2e09d56e630fe51b20702f73aeb15f5219f07 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 7 Nov 2025 14:47:01 +0100 Subject: [PATCH 2/7] Format --- mlx/distributed/mpi/mpi.cpp | 2 +- python/mlx/distributed_run.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index a5f9e32950..228ddd6fd5 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -1,7 +1,7 @@ // Copyright © 2024 Apple Inc. -#include #include +#include #include #include "mlx/backend/cpu/encoder.h" diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index bebbf1ad02..66661df78f 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -395,7 +395,9 @@ def get_mpi_libname(): otool_output = otool_output.stdout.decode() # StopIteration if not found - libmpi_line = next(filter(lambda line: "libmpi" in line, otool_output.splitlines())) + libmpi_line = next( + filter(lambda line: "libmpi" in line, otool_output.splitlines()) + ) return libmpi_line.strip().split()[0].removeprefix("@rpath/") except: return None @@ -409,7 +411,10 @@ def launch_mpi(parser, hosts, args, command): mpi_libname = get_mpi_libname() if mpi_libname is not None: dyld = Path(mpirun).parent.parent / "lib" - args.env = [f"DYLD_LIBRARY_PATH={str(dyld)}", f"MPI_LIBNAME={mpi_libname}"] + args.env + args.env = [ + f"DYLD_LIBRARY_PATH={str(dyld)}", + f"MPI_LIBNAME={mpi_libname}", + ] + args.env log(args.verbose, f"Using '{mpirun}'") with tempfile.NamedTemporaryFile(mode="w") as f: From cdc3cdcca8203fb138b4f0cb4c6933a499aa896c Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 7 Nov 2025 17:21:07 +0100 Subject: [PATCH 3/7] Support Linux CPU --- python/mlx/distributed_run.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 66661df78f..0459c55a42 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -5,6 +5,7 @@ import ipaddress import json import os +import platform import shlex import shutil import sys @@ -391,7 +392,10 @@ def get_mpi_libname(): ompi_info = run(["which", "ompi_info"], check=True, capture_output=True) ompi_info = ompi_info.stdout.strip().decode() - otool_output = run(["otool", "-L", ompi_info], check=True, capture_output=True) + if platform.system() == "Darwin": + otool_output = run(["otool", "-L", ompi_info], check=True, capture_output=True) + else: + otool_output = run(["ldd", ompi_info], check=True, capture_output=True) otool_output = otool_output.stdout.decode() # StopIteration if not found From 744890562b72ba49ba3df189ec2e823f1cfb3624 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 7 Nov 2025 17:23:05 +0100 Subject: [PATCH 4/7] Format --- python/mlx/distributed_run.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 0459c55a42..064796656d 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -393,7 +393,9 @@ def get_mpi_libname(): ompi_info = ompi_info.stdout.strip().decode() if platform.system() == "Darwin": - otool_output = run(["otool", "-L", ompi_info], check=True, capture_output=True) + otool_output = run( + ["otool", "-L", ompi_info], check=True, capture_output=True + ) else: otool_output = run(["ldd", ompi_info], check=True, capture_output=True) otool_output = otool_output.stdout.decode() From 8babd2779207ad61a9a98645e72f0d4dd6c9c5e0 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 7 Nov 2025 17:44:25 +0100 Subject: [PATCH 5/7] Update docs --- docs/src/usage/distributed.rst | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/docs/src/usage/distributed.rst b/docs/src/usage/distributed.rst index f1a1e4e949..0b83709e27 100644 --- a/docs/src/usage/distributed.rst +++ b/docs/src/usage/distributed.rst @@ -7,12 +7,13 @@ Distributed Communication MLX supports distributed communication operations that allow the computational cost of training or inference to be shared across many physical machines. At the -moment we support two different communication backends: +moment we support three different communication backends: * `MPI `_ a full-featured and mature distributed communications library -* A **ring** backend of our own that uses native TCP sockets and should be - faster for thunderbolt connections. +* A **ring** backend of our own that uses native TCP sockets. It should be + faster for thunderbolt connections, but it also works over Ethernet. +* `nccl `_, for use in CUDA environments. The list of all currently supported operations and their documentation can be seen in the :ref:`API docs`. @@ -84,9 +85,8 @@ Selecting Backend ^^^^^^^^^^^^^^^^^ You can select the backend you want to use when calling :func:`init` by passing -one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to -initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they -both fail then a singleton group is created. +one of ``{'any', 'ring', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all +available backends. If they all fail then a singleton group is created. .. note:: After a distributed backend is successfully initialized :func:`init` will @@ -220,7 +220,7 @@ print 4 etc. Installing MPI ^^^^^^^^^^^^^^ -MPI can be installed with Homebrew, using the Anaconda package manager or +MPI can be installed with Homebrew, pip, using the Anaconda package manager, or compiled from source. Most of our testing is done using ``openmpi`` installed with the Anaconda package manager as follows: @@ -228,14 +228,16 @@ with the Anaconda package manager as follows: $ conda install conda-forge::openmpi -Installing with Homebrew may require specifying the location of ``libmpi.dyld`` +Installing with Homebrew or pip requires specifying the location of ``libmpi.dyld`` so that MLX can find it and load it at runtime. This can simply be achieved by passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is -done automatically by ``mlx.launch``. +done automatically by ``mlx.launch``. Some environments use a non-standard +library filename that can be specified using the ``MPI_LIBNAME`` environment +variable. This is automatically taken care of by ``mlx.launch`` as well. .. code:: shell - $ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py + $ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ -x MPI_LIBNAME=libmpi.40.dylib python test.py $ # or simply $ mlx.launch -n 2 test.py From fb055fd3da450acc2e6e802799efb9ff25bc0008 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 7 Nov 2025 15:07:30 -0800 Subject: [PATCH 6/7] Change environment variable from MPI_LIBNAME to MLX_MPI_LIBNAME --- mlx/distributed/mpi/mpi.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index 228ddd6fd5..3b176e6e67 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -21,7 +21,7 @@ } static const char* get_libmpi_name() { - const char* libname = std::getenv("MPI_LIBNAME"); + const char* libname = std::getenv("MLX_MPI_LIBNAME"); if (libname != nullptr) { return libname; } From 7dd51b3b8105afbe04ab65f55d72915a430e5369 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 7 Nov 2025 15:08:13 -0800 Subject: [PATCH 7/7] Rename MPI_LIBNAME to MLX_MPI_LIBNAME --- python/mlx/distributed_run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 064796656d..e4b50a5cef 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -419,7 +419,7 @@ def launch_mpi(parser, hosts, args, command): dyld = Path(mpirun).parent.parent / "lib" args.env = [ f"DYLD_LIBRARY_PATH={str(dyld)}", - f"MPI_LIBNAME={mpi_libname}", + f"MLX_MPI_LIBNAME={mpi_libname}", ] + args.env log(args.verbose, f"Using '{mpirun}'")