From b539fdd337cb6eb62b00d6935cafe31e87bd2205 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 19 Mar 2026 01:08:31 +0100 Subject: [PATCH] [CUDA] Search system-installed CUDA toolkit for headers --- mlx/backend/cuda/jit_module.cpp | 35 +++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 953fdb8a0d..d4f1b4919c 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -26,6 +26,25 @@ void check_nvrtc_error(const char* name, nvrtcResult err) { } } +// Return the default path to CUDA toolkit. +const std::filesystem::path& default_cuda_toolkit_path() { +#if defined(_WIN32) + static auto cached_path = []() -> std::filesystem::path { + std::filesystem::path root( + LR"(C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA)"); + for (auto& file : std::filesystem::directory_iterator(root)) { + if (std::filesystem::exists(file.path() / "include" / "cuda.h")) { + return file.path(); + } + } + return {}; + }(); +#else + static std::filesystem::path cached_path = "/usr/local/cuda"; +#endif + return cached_path; +} + // Return the --include-path args used for invoking NVRTC. const std::vector& include_path_args() { static std::vector cached_args = []() { @@ -47,26 +66,22 @@ const std::vector& include_path_args() { // Add path to CUDA runtime headers, try local-installed python package // first and then system-installed headers. path = root_dir.parent_path() / "nvidia" / "cuda_runtime" / "include"; - if (std::filesystem::exists(path)) { - args.push_back(fmt::format("--include-path={}", path.string())); - } else { + if (!std::filesystem::exists(path)) { const char* home = std::getenv("CUDA_HOME"); if (!home) { home = std::getenv("CUDA_PATH"); } -#if defined(__linux__) - if (!home) { - home = "/usr/local/cuda"; + path = home ? std::filesystem::path(home) : default_cuda_toolkit_path(); + if (!path.empty()) { + path = path / "include"; } -#endif - if (home && std::filesystem::exists(home)) { - args.push_back(fmt::format("--include-path={}/include", home)); - } else { + if (path.empty() || !std::filesystem::exists(path)) { throw std::runtime_error( "Can not find locations of CUDA headers, please set environment " "variable CUDA_HOME or CUDA_PATH."); } } + args.push_back(fmt::format("--include-path={}", path.string())); return args; }(); return cached_args;