Skip to content

Commit dc37789

Browse files
authored
refactor: pass hopper deepgemm include directory through python (#2090)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description This PR implements the refactor mentioned in https://github.com/flashinfer-ai/flashinfer/pull/1969/files#r2461856020 In our current design we rely on calling `pip show flashinfer-python 2>/dev/null || uv pip show flashinfer-python 2>/dev/null` to obtain deepgemm jit include directory, which is error-prune (e.g. if user do not have `pip` available in the environment it will fail), in this PR we pass the deepgemm jit include directory through python APIs. ## πŸ” Related Issues #1969 ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc @djmmoss @jiahanc @nvmbreughe <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Modules now set DeepGEMM JIT include directories at runtime so fused MoE modules have correct JIT include paths during initialization. * **Chores** * JIT compiler API and module build updated to accept and propagate externally provided include directories. * Minor header/build adjustments to support the new initialization flow. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent b14408b commit dc37789

File tree

7 files changed

+82
-83
lines changed

7 files changed

+82
-83
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <tvm/ffi/extra/module.h>
18+
19+
#include <filesystem>
20+
21+
#include "nv_internal/tensorrt_llm/deep_gemm/compiler.cuh"
22+
23+
namespace flashinfer {
24+
25+
void set_deepgemm_jit_include_dirs(tvm::ffi::Array<tvm::ffi::String> include_dirs) {
26+
std::vector<std::filesystem::path> dirs;
27+
for (const auto& dir : include_dirs) {
28+
dirs.push_back(std::filesystem::path(std::string(dir)));
29+
}
30+
deep_gemm::jit::Compiler::setIncludeDirs(dirs);
31+
}
32+
33+
} // namespace flashinfer
34+
35+
TVM_FFI_DLL_EXPORT_TYPED_FUNC(set_deepgemm_jit_include_dirs,
36+
flashinfer::set_deepgemm_jit_include_dirs);

β€Žcsrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuhβ€Ž

Lines changed: 25 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
#include "nvrtc.h"
3737
#include "runtime.cuh"
3838
#include "scheduler.cuh"
39+
#include "tensorrt_llm/common/assert.h"
40+
#include "tensorrt_llm/common/cudaUtils.h"
41+
#include "tensorrt_llm/common/logger.h"
3942

4043
#ifdef _WIN32
4144
#include <windows.h>
@@ -44,7 +47,7 @@
4447
namespace deep_gemm::jit {
4548

4649
// Generate a unique ID for temporary directories to avoid collisions
47-
std::string generateUniqueId() {
50+
inline std::string generateUniqueId() {
4851
// Use current time and random number to generate a unique ID
4952
static std::mt19937 gen(std::random_device{}());
5053
static std::uniform_int_distribution<> distrib(0, 999999);
@@ -59,7 +62,7 @@ std::string generateUniqueId() {
5962
return std::to_string(value) + "_" + std::to_string(random_value);
6063
}
6164

62-
std::filesystem::path getDefaultUserDir() {
65+
inline std::filesystem::path getDefaultUserDir() {
6366
static std::filesystem::path userDir;
6467
if (userDir.empty()) {
6568
char const* cacheDir = getenv("TRTLLM_DG_CACHE_DIR");
@@ -91,7 +94,7 @@ inline std::filesystem::path getTmpDir() { return getDefaultUserDir() / "tmp"; }
9194

9295
inline std::filesystem::path getCacheDir() { return getDefaultUserDir() / "cache"; }
9396

94-
std::string getNvccCompiler() {
97+
inline std::string getNvccCompiler() {
9598
static std::string compiler;
9699
if (compiler.empty()) {
97100
// Check environment variable
@@ -121,75 +124,21 @@ std::string getNvccCompiler() {
121124
return compiler;
122125
}
123126

124-
std::vector<std::filesystem::path> getJitIncludeDirs() {
127+
inline std::vector<std::filesystem::path>& getJitIncludeDirs() {
125128
static std::vector<std::filesystem::path> includeDirs;
126-
if (includeDirs.empty()) {
127-
// Command to execute - try pip first, fallback to uv pip
128-
char const* cmd =
129-
"pip show flashinfer-python 2>/dev/null || uv pip show flashinfer-python 2>/dev/null";
130-
131-
// Buffer to store the output
132-
std::array<char, 128> buffer;
133-
std::string result;
134-
135-
// Open pipe to command
136-
#ifdef _MSC_VER
137-
FILE* pipe = _popen(cmd, "r");
138-
#else
139-
FILE* pipe = popen(cmd, "r");
140-
#endif
141-
142-
if (pipe) {
143-
// Read the output
144-
while (fgets(buffer.data(), buffer.size(), pipe) != nullptr) {
145-
result += buffer.data();
146-
}
147-
148-
// Close the pipe
149-
#ifdef _MSC_VER
150-
_pclose(pipe);
151-
#else
152-
pclose(pipe);
153-
#endif
154-
155-
// Parse the location using regex
156-
// `pip show tensorrt_llm` will output something like:
157-
// Location: /usr/local/lib/python3.12/dist-packages
158-
// Editable project location: /code
159-
std::regex locationRegex("(Location|Editable project location): (.+)");
160-
161-
// Find all matches
162-
auto match_begin = std::sregex_iterator(result.begin(), result.end(), locationRegex);
163-
auto match_end = std::sregex_iterator();
164-
165-
// Get the number of matches
166-
auto match_count = std::distance(match_begin, match_end);
167-
168-
if (match_count > 0) {
169-
// Get the last match
170-
auto last_match_iter = match_begin;
171-
std::advance(last_match_iter, match_count - 1);
172-
173-
// Get the path from the second capture group
174-
std::string location = last_match_iter->str(2);
175-
location.erase(location.find_last_not_of(" \n\r\t") + 1);
176-
177-
// Set the include directory based on the package location
178-
includeDirs.push_back(std::filesystem::path(location) / "flashinfer" / "data" / "csrc" /
179-
"nv_internal" / "tensorrt_llm");
180-
}
181-
} else {
182-
TLLM_LOG_WARNING("Failed to find FlashInfer installation, DeepGEMM will be disabled.");
183-
}
184-
}
185129
return includeDirs;
186130
}
187131

188-
std::string generateKernel(uint32_t const shape_n, uint32_t const shape_k, uint32_t const block_m,
189-
uint32_t const block_n, uint32_t const block_k,
190-
uint32_t const num_groups, uint32_t const num_stages,
191-
uint32_t const num_tma_multicast, deep_gemm::GemmType const gemm_type,
192-
bool swapAB = false) {
132+
inline void setJitIncludeDirs(std::vector<std::filesystem::path> const& dirs) {
133+
static std::vector<std::filesystem::path>& includeDirs = getJitIncludeDirs();
134+
includeDirs = dirs;
135+
}
136+
137+
inline std::string generateKernel(uint32_t const shape_n, uint32_t const shape_k,
138+
uint32_t const block_m, uint32_t const block_n,
139+
uint32_t const block_k, uint32_t const num_groups,
140+
uint32_t const num_stages, uint32_t const num_tma_multicast,
141+
deep_gemm::GemmType const gemm_type, bool swapAB = false) {
193142
constexpr uint32_t kNumTMAThreads = 128;
194143
constexpr uint32_t kNumMathThreadsPerGroup = 128;
195144

@@ -289,7 +238,12 @@ class Compiler {
289238
return instance;
290239
}
291240

292-
[[nodiscard]] bool isValid() const { return !includeDirs_.empty(); }
241+
[[nodiscard]] bool isValid() const { return !getJitIncludeDirs().empty(); }
242+
243+
// Set include directories before the singleton is initialized
244+
static void setIncludeDirs(std::vector<std::filesystem::path> const& dirs) {
245+
setJitIncludeDirs(dirs);
246+
}
293247

294248
// Build function
295249
Runtime* build(uint32_t const shape_n, uint32_t const shape_k, uint32_t const block_m,
@@ -362,7 +316,7 @@ class Compiler {
362316
std::filesystem::create_directories(path);
363317
}
364318

365-
for (auto const& dir : includeDirs_) {
319+
for (auto const& dir : getJitIncludeDirs()) {
366320
flags.push_back("-I" + dir.string());
367321
}
368322

@@ -518,10 +472,8 @@ class Compiler {
518472
}
519473

520474
private:
521-
std::vector<std::filesystem::path> includeDirs_;
522-
523475
// Private constructor for singleton pattern
524-
Compiler() : includeDirs_(getJitIncludeDirs()) {
476+
Compiler() {
525477
// Create necessary directories
526478
if (kJitUseNvcc || kJitDumpCubin) {
527479
std::filesystem::create_directories(getTmpDir());

β€Žcsrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuhβ€Ž

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
*/
1717

1818
#pragma once
19+
#include <cuda.h>
1920
#include <cuda_runtime.h>
2021
#include <nvrtc.h>
2122

@@ -67,7 +68,7 @@ GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t sha
6768

6869
namespace deep_gemm::jit {
6970

70-
std::string gemm_type_to_string(deep_gemm::GemmType gemm_type) {
71+
inline std::string gemm_type_to_string(deep_gemm::GemmType gemm_type) {
7172
switch (gemm_type) {
7273
case deep_gemm::GemmType::Normal:
7374
return std::string("Normal");
@@ -85,10 +86,10 @@ std::string gemm_type_to_string(deep_gemm::GemmType gemm_type) {
8586
}
8687
}
8788

88-
int div_up(int a, int b) { return (a + b - 1) / b; }
89+
inline int div_up(int a, int b) { return (a + b - 1) / b; }
8990

90-
int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k = 128,
91-
bool swap_ab = false) {
91+
inline int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k = 128,
92+
bool swap_ab = false) {
9293
if (!swap_ab) {
9394
int smem_d = block_m * block_n * 2;
9495
int smem_a_per_stage = block_m * block_k;
@@ -126,16 +127,16 @@ int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k =
126127
}
127128
}
128129

129-
bool is_tma_multicast_legal(int n, int block_n, int num_tma_multicast, int num_sms) {
130+
inline bool is_tma_multicast_legal(int n, int block_n, int num_tma_multicast, int num_sms) {
130131
if (num_tma_multicast == 1) {
131132
return true;
132133
}
133134
return (n % (block_n * num_tma_multicast) == 0) && num_sms % num_tma_multicast == 0;
134135
}
135136

136-
GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
137-
int num_groups, int num_device_sms,
138-
bool is_grouped_contiguous = false, bool swap_ab = false) {
137+
inline GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
138+
int num_groups, int num_device_sms,
139+
bool is_grouped_contiguous = false, bool swap_ab = false) {
139140
// Choose candidate block sizes
140141
std::vector<int> block_ms;
141142
block_ms.push_back((!is_grouped_contiguous && shape_m <= 64) ? 64 : 128);

β€Žcsrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuhβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,6 @@ class RuntimeCache {
181181
};
182182

183183
// Global function to access the singleton
184-
RuntimeCache& getGlobalRuntimeCache() { return RuntimeCache::getInstance(); }
184+
inline RuntimeCache& getGlobalRuntimeCache() { return RuntimeCache::getInstance(); }
185185

186186
} // namespace deep_gemm::jit

β€Žflashinfer/fused_moe/core.pyβ€Ž

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,14 @@ def get_cutlass_fused_moe_module(backend: str = "100", use_fast_build: bool = Fa
328328
else:
329329
raise ValueError(f"Invalid backend: {backend}")
330330

331+
# Set DeepGEMM JIT include directories after module is loaded
332+
from ..jit import env as jit_env
333+
334+
deepgemm_include_dir = str(
335+
jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "tensorrt_llm"
336+
)
337+
module.set_deepgemm_jit_include_dirs([deepgemm_include_dir])
338+
331339
class MoERunner(TunableRunner):
332340
# avoid overhead of creating a new runner in forward pass
333341
runner_dict: Dict[

β€Žflashinfer/jit/fused_moe.pyβ€Ž

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,9 @@ def gen_cutlass_fused_moe_module(
164164
jit_env.FLASHINFER_CSRC_DIR
165165
/ "nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.cu",
166166
jit_env.FLASHINFER_CSRC_DIR
167-
/ "fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu",
167+
/ "fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu",
168+
jit_env.FLASHINFER_CSRC_DIR
169+
/ "fused_moe/cutlass_backend/deepgemm_jit_setup.cu",
168170
jit_env.FLASHINFER_CSRC_DIR
169171
/ "fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu",
170172
# Add all generated kernels

0 commit comments

Comments
Β (0)