diff --git a/setup.py b/setup.py index e54e5c8..a0ecd4d 100644 --- a/setup.py +++ b/setup.py @@ -64,15 +64,26 @@ def parse_library_names(libdir): "-Wl,-rpath,@loader_path/../kintera/lib", ] else: + cuda_linker = [] + cuda_libraries = [lib for lib in libraries if "cuda" in lib] + + if cuda_libraries: + for lib in cuda_libraries: + libraries.remove(lib) + cuda_linker = ( + ["-Wl,--no-as-needed"] + + [f"-l{lib}" for lib in cuda_libraries] + + ["-Wl,--as-needed"] + ) + extra_link_args = [ - "-Wl,--no-as-needed", "-Wl,-rpath,$ORIGIN/lib", "-Wl,-rpath,$ORIGIN/../torch/lib", "-Wl,-rpath,$ORIGIN/../pydisort/lib", "-Wl,-rpath,$ORIGIN/../pyharp/lib", "-Wl,-rpath,$ORIGIN/../kintera/lib", - "-Wl,--as-needed", ] + extra_link_args += cuda_linker ext_module = cpp_extension.CppExtension( name='snapy.snapy',