Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 183 additions & 48 deletions python/metatomic_torch/metatomic/torch/_extensions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import glob
import hashlib
import os
import re
import shutil
import site
import sys
Expand All @@ -16,22 +17,136 @@
METATENSOR_TORCH_LIB_PATH = metatensor.torch._c_lib._lib_path()


def _rascaline_lib_path():
# This is kept for backward compatibility, but rascaline is now named featomic.
# This code should be removed by the middle of 2025.
import rascaline
def _find_delocate_deps(module, lib_name: str, optional=False):
"""
Find a shared library named `lib_name` that was inserted by delocate inside a wheel
on macOS.

return [rascaline._c_lib._lib_path()]
:param module: module corresponding to the wheel
:param lib_name: name of the library to find
:param optional: should we warn if the library is not found?
"""
assert sys.platform == "darwin"
# delocate puts the dependencies in <wheel>/.dylibs/
search_dir = os.path.join(os.path.dirname(module.__file__), ".dylibs")

libs_list = glob.glob(os.path.join(search_dir, f"{lib_name}.*"))
if len(libs_list) == 0 and not optional:
warnings.warn(
f"No {lib_name} shared library found in '{search_dir}'. "
"This may cause issues when loading and running the model.",
stacklevel=2,
)
elif len(libs_list) > 1:
raise RuntimeError(
f"Multiple {lib_name} shared libraries found in '{search_dir}': "
f"{libs_list}. Try to re-install in a fresh environment."
)
else: # len(libs_list) == 1
return libs_list[0]


def _find_auditwheel_deps(wheel: str, lib_name: str, optional=False):
"""
Find a shared library named `lib_name` that was inserted by auditwheel inside a
wheel on Linux.

:param wheel: name of the wheel/distribution
:param lib_name: name of the library to find
:param optional: should we warn if the library is not found?
"""
assert isinstance(wheel, str)
assert sys.platform.startswith("linux")
# auditwheel puts the dependencies in <wheel>.libs/
search_dir = f"{wheel}.libs/"
libs_list = []

for prefix in site.getsitepackages():
libs_dir = os.path.join(prefix, search_dir)
if os.path.exists(libs_dir):
libs_list = glob.glob(os.path.join(libs_dir, lib_name + "-*.so*"))
if len(libs_list) != 0:
# found it!
break

if len(libs_list) == 0 and not optional:
warnings.warn(
f"No {lib_name} shared library found in '{search_dir}'. "
"This may cause issues when loading and running the model.",
stacklevel=2,
)
elif len(libs_list) > 1:
raise RuntimeError(
f"Multiple {lib_name} shared libraries found in '{search_dir}': "
f"{libs_list}. Try to re-install in a fresh environment."
)
else: # len(libs_list) == 1
return libs_list[0]


def _find_global_deps(lib_name: str, optional=False, only_versionned=False):
"""
Find a shared library named `lib_name` in the global library directory of the
current Python environment.

:param lib_name: name of the library to find
:param optional: should we warn if the library is not found?
:param only_versionned: should we only include versionned shared libraries
(e.g. libxyz.so.12) and exclude unversionned ones (libxyz.so)?
"""
prefix = sys.prefix

if sys.platform.startswith("linux"):
lib_dir = os.path.join(prefix, "lib")
# allow both .so and .so.X.Y.Z
lib_ext = "so*"
elif sys.platform == "darwin":
lib_dir = os.path.join(prefix, "lib")
lib_ext = "dylib"
elif sys.platform == "win32":
lib_dir = os.path.join(prefix, "bin")
lib_ext = "dll"
else:
raise RuntimeError(f"unsupported platform: {sys.platform}")

libs_list = glob.glob(os.path.join(lib_dir, f"{lib_name}*.{lib_ext}"))
if len(libs_list) == 0 and not optional:
warnings.warn(
f"No {lib_name} shared library found in '{lib_dir}'. "
"This may cause issues when loading and running the model.",
stacklevel=2,
)

if only_versionned and len(libs_list) > 1:
versionned_libs = []
for lib in libs_list:
base = os.path.basename(lib)
if sys.platform.startswith("linux"):
if re.search(rf"{lib_name}\.so\.\d+.*", base):
versionned_libs.append(lib)
elif sys.platform == "darwin":
if re.search(rf"{lib_name}\.\d+.*\.dylib", base):
versionned_libs.append(lib)
elif sys.platform == "win32":
# Windows does not have versionned DLLs
pass

if len(versionned_libs) > 0:
libs_list = versionned_libs

return libs_list


def _featomic_deps_path():
import featomic

deps_path = [featomic._c_lib._lib_path()]
deps_path = []
if sys.platform.startswith("linux"):
libgomp_path = _find_auditwheel_deps("featomic_torch", "libgomp")
if libgomp_path is not None:
deps_path.append(libgomp_path)

libgomp_path = _find_openmp_dep("featomic_torch.libs")
if libgomp_path is not None:
deps_path.insert(0, libgomp_path)
deps_path.append(featomic._c_lib._lib_path())

return deps_path

Expand All @@ -41,11 +156,11 @@ def _sphericart_deps_path():

deps_path = []

libgomp_path = _find_openmp_dep("sphericart_torch.libs")
if libgomp_path is not None:
deps_path.append(libgomp_path)

if sys.platform.startswith("linux"):
libgomp_path = _find_auditwheel_deps("sphericart_torch", "libgomp")
if libgomp_path is not None:
deps_path.append(libgomp_path)

# sphericart uses a separate library to get the CUDA stream corresponding to a
# tensor, see https://github.com/lab-cosmo/sphericart/pull/164
sphericart_torch_path = sphericart.torch._lib_path()
Expand All @@ -58,49 +173,69 @@ def _sphericart_deps_path():
return deps_path


def _find_openmp_dep(search_dir):
"""
When building code that uses OpenMP on linux, we typically dynamically link to
libgomp. `cibuildwheel` then copies ``libgomp.so`` to
``<wheel_name>.libs/libgomp-<hash>.so``, so we need to find and add this shared
library to the extensions dependencies.
"""
def _deepmd_deps_path():
import deepmd
import deepmd.lib

deps_path = []

if sys.platform == "darwin":
if deepmd.__version__ <= "3.1.0":
libmpi_path = _find_delocate_deps(deepmd, "libmpi")
if libmpi_path is not None:
deps_path.append(libmpi_path)

libpmpi_path = _find_delocate_deps(deepmd, "libpmpi")
if libpmpi_path is not None:
deps_path.append(libpmpi_path)
else:
# libmpi and libpmpi are no longer bundled since deepmd-kit 3.1.1
# but taken from the `mpich` wheel, which installs them in the
# virtualenv `lib` directory.
deps_path += _find_global_deps("libmpi", only_versionned=True)
deps_path += _find_global_deps("libpmpi", only_versionned=True)

elif sys.platform.startswith("linux"):
libgomp_path = _find_auditwheel_deps("deepmd_kit", "libgomp")
if libgomp_path is not None:
deps_path.append(libgomp_path)

if deepmd.__version__ <= "3.1.0":
libmpi_path = _find_auditwheel_deps("deepmd_kit", "libmpi")
if libmpi_path is not None:
deps_path.append(libmpi_path)

for mpi_dep in ["libfabric", "libucm", "libucp", "libucs", "libuct"]:
mpi_dep_path = _find_auditwheel_deps("deepmd_kit", mpi_dep)
if mpi_dep_path is not None:
deps_path.append(mpi_dep_path)
else:
# pull libmpi from the mpich wheel
deps_path += _find_global_deps("libmpi", only_versionned=True)
# dependencies of libmpi as installed by the mpich wheel
for mpi_dep in ["libfabric", "libucm", "libucp", "libucs", "libuct"]:
deps_path += _find_global_deps("mpich/" + mpi_dep, only_versionned=True)

libs_dir = os.path.dirname(deepmd.lib.__file__)
# libdeepmd.so/deepmd.dll/libdeepmd.dylib
deps_path += list(glob.glob(os.path.join(libs_dir, "*deepmd.*")))

if sys.platform.startswith("linux"):
libs_list = []

site_packages = site.getsitepackages()
if site.ENABLE_USER_SITE:
site_packages.append(site.getusersitepackages())

for prefix in site_packages:
libs_dir = os.path.join(prefix, search_dir)
if os.path.exists(libs_dir):
libs_list = glob.glob(os.path.join(libs_dir, "libgomp-*.so*"))
if len(libs_list) != 0:
# found it!
break

if len(libs_list) == 0:
warnings.warn(
f"No libgomp shared library found in '{search_dir}'. "
"This may cause issues when loading and running the model.",
stacklevel=2,
)
elif len(libs_list) > 1:
raise RuntimeError(
f"Multiple libgomp shared libraries found in '{search_dir}': "
f"{libs_list}. Try to re-install in a fresh environment."
)
else: # len(libs_list) == 1
return libs_list[0]
deps_path += list(glob.glob(os.path.join(libs_dir, "libdeepmd_op_cuda.so")))
deps_path += list(glob.glob(os.path.join(libs_dir, "libdeepmd_dyn_cudart.so")))
# there is also a dependency on libmpi, but it is not distributed in the wheel

# no extra dependencies on Windows, only deepmd.dll

return deps_path


# Manual definition of which TorchScript extensions have their own dependencies. The
# dependencies should be returned in the order they need to be loaded.
EXTENSIONS_WITH_DEPENDENCIES = {
"rascaline_torch": _rascaline_lib_path,
"featomic_torch": _featomic_deps_path,
"sphericart_torch": _sphericart_deps_path,
"deepmd_op_pt": _deepmd_deps_path,
}


Expand Down
Loading