Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions mlx/backend/cuda/jit_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,25 @@ void check_nvrtc_error(const char* name, nvrtcResult err) {
}
}

// Return the default path to CUDA toolkit.
const std::filesystem::path& default_cuda_toolkit_path() {
#if defined(_WIN32)
static auto cached_path = []() -> std::filesystem::path {
std::filesystem::path root(
LR"(C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA)");
for (auto& file : std::filesystem::directory_iterator(root)) {
if (std::filesystem::exists(file.path() / "include" / "cuda.h")) {
return file.path();
}
}
return {};
}();
#else
static std::filesystem::path cached_path = "/usr/local/cuda";
#endif
return cached_path;
}

// Return the --include-path args used for invoking NVRTC.
const std::vector<std::string>& include_path_args() {
static std::vector<std::string> cached_args = []() {
Expand All @@ -47,26 +66,22 @@ const std::vector<std::string>& include_path_args() {
// Add path to CUDA runtime headers, try local-installed python package
// first and then system-installed headers.
path = root_dir.parent_path() / "nvidia" / "cuda_runtime" / "include";
if (std::filesystem::exists(path)) {
args.push_back(fmt::format("--include-path={}", path.string()));
} else {
if (!std::filesystem::exists(path)) {
const char* home = std::getenv("CUDA_HOME");
if (!home) {
home = std::getenv("CUDA_PATH");
}
#if defined(__linux__)
if (!home) {
home = "/usr/local/cuda";
path = home ? std::filesystem::path(home) : default_cuda_toolkit_path();
if (!path.empty()) {
path = path / "include";
}
#endif
if (home && std::filesystem::exists(home)) {
args.push_back(fmt::format("--include-path={}/include", home));
} else {
if (path.empty() || !std::filesystem::exists(path)) {
throw std::runtime_error(
"Can not find locations of CUDA headers, please set environment "
"variable CUDA_HOME or CUDA_PATH.");
}
}
args.push_back(fmt::format("--include-path={}", path.string()));
return args;
}();
return cached_args;
Expand Down
Loading