diff --git a/python/metatomic_torch/metatomic/torch/_extensions.py b/python/metatomic_torch/metatomic/torch/_extensions.py index 9ef95830..571f48e7 100644 --- a/python/metatomic_torch/metatomic/torch/_extensions.py +++ b/python/metatomic_torch/metatomic/torch/_extensions.py @@ -1,6 +1,7 @@ import glob import hashlib import os +import re import shutil import site import sys @@ -16,22 +17,136 @@ METATENSOR_TORCH_LIB_PATH = metatensor.torch._c_lib._lib_path() -def _rascaline_lib_path(): - # This is kept for backward compatibility, but rascaline is now named featomic. - # This code should be removed by the middle of 2025. - import rascaline +def _find_delocate_deps(module, lib_name: str, optional=False): + """ + Find a shared library named `lib_name` that was inserted by delocate inside a wheel + on macOS. - return [rascaline._c_lib._lib_path()] + :param module: module corresponding to the wheel + :param lib_name: name of the library to find + :param optional: should we warn if the library is not found? + """ + assert sys.platform == "darwin" + # delocate puts the dependencies in /.dylibs/ + search_dir = os.path.join(os.path.dirname(module.__file__), ".dylibs") + + libs_list = glob.glob(os.path.join(search_dir, f"{lib_name}.*")) + if len(libs_list) == 0 and not optional: + warnings.warn( + f"No {lib_name} shared library found in '{search_dir}'. " + "This may cause issues when loading and running the model.", + stacklevel=2, + ) + elif len(libs_list) > 1: + raise RuntimeError( + f"Multiple {lib_name} shared libraries found in '{search_dir}': " + f"{libs_list}. Try to re-install in a fresh environment." + ) + else: # len(libs_list) == 1 + return libs_list[0] + + +def _find_auditwheel_deps(wheel: str, lib_name: str, optional=False): + """ + Find a shared library named `lib_name` that was inserted by auditwheel inside a + wheel on Linux. + + :param wheel: name of the wheel/distribution + :param lib_name: name of the library to find + :param optional: should we warn if the library is not found? + """ + assert isinstance(wheel, str) + assert sys.platform.startswith("linux") + # auditwheel puts the dependencies in .libs/ + search_dir = f"{wheel}.libs/" + libs_list = [] + + for prefix in site.getsitepackages(): + libs_dir = os.path.join(prefix, search_dir) + if os.path.exists(libs_dir): + libs_list = glob.glob(os.path.join(libs_dir, lib_name + "-*.so*")) + if len(libs_list) != 0: + # found it! + break + + if len(libs_list) == 0 and not optional: + warnings.warn( + f"No {lib_name} shared library found in '{search_dir}'. " + "This may cause issues when loading and running the model.", + stacklevel=2, + ) + elif len(libs_list) > 1: + raise RuntimeError( + f"Multiple {lib_name} shared libraries found in '{search_dir}': " + f"{libs_list}. Try to re-install in a fresh environment." + ) + else: # len(libs_list) == 1 + return libs_list[0] + + +def _find_global_deps(lib_name: str, optional=False, only_versionned=False): + """ + Find a shared library named `lib_name` in the global library directory of the + current Python environment. + + :param lib_name: name of the library to find + :param optional: should we warn if the library is not found? + :param only_versionned: should we only include versionned shared libraries + (e.g. libxyz.so.12) and exclude unversionned ones (libxyz.so)? + """ + prefix = sys.prefix + + if sys.platform.startswith("linux"): + lib_dir = os.path.join(prefix, "lib") + # allow both .so and .so.X.Y.Z + lib_ext = "so*" + elif sys.platform == "darwin": + lib_dir = os.path.join(prefix, "lib") + lib_ext = "dylib" + elif sys.platform == "win32": + lib_dir = os.path.join(prefix, "bin") + lib_ext = "dll" + else: + raise RuntimeError(f"unsupported platform: {sys.platform}") + + libs_list = glob.glob(os.path.join(lib_dir, f"{lib_name}*.{lib_ext}")) + if len(libs_list) == 0 and not optional: + warnings.warn( + f"No {lib_name} shared library found in '{lib_dir}'. " + "This may cause issues when loading and running the model.", + stacklevel=2, + ) + + if only_versionned and len(libs_list) > 1: + versionned_libs = [] + for lib in libs_list: + base = os.path.basename(lib) + if sys.platform.startswith("linux"): + if re.search(rf"{lib_name}\.so\.\d+.*", base): + versionned_libs.append(lib) + elif sys.platform == "darwin": + if re.search(rf"{lib_name}\.\d+.*\.dylib", base): + versionned_libs.append(lib) + elif sys.platform == "win32": + # Windows does not have versionned DLLs + pass + + if len(versionned_libs) > 0: + libs_list = versionned_libs + + return libs_list def _featomic_deps_path(): import featomic - deps_path = [featomic._c_lib._lib_path()] + deps_path = [] + if sys.platform.startswith("linux"): + libgomp_path = _find_auditwheel_deps("featomic_torch", "libgomp") + if libgomp_path is not None: + deps_path.append(libgomp_path) - libgomp_path = _find_openmp_dep("featomic_torch.libs") - if libgomp_path is not None: - deps_path.insert(0, libgomp_path) + deps_path.append(featomic._c_lib._lib_path()) return deps_path @@ -41,11 +156,11 @@ def _sphericart_deps_path(): deps_path = [] - libgomp_path = _find_openmp_dep("sphericart_torch.libs") - if libgomp_path is not None: - deps_path.append(libgomp_path) - if sys.platform.startswith("linux"): + libgomp_path = _find_auditwheel_deps("sphericart_torch", "libgomp") + if libgomp_path is not None: + deps_path.append(libgomp_path) + # sphericart uses a separate library to get the CUDA stream corresponding to a # tensor, see https://github.com/lab-cosmo/sphericart/pull/164 sphericart_torch_path = sphericart.torch._lib_path() @@ -58,49 +173,69 @@ def _sphericart_deps_path(): return deps_path -def _find_openmp_dep(search_dir): - """ - When building code that uses OpenMP on linux, we typically dynamically link to - libgomp. `cibuildwheel` then copies ``libgomp.so`` to - ``.libs/libgomp-.so``, so we need to find and add this shared - library to the extensions dependencies. - """ +def _deepmd_deps_path(): + import deepmd + import deepmd.lib + + deps_path = [] + + if sys.platform == "darwin": + if deepmd.__version__ <= "3.1.0": + libmpi_path = _find_delocate_deps(deepmd, "libmpi") + if libmpi_path is not None: + deps_path.append(libmpi_path) + + libpmpi_path = _find_delocate_deps(deepmd, "libpmpi") + if libpmpi_path is not None: + deps_path.append(libpmpi_path) + else: + # libmpi and libpmpi are no longer bundled since deepmd-kit 3.1.1 + # but taken from the `mpich` wheel, which installs them in the + # virtualenv `lib` directory. + deps_path += _find_global_deps("libmpi", only_versionned=True) + deps_path += _find_global_deps("libpmpi", only_versionned=True) + + elif sys.platform.startswith("linux"): + libgomp_path = _find_auditwheel_deps("deepmd_kit", "libgomp") + if libgomp_path is not None: + deps_path.append(libgomp_path) + + if deepmd.__version__ <= "3.1.0": + libmpi_path = _find_auditwheel_deps("deepmd_kit", "libmpi") + if libmpi_path is not None: + deps_path.append(libmpi_path) + + for mpi_dep in ["libfabric", "libucm", "libucp", "libucs", "libuct"]: + mpi_dep_path = _find_auditwheel_deps("deepmd_kit", mpi_dep) + if mpi_dep_path is not None: + deps_path.append(mpi_dep_path) + else: + # pull libmpi from the mpich wheel + deps_path += _find_global_deps("libmpi", only_versionned=True) + # dependencies of libmpi as installed by the mpich wheel + for mpi_dep in ["libfabric", "libucm", "libucp", "libucs", "libuct"]: + deps_path += _find_global_deps("mpich/" + mpi_dep, only_versionned=True) + + libs_dir = os.path.dirname(deepmd.lib.__file__) + # libdeepmd.so/deepmd.dll/libdeepmd.dylib + deps_path += list(glob.glob(os.path.join(libs_dir, "*deepmd.*"))) + if sys.platform.startswith("linux"): - libs_list = [] - - site_packages = site.getsitepackages() - if site.ENABLE_USER_SITE: - site_packages.append(site.getusersitepackages()) - - for prefix in site_packages: - libs_dir = os.path.join(prefix, search_dir) - if os.path.exists(libs_dir): - libs_list = glob.glob(os.path.join(libs_dir, "libgomp-*.so*")) - if len(libs_list) != 0: - # found it! - break - - if len(libs_list) == 0: - warnings.warn( - f"No libgomp shared library found in '{search_dir}'. " - "This may cause issues when loading and running the model.", - stacklevel=2, - ) - elif len(libs_list) > 1: - raise RuntimeError( - f"Multiple libgomp shared libraries found in '{search_dir}': " - f"{libs_list}. Try to re-install in a fresh environment." - ) - else: # len(libs_list) == 1 - return libs_list[0] + deps_path += list(glob.glob(os.path.join(libs_dir, "libdeepmd_op_cuda.so"))) + deps_path += list(glob.glob(os.path.join(libs_dir, "libdeepmd_dyn_cudart.so"))) + # there is also a dependency on libmpi, but it is not distributed in the wheel + + # no extra dependencies on Windows, only deepmd.dll + + return deps_path # Manual definition of which TorchScript extensions have their own dependencies. The # dependencies should be returned in the order they need to be loaded. EXTENSIONS_WITH_DEPENDENCIES = { - "rascaline_torch": _rascaline_lib_path, "featomic_torch": _featomic_deps_path, "sphericart_torch": _sphericart_deps_path, + "deepmd_op_pt": _deepmd_deps_path, }