From a654093141b9c26df6ca1d76c522b8972d194925 Mon Sep 17 00:00:00 2001 From: raayandhar Date: Sun, 9 Nov 2025 19:53:24 -0800 Subject: [PATCH 01/11] (in progress) BF16 GEMM using CUTLASS backend for SM100 Signed-off-by: raayandhar --- csrc/bf16_gemm_cutlass.cu | 161 ++++++++++ csrc/bf16_gemm_cutlass.jinja | 27 ++ docs/api/gemm.rst | 9 + flashinfer/__init__.py | 2 + flashinfer/gemm/gemm_base.py | 284 +++++++++++++++++- flashinfer/jit/gemm/__init__.py | 2 + flashinfer/jit/gemm/core.py | 47 +++ include/flashinfer/gemm/bf16_gemm_cutlass.h | 62 ++++ .../gemm/bf16_gemm_cutlass_template.h | 215 +++++++++++++ .../gemm/bf16_gemm_template_sm100.h | 187 ++++++++++++ .../gemm/fp8_gemm_cutlass_template.h | 1 - tests/gemm/test_bmm_bf16.py | 38 +++ tests/gemm/test_mm_bf16.py | 37 +++ 13 files changed, 1061 insertions(+), 11 deletions(-) create mode 100644 csrc/bf16_gemm_cutlass.cu create mode 100644 csrc/bf16_gemm_cutlass.jinja create mode 100644 include/flashinfer/gemm/bf16_gemm_cutlass.h create mode 100644 include/flashinfer/gemm/bf16_gemm_cutlass_template.h create mode 100644 include/flashinfer/gemm/bf16_gemm_template_sm100.h create mode 100644 tests/gemm/test_bmm_bf16.py create mode 100644 tests/gemm/test_mm_bf16.py diff --git a/csrc/bf16_gemm_cutlass.cu b/csrc/bf16_gemm_cutlass.cu new file mode 100644 index 0000000000..f5e3750ad2 --- /dev/null +++ b/csrc/bf16_gemm_cutlass.cu @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2025, FlashInfer. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include + +#include "flashinfer/gemm/bf16_gemm_cutlass.h" +#include "flashinfer/gemm/bf16_gemm_cutlass_template.h" +#include "flashinfer/gemm/cutlass_gemm_configs.h" +#include "tvm_ffi_utils.h" + +using flashinfer::gemm::ClusterShape; +using flashinfer::gemm::CutlassBf16GemmRunner; +using flashinfer::gemm::CutlassBf16GemmRunnerInterface; +using flashinfer::gemm::CutlassGemmConfig; +using flashinfer::gemm::CutlassTileConfigSM100; +using flashinfer::gemm::EpilogueScheduleType; +using flashinfer::gemm::MainloopScheduleType; + +namespace flashinfer { +namespace gemm { +template class CutlassBf16GemmRunner<__nv_bfloat16>; +template class CutlassBf16GemmRunner; +} // namespace gemm +} // namespace flashinfer + +namespace torch_ext { + +namespace { + +CutlassGemmConfig getBf16GemmConfig(int64_t m, int64_t n, int64_t k, int64_t tactic) { + auto getCutlassBf16GemmConfigs = []() { + CutlassBf16GemmRunner<__nv_bfloat16> gemmRunner; + return gemmRunner.getConfigs(); + }; + static std::vector globalConfigs = getCutlassBf16GemmConfigs(); + TVM_FFI_ICHECK(tactic >= 0 && tactic < static_cast(globalConfigs.size())) + << "tactic must be between 0 and " << globalConfigs.size(); + return globalConfigs[tactic]; +} + +template +void runGemm(TensorView out, TensorView mat1, TensorView mat2, int64_t m, int64_t n, int64_t k, + int64_t b, CutlassGemmConfig const& gemmConfig, TensorView workspace_buffer) { + CutlassBf16GemmRunner gemmRunner; + + int64_t const required_workspace_size = gemmRunner.getWorkspaceSize(m, n, k); + int64_t const provided_workspace_size = + workspace_buffer.numel() * get_element_size(workspace_buffer); + + auto runKernel = [&](void* workspace) { + gemmRunner.gemm(static_cast<__nv_bfloat16*>(mat1.data_ptr()), + static_cast<__nv_bfloat16*>(mat2.data_ptr()), out.data_ptr(), m, n, k, b, + gemmConfig, static_cast(workspace), required_workspace_size, + get_stream(mat1.device())); + }; + + if (provided_workspace_size < required_workspace_size) { + Tensor new_workspace = + alloc_tensor({required_workspace_size}, DLDataType{kDLInt, 8, 1}, mat1.device()); + runKernel(new_workspace.data_ptr()); + } else { + runKernel(workspace_buffer.data_ptr()); + } +} + +void bf16_bmm_impl(TensorView mat1, TensorView mat2, TensorView out, TensorView workspace_buffer, + int64_t tactic) { + CHECK_INPUT_AND_TYPE(mat1, dl_bfloat16); + CHECK_INPUT_AND_TYPE(mat2, dl_bfloat16); + + int64_t m, n, k, b; + if (mat1.ndim() == 2) { + TVM_FFI_ICHECK_EQ(mat2.ndim(), 2) << "mat2 must be a matrix"; + TVM_FFI_ICHECK_EQ(mat1.size(1), mat2.size(1)) + << "mat1 and mat2 shapes cannot be multiplied (" << mat1.size(0) << "x" << mat1.size(1) + << " and " << mat2.size(0) << "x" << mat2.size(1) << ")"; + m = mat1.size(0); + n = mat2.size(0); + k = mat2.size(1); + b = 1; + } else if (mat1.ndim() == 3) { + TVM_FFI_ICHECK_EQ(mat2.ndim(), 3) << "mat2 must be a batch of matrices"; + TVM_FFI_ICHECK_EQ(mat1.size(0), mat2.size(0)) << "mat1 and mat2 must have the same batch size (" + << mat1.size(0) << " and " << mat2.size(0) << ")"; + TVM_FFI_ICHECK_EQ(mat1.size(2), mat2.size(2)) + << "mat1 and mat2 shapes cannot be multiplied (" << mat1.size(1) << "x" << mat1.size(2) + << " and " << mat2.size(1) << "x" << mat2.size(2) << ")"; + m = mat1.size(1); + n = mat2.size(1); + k = mat2.size(2); + b = mat1.size(0); + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "mat1 must be a matrix or a batch of matrices"; + } + + if (tactic == -1) { + tactic = 0; + } + auto config = getBf16GemmConfig(m, n, k, tactic); + + std::vector out_shape = + mat1.ndim() == 2 ? std::vector{m, n} : std::vector{b, m, n}; + TVM_FFI_ICHECK_EQ(out.ndim(), static_cast(out_shape.size())) + << "out must have " << out_shape.size() << " dimensions, but got " << out.ndim(); + for (int i = 0; i < static_cast(out_shape.size()); ++i) { + TVM_FFI_ICHECK_EQ(out.size(i), out_shape[i]) + << "out shape mismatch at dimension " << i << ": expected " << out_shape[i] << ", got " + << out.size(i); + } + + switch (encode_dlpack_dtype(out.dtype())) { + case float16_code: + runGemm(out, mat1, mat2, m, n, k, b, config, workspace_buffer); + break; + case bfloat16_code: + runGemm<__nv_bfloat16>(out, mat1, mat2, m, n, k, b, config, workspace_buffer); + break; + default: + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "out_dtype must be one of fp16/bf16."; + } +} + +} // namespace + +void bf16_gemm(TensorView mat1, TensorView mat2, TensorView out, TensorView workspace_buffer, + int64_t tactic) { + bf16_bmm_impl(mat1, mat2, out, workspace_buffer, tactic); +} + +int64_t bf16_gemm_tactic_num() { + auto getCutlassConfigs = []() { + CutlassBf16GemmRunner<__nv_bfloat16> gemmRunner; + return gemmRunner.getConfigs(); + }; + static int64_t totalTactics = getCutlassConfigs().size(); + return totalTactics; +} + +} // namespace torch_ext + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(bf16_gemm, torch_ext::bf16_gemm); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(bf16_gemm_tactic_num, torch_ext::bf16_gemm_tactic_num); diff --git a/csrc/bf16_gemm_cutlass.jinja b/csrc/bf16_gemm_cutlass.jinja new file mode 100644 index 0000000000..27d6913346 --- /dev/null +++ b/csrc/bf16_gemm_cutlass.jinja @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2025, FlashInfer. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flashinfer/gemm/bf16_gemm_template_sm100.h" + +namespace flashinfer { +namespace gemm { + INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 1, 1, _1SM); + // INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 2, 1, _1SM); + // INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 4, 1, _1SM); + // INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 1, 1, _2SM); + // INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 2, 1, _2SM); +} // namespace gemm +} // namespace flashinfer diff --git a/docs/api/gemm.rst b/docs/api/gemm.rst index 8c9fbeeea6..c0c99a7f92 100644 --- a/docs/api/gemm.rst +++ b/docs/api/gemm.rst @@ -7,6 +7,15 @@ flashinfer.gemm This module provides a set of GEMM operations. +BF16 GEMM +--------- + +.. autosummary:: + :toctree: ../generated + + mm_bf16 + bmm_bf16 + FP4 GEMM -------- diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index faad4f12a3..9f4d8de1f4 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -85,7 +85,9 @@ trtllm_fp8_per_tensor_scale_moe, ) from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper +from .gemm import bmm_bf16 as bmm_bf16 from .gemm import bmm_fp8 as bmm_fp8 +from .gemm import mm_bf16 as mm_bf16 from .gemm import mm_fp4 as mm_fp4 from .gemm import mm_fp8 as mm_fp8 from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100 diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index ac0fbab4a0..8ed2acd6b2 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -44,16 +44,17 @@ backend_requirement, supported_compute_capability, ) -from ..jit.gemm import gen_gemm_sm90_module -from ..jit.gemm import gen_gemm_module -from ..jit.gemm import gen_gemm_sm100_module -from ..jit.gemm import gen_gemm_sm120_module -from ..jit.gemm import gen_gemm_sm120_module_cutlass_fp4 -from ..jit.gemm import gen_gemm_sm100_module_cutlass_fp4 -from ..jit.gemm import gen_gemm_sm100_module_cutlass_fp8 -from ..jit.gemm import gen_trtllm_gen_gemm_module -from ..jit.gemm import gen_tgv_gemm_sm10x_module -from ..jit.gemm import gen_deepgemm_sm100_module +from .jit.gemm import gen_gemm_sm90_module +from .jit.gemm import gen_gemm_module +from .jit.gemm import gen_gemm_sm100_module +from .jit.gemm import gen_gemm_sm120_module +from .jit.gemm import gen_gemm_sm120_module_cutlass_fp4 +from .jit.gemm import gen_gemm_sm100_module_cutlass_fp4 +from .jit.gemm import gen_gemm_sm100_module_cutlass_fp8 +from .jit.gemm import gen_gemm_sm100_module_cutlass_bf16 +from .jit.gemm import gen_trtllm_gen_gemm_module +from .jit.gemm import gen_tgv_gemm_sm10x_module +from .jit.gemm import gen_deepgemm_sm100_module CUDNN_AVAILABLE = False @@ -178,6 +179,140 @@ def _fake_cutlass_segment_gemm( return _gemm_module +@supported_compute_capability([100]) +def mm_bf16( + a: torch.Tensor, + b: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + backend: Literal["cutlass"] = "cutlass", +) -> torch.Tensor: + r"""MM BF16 + + Parameters + ---------- + a: torch.Tensor + Input tensor, shape (m, k), bf16 row-major. + + b: torch.Tensor + Weight tensor, shape (k, n), bf16 row-major. This tensor is interpreted + as a column-major (n, k) matrix internally. + + out: Optional[torch.Tensor] + Out tensor, shape (m, n), bf16 or fp16, defaults to ``None``. + + out_dtype: torch.dtype + Output dtype, bf16 (default) or fp16. + + backend: Literal["cutlass"] + Backend to use, defaults to "cutlass". + + Returns + ------- + torch.Tensor + Out tensor, shape (m, n), bf16 or fp16. + + # Note: add Examples section here + """ + if backend != "cutlass": + raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.") + if out_dtype not in (torch.bfloat16, torch.float16): + raise ValueError("Only bf16 and fp16 outputs are supported.") + + if out is None: + out = torch.empty( + (a.shape[0], b.shape[1]), + device=a.device, + dtype=out_dtype, + ) + else: + if out.shape != (a.shape[0], b.shape[1]): + raise ValueError( + f"Output shape mismatch. Expected {(a.shape[0], b.shape[1])}, got {out.shape}." + ) + if out.device != a.device: + raise ValueError( + f"Output device mismatch. Expected {a.device}, got {out.device}." + ) + if out.dtype != out_dtype: + raise ValueError( + f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}." + ) + + workspace_buffer = _get_cache_buf( + "mm_bf16_workspace", DEFAULT_WORKSPACE_SIZE, a.device + ) + bf16_gemm_sm100(a, b, out, workspace_buffer) + return out + + +@supported_compute_capability([100]) +def bmm_bf16( + a: torch.Tensor, + b: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + backend: Literal["cutlass"] = "cutlass", +) -> torch.Tensor: + r"""BMM BF16 + + Parameters + ---------- + a: torch.Tensor + Input tensor, shape (b, m, k), bf16. + + b: torch.Tensor + Weight tensor, shape (b, k, n), bf16. + + out: Optional[torch.Tensor] + Out tensor, shape (b, m, n), bf16 or fp16, defaults to ``None``. + + out_dtype: torch.dtype + Output dtype, bf16 (default) or fp16. + + backend: Literal["cutlass"] + Backend to use, defaults to "cutlass". + + Returns + ------- + torch.Tensor + Out tensor, shape (b, m, n), bf16 or fp16. + + # Note: add Examples section here + """ + if backend != "cutlass": + raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.") + if out_dtype not in (torch.bfloat16, torch.float16): + raise ValueError("Only bf16 and fp16 outputs are supported.") + + expected_shape = (a.shape[0], a.shape[1], b.shape[2]) + if out is None: + out = torch.empty( + expected_shape, + device=a.device, + dtype=out_dtype, + ) + else: + if out.shape != expected_shape: + raise ValueError( + f"Output shape mismatch. Expected {expected_shape}, got {out.shape}." + ) + if out.device != a.device: + raise ValueError( + f"Output device mismatch. Expected {a.device}, got {out.device}." + ) + if out.dtype != out_dtype: + raise ValueError( + f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}." + ) + + workspace_buffer = _get_cache_buf( + "bmm_bf16_workspace", DEFAULT_WORKSPACE_SIZE, a.device + ) + bf16_gemm_sm100(a, b, out, workspace_buffer) + return out + + @functools.cache def get_gemm_sm100_module(): module = gen_gemm_sm100_module().build_and_load() @@ -354,6 +489,43 @@ def forward( ) +@functools.cache +def get_gemm_sm100_module_cutlass_bf16(): + module = gen_gemm_sm100_module_cutlass_bf16().build_and_load() + + def cutlass_bf16_gemm_runner(): + class CutlassBf16GemmRunner(TunableRunner): + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + return list(range(module.bf16_gemm_tactic_num())) + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ) -> torch.Tensor: + a, b, out, workspace_buffer = inputs + module.bf16_gemm( + a, + b.transpose(-2, -1), + out, + workspace_buffer, + tactic, + ) + return out + + return CutlassBf16GemmRunner() + + return SimpleNamespace( + cutlass_bf16_gemm_runner=cutlass_bf16_gemm_runner, + ) + + def fp8_gemm_sm100( a: torch.Tensor, b: torch.Tensor, @@ -403,6 +575,98 @@ def fp8_gemm_sm100( runner(inputs=inputs, tactic=tactic) +def bf16_gemm_sm100( + a: torch.Tensor, + b: torch.Tensor, + out: torch.Tensor, + workspace_buffer: torch.Tensor, +) -> None: + runners = [] + is_sm_supported = _match_sm_version(a.device, ["100"]) + + if is_sm_supported: + runners.append(get_gemm_sm100_module_cutlass_bf16().cutlass_bf16_gemm_runner()) + + if len(runners) == 0: + major, minor = get_compute_capability(torch.device("cuda")) + raise ValueError(f"No valid runner found for current device sm{major}{minor}") + + tuner = AutoTuner.get() + a_tensor_index = 0 + out_tensor_index = 2 + tuning_config = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + (a_tensor_index,), + (-2,), + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, + ), + ), + constraint_specs=( + ConstraintSpec( + out_tensor_index, -2, lambda shapes: shapes[a_tensor_index][-2] + ), + ), + ) + + inputs = [a, b, out, workspace_buffer] + runner, tactic = tuner.choose_one( + "bf16_gemm", + runners, + tuning_config, + inputs, + ) + + runner(inputs=inputs, tactic=tactic) + + +def bf16_gemm_sm100( + a: torch.Tensor, + b: torch.Tensor, + out: torch.Tensor, + workspace_buffer: torch.Tensor, +) -> None: + runners = [] + is_sm_supported = _match_sm_version(a.device, ["100"]) + + if is_sm_supported: + runners.append(get_gemm_sm100_module_cutlass_bf16().cutlass_bf16_gemm_runner()) + + if len(runners) == 0: + major, minor = get_compute_capability(torch.device("cuda")) + raise ValueError(f"No valid runner found for current device sm{major}{minor}") + + tuner = AutoTuner.get() + a_tensor_index = 0 + out_tensor_index = 2 + tuning_config = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + (a_tensor_index,), + (-2,), + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, + ), + ), + constraint_specs=( + ConstraintSpec( + out_tensor_index, -2, lambda shapes: shapes[a_tensor_index][-2] + ), + ), + ) + + inputs = [a, b, out, workspace_buffer] + runner, tactic = tuner.choose_one( + "bf16_gemm", + runners, + tuning_config, + inputs, + ) + + runner(inputs=inputs, tactic=tactic) + + def _create_cutlass_fp4_gemm_module(module, op_name: str, tuner_name: str): """Helper function to create cutlass FP4 GEMM module.""" diff --git a/flashinfer/jit/gemm/__init__.py b/flashinfer/jit/gemm/__init__.py index f1681d3bf5..617327c451 100644 --- a/flashinfer/jit/gemm/__init__.py +++ b/flashinfer/jit/gemm/__init__.py @@ -19,6 +19,7 @@ gen_gemm_sm100_module_cutlass_fp4, gen_gemm_sm120_module_cutlass_fp4, gen_gemm_sm100_module_cutlass_fp8, + gen_gemm_sm100_module_cutlass_bf16, gen_gemm_sm100_module, gen_gemm_sm120_module, gen_trtllm_gen_gemm_module, @@ -33,6 +34,7 @@ "gen_gemm_sm100_module_cutlass_fp4", "gen_gemm_sm120_module_cutlass_fp4", "gen_gemm_sm100_module_cutlass_fp8", + "gen_gemm_sm100_module_cutlass_bf16", "gen_gemm_sm100_module", "gen_gemm_sm120_module", "gen_trtllm_gen_gemm_module", diff --git a/flashinfer/jit/gemm/core.py b/flashinfer/jit/gemm/core.py index 7873d0de14..383a38a87e 100644 --- a/flashinfer/jit/gemm/core.py +++ b/flashinfer/jit/gemm/core.py @@ -190,6 +190,53 @@ def gen_gemm_sm100_module_cutlass_fp8() -> JitSpec: ) +def gen_gemm_sm100_module_cutlass_bf16() -> JitSpec: + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100_cutlass_bf16" + os.makedirs(gen_directory, exist_ok=True) + source_paths = [ + jit_env.FLASHINFER_CSRC_DIR / "bf16_gemm_cutlass.cu", + ] + + with open(jit_env.FLASHINFER_CSRC_DIR / "bf16_gemm_cutlass.jinja") as f: + kernel_inst_templ = jinja2.Template(f.read()) + dtype_list = ["__nv_bfloat16", "half"] + cta_m_n_k_list = [ + (64, 64, 128), + # (64, 128, 128), + # (64, 256, 128), + # (128, 64, 128), + # (128, 128, 128), + # (128, 256, 128), + ] + for cta_m, cta_n, cta_k in cta_m_n_k_list: + for dtype in dtype_list: + dest_path = ( + gen_directory + / f"bf16_gemm_cutlass_{dtype}_{cta_m}_{cta_n}_{cta_k}.cu" + ) + source_paths.append(dest_path) + source = kernel_inst_templ.render( + type=dtype, + cta_m=cta_m, + cta_n=cta_n, + cta_k=cta_k, + ) + write_if_different(dest_path, source) + + nvcc_flags = current_compilation_context.get_nvcc_flags_list( + supported_major_versions=[10, 11, 12] + ) + + return gen_jit_spec( + "bf16_gemm_cutlass", + source_paths, + extra_cuda_cflags=nvcc_flags + ["-DENABLE_BF16"], + extra_cflags=[ + "-DFAST_BUILD", + ], + ) + + def gen_gemm_sm100_module() -> JitSpec: gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100" os.makedirs(gen_directory, exist_ok=True) diff --git a/include/flashinfer/gemm/bf16_gemm_cutlass.h b/include/flashinfer/gemm/bf16_gemm_cutlass.h new file mode 100644 index 0000000000..6011075a19 --- /dev/null +++ b/include/flashinfer/gemm/bf16_gemm_cutlass.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2025, FlashInfer. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef FLASHINFER_BF16_GEMM_CUTLASS_H_ +#define FLASHINFER_BF16_GEMM_CUTLASS_H_ + +#include + +#include + +#include "flashinfer/gemm/cutlass_gemm_configs.h" + +namespace flashinfer { +namespace gemm { + +class CutlassBf16GemmRunnerInterface { + public: + CutlassBf16GemmRunnerInterface() = default; + virtual ~CutlassBf16GemmRunnerInterface() = default; + + virtual void gemm(__nv_bfloat16 const* A, __nv_bfloat16 const* B, void* D, int m, int n, int k, + int b, CutlassGemmConfig gemmConfig, char* workspacePtr, + size_t const workspaceBytes, cudaStream_t stream) = 0; + + virtual size_t getWorkspaceSize(int m, int n, int k) = 0; + + virtual std::vector getConfigs() const = 0; +}; + +template +class CutlassBf16GemmRunner : public CutlassBf16GemmRunnerInterface { + public: + CutlassBf16GemmRunner() = default; + ~CutlassBf16GemmRunner() = default; + + void gemm(__nv_bfloat16 const* A, __nv_bfloat16 const* B, void* D, int m, int n, int k, int b, + CutlassGemmConfig gemmConfig, char* workspacePtr, size_t const workspaceBytes, + cudaStream_t stream) override; + size_t getWorkspaceSize(int m, int n, int k) override; + std::vector getConfigs() const override; + + private: + size_t getWorkspaceSizeImpl(int m, int n, int k); +}; + +} // namespace gemm +} // namespace flashinfer + +#endif // FLASHINFER_BF16_GEMM_CUTLASS_H_ diff --git a/include/flashinfer/gemm/bf16_gemm_cutlass_template.h b/include/flashinfer/gemm/bf16_gemm_cutlass_template.h new file mode 100644 index 0000000000..54953e3043 --- /dev/null +++ b/include/flashinfer/gemm/bf16_gemm_cutlass_template.h @@ -0,0 +1,215 @@ +/* + * Copyright (c) 2025, FlashInfer. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_BF16_GEMM_CUTLASS_TEMPLATE_H_ +#define FLASHINFER_BF16_GEMM_CUTLASS_TEMPLATE_H_ + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/gemm.h" +#include "flashinfer/arch_condition.h" +#include "flashinfer/cutlass_utils.cuh" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC__ + +#include +#include + +#include "cutlass/bfloat16.h" + +namespace flashinfer { +namespace gemm { + +struct _1SM {}; +struct _2SM {}; + +template +size_t genericBf16GemmKernelLauncherSm100(__nv_bfloat16 const* A, __nv_bfloat16 const* B, T* D, + int m, int n, int k, int b, CutlassGemmConfig config, + char* workspacePtr, size_t const workspaceBytes, + cudaStream_t stream); + +template +size_t dispatchGemmClusterShapeSm100(__nv_bfloat16 const* A, __nv_bfloat16 const* B, T* D, int m, + int n, int k, int b, CutlassGemmConfig gemmConfig, + char* workspacePtr, size_t const workspaceBytes, + cudaStream_t stream) { + using namespace cute; + + switch (gemmConfig.cluster_shape) { + case ClusterShape::ClusterShape_1x1x1: + return genericBf16GemmKernelLauncherSm100, + _1SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr, + workspaceBytes, stream); + break; + // case ClusterShape::ClusterShape_1x2x1: + // return genericBf16GemmKernelLauncherSm100, _1SM>( + // A, B, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); + // break; + // case ClusterShape::ClusterShape_1x4x1: + // return genericBf16GemmKernelLauncherSm100, _1SM>( + // A, B, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); + // break; + // case ClusterShape::ClusterShape_2x1x1: + // return genericBf16GemmKernelLauncherSm100, _2SM>( + // A, B, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); + // break; + // case ClusterShape::ClusterShape_2x2x1: + // return genericBf16GemmKernelLauncherSm100, _2SM>( + // A, B, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); + // break; + default: + throw std::runtime_error("invalid config for bf16 gemm"); + break; + } +} + +template +size_t dispatchToArch(__nv_bfloat16 const* A, __nv_bfloat16 const* B, void* D, int m, int n, int k, + int b, CutlassGemmConfig gemmConfig, char* workspacePtr, + size_t const workspaceBytes, cudaStream_t stream) { + using arch = cutlass::arch::Sm100; + + switch (gemmConfig.tile_config_sm100) { + case CutlassTileConfigSM100::CtaShape64x64x128B: + return dispatchGemmClusterShapeSm100( + B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); + break; + // case CutlassTileConfigSM100::CtaShape64x128x128B: + // return dispatchGemmClusterShapeSm100( + // B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, + // stream); + // break; + // case CutlassTileConfigSM100::CtaShape64x256x128B: + // return dispatchGemmClusterShapeSm100( + // B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, + // stream); + // break; + // case CutlassTileConfigSM100::CtaShape128x64x128B: + // return dispatchGemmClusterShapeSm100( + // B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, + // stream); + // break; + // case CutlassTileConfigSM100::CtaShape128x128x128B: + // return dispatchGemmClusterShapeSm100( + // B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, + // stream); + // break; + // case CutlassTileConfigSM100::CtaShape128x256x128B: + // return dispatchGemmClusterShapeSm100( + // B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, + // stream); + // break; + + default: + throw std::runtime_error("unsupported tile config for bf16 gemm"); + break; + } +} + +template +void CutlassBf16GemmRunner::gemm(__nv_bfloat16 const* A, __nv_bfloat16 const* B, void* D, int m, + int n, int k, int b, CutlassGemmConfig gemmConfig, + char* workspacePtr, size_t const workspaceBytes, + cudaStream_t stream) { + dispatchToArch(A, B, reinterpret_cast(D), m, n, k, b, gemmConfig, workspacePtr, + workspaceBytes, stream); +} + +template +size_t CutlassBf16GemmRunner::getWorkspaceSizeImpl(int m, int n, int k) { + size_t workspace_size = 0; + auto gemmConfigs = CutlassBf16GemmRunner{}.getConfigs(); + for (auto const& gemmConfig : gemmConfigs) { + try { + size_t curr_workspace_size = + dispatchToArch(nullptr, nullptr, nullptr, m, n, k, 1, gemmConfig, nullptr, 0, nullptr); + workspace_size = std::max(workspace_size, curr_workspace_size); + } catch (std::runtime_error&) { + continue; + } + } + return workspace_size; +} + +template +size_t CutlassBf16GemmRunner::getWorkspaceSize(int m, int n, int k) { + using MNK = std::tuple; + + struct MNKHash { + size_t operator()(const MNK& mnk) const { + auto h1 = std::hash{}(std::get<0>(mnk)); + auto h2 = std::hash{}(std::get<1>(mnk)); + auto h3 = std::hash{}(std::get<2>(mnk)); + return h1 ^ h2 ^ h3; + } + }; + + static std::unordered_map workspace_hashmap; + + size_t workspace_size = 0; + if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) { + workspace_size = CutlassBf16GemmRunner::getWorkspaceSizeImpl(m, n, k); + workspace_hashmap[std::make_tuple(m, n, k)] = workspace_size; + } else { + workspace_size = workspace_hashmap[std::make_tuple(m, n, k)]; + } + return workspace_size; +} + +template +std::vector CutlassBf16GemmRunner::getConfigs() const { + std::vector candidate_configs; + + std::vector tilesSm100 = { + CutlassTileConfigSM100::CtaShape64x64x128B, // CutlassTileConfigSM100::CtaShape64x128x128B, + // CutlassTileConfigSM100::CtaShape64x256x128B, CutlassTileConfigSM100::CtaShape128x64x128B, + // CutlassTileConfigSM100::CtaShape128x128x128B, CutlassTileConfigSM100::CtaShape128x256x128B, + }; + + std::vector clusterShapes = { + ClusterShape::ClusterShape_1x1x1, // ClusterShape::ClusterShape_1x2x1, + // ClusterShape::ClusterShape_1x4x1, ClusterShape::ClusterShape_2x1x1, + // ClusterShape::ClusterShape_2x2x1, + }; + + for (auto const& tile_config : tilesSm100) { + for (auto const& cluster_config : clusterShapes) { + CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, + cluster_config); + candidate_configs.push_back(config); + } + } + return candidate_configs; +} + +} // namespace gemm +} // namespace flashinfer + +#endif // FLASHINFER_BF16_GEMM_CUTLASS_TEMPLATE_H_ diff --git a/include/flashinfer/gemm/bf16_gemm_template_sm100.h b/include/flashinfer/gemm/bf16_gemm_template_sm100.h new file mode 100644 index 0000000000..990fcd910a --- /dev/null +++ b/include/flashinfer/gemm/bf16_gemm_template_sm100.h @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2025, FlashInfer. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_BF16_GEMM_TEMPLATE_SM100_H_ +#define FLASHINFER_BF16_GEMM_TEMPLATE_SM100_H_ + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/numeric_conversion.h" +#include "flashinfer/arch_condition.h" +#include "flashinfer/cutlass_utils.cuh" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC__ + +#include +#include + +#include "cutlass/bfloat16.h" +#include "flashinfer/gemm/cutlass_gemm_configs.h" + +namespace flashinfer { +namespace gemm { + +template +struct SMTypeAdapter {}; + +struct _1SM; +struct _2SM; + +template <> +struct SMTypeAdapter<_1SM> { + static int const Scale = 1; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; +}; + +template <> +struct SMTypeAdapter<_2SM> { + static int const Scale = 2; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmSm100; +}; + +template +size_t genericBf16GemmKernelLauncherSm100(__nv_bfloat16 const* A, __nv_bfloat16 const* B, T* D, + int m, int n, int k, int b, CutlassGemmConfig config, + char* workspacePtr, size_t const workspaceBytes, + cudaStream_t stream) { + using namespace cute; + + using ElementA = cutlass::bfloat16_t; + using LayoutA = cutlass::layout::RowMajor; + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = cutlass::bfloat16_t; + using LayoutB = cutlass::layout::ColumnMajor; + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementOutput_ = + typename cutlass::platform::conditional::value, + cutlass::half_t, T>::type; +#ifdef ENABLE_BF16 + using ElementOutput = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, cutlass::bfloat16_t, + ElementOutput_>::type; +#else + using ElementOutput = ElementOutput_; +#endif + + using ElementC = ElementOutput; + using LayoutC = cutlass::layout::ColumnMajor; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ElementD = ElementC; + using LayoutD = LayoutC; + constexpr int AlignmentD = AlignmentC; + + using ElementAccumulator = float; + using ElementCompute = float; + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape::Scale>, cute::Int, + cute::Int>; + + using ClusterShape = ClusterShape_; + using EpilogueSchedule = typename SMTypeAdapter::EpilogueSchedule; + using MainloopSchedule = typename SMTypeAdapter::MainloopSchedule; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, + ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal, + CollectiveMainloop, CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, b)); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, b)); + auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, b)); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, b)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, b}, + {reinterpret_cast(A), stride_A, reinterpret_cast(B), + stride_B}, + {{}, nullptr, stride_C, reinterpret_cast(D), stride_D}}; + + // Is not the right way to do this? + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha = 1.0f; + fusion_args.beta = 0.0f; + + Gemm gemm; + + CUTLASS_CHECK(gemm.can_implement(arguments)); + + size_t workspace_size = gemm.get_workspace_size(arguments); + if (workspace_size > workspaceBytes) { + throw std::runtime_error("[Bf16 Gemm Runner] insufficient workspace"); + } + + // NOTE: These can also be simplified using CUTLASS_CHECK. Same goes for some of the other files. + cutlass::Status initStatus = gemm.initialize(arguments, workspacePtr, stream); + if (initStatus != cutlass::Status::kSuccess) { + throw std::runtime_error("[Bf16 Gemm Runner] failed to initialize"); + } + + cutlass::Status runStatus = gemm.run(stream); + if (runStatus != cutlass::Status::kSuccess) { + throw std::runtime_error("[Bf16 Gemm Runner] failed to run"); + } + + return workspace_size; +} + +} // namespace gemm +} // namespace flashinfer + +#define INSTANCE_BF16_GEMM_TEMPLATE_SM100(RET_TYPE, TILE_M, TILE_N, TILE_K, CGA_M_, CGA_N_, \ + CGA_K_, SM_TYPE) \ + template size_t genericBf16GemmKernelLauncherSm100< \ + RET_TYPE, cutlass::arch::Sm100, TILE_M, TILE_N, TILE_K, \ + cute::Shape, cute::Int, cute::Int>, SM_TYPE>( \ + __nv_bfloat16 const* A, __nv_bfloat16 const* B, RET_TYPE* D, int m, int n, int k, int b, \ + CutlassGemmConfig config, char* workspacePtr, size_t const workspaceBytes, \ + cudaStream_t stream); + +#endif // FLASHINFER_BF16_GEMM_TEMPLATE_SM100_H_ diff --git a/include/flashinfer/gemm/fp8_gemm_cutlass_template.h b/include/flashinfer/gemm/fp8_gemm_cutlass_template.h index a3c9d0f2c4..c346fe53f6 100644 --- a/include/flashinfer/gemm/fp8_gemm_cutlass_template.h +++ b/include/flashinfer/gemm/fp8_gemm_cutlass_template.h @@ -68,7 +68,6 @@ size_t dispatchGemmClusterShapeSm100(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const _1SM>( A, B, scale_a, scale_b, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); break; - case ClusterShape::ClusterShape_2x1x1: return genericFp8GemmKernelLauncherSm100, _2SM>( diff --git a/tests/gemm/test_bmm_bf16.py b/tests/gemm/test_bmm_bf16.py new file mode 100644 index 0000000000..6c635b5a17 --- /dev/null +++ b/tests/gemm/test_bmm_bf16.py @@ -0,0 +1,38 @@ +import pytest +import torch +import torch.nn.functional as F + +from flashinfer import autotune, bmm_bf16 +from flashinfer.utils import get_compute_capability + + +@pytest.mark.parametrize("b", [1, 16]) +@pytest.mark.parametrize("m", [48, 128]) +@pytest.mark.parametrize("n", [80, 64]) +@pytest.mark.parametrize("k", [64, 256]) +@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) +def test_bmm_bf16(b, m, n, k, res_dtype): + compute_capability = get_compute_capability(torch.device(device="cuda")) + print(compute_capability) + cc_number = compute_capability[0] * 10 + compute_capability[1] + if not bmm_bf16.is_compute_capability_supported(cc_number): + pytest.skip( + f"bmm_bf16 requires one of the following compute capabilities: " + f"{sorted(bmm_bf16._supported_ccs)}. " + f"Detected sm{cc_number}." + ) + torch.manual_seed(7) + a = torch.randn([b, m, k], device="cuda", dtype=torch.bfloat16) + b = torch.randn([b, k, n], device="cuda", dtype=torch.bfloat16) + reference = torch.bmm(a.float(), b.float()) + + out = torch.empty([b, m, n], device="cuda", dtype=res_dtype) + with autotune(): + bmm_bf16(a, b, out=out, out_dtype=res_dtype) + + cos_sim = F.cosine_similarity(reference.reshape(-1), out.float().reshape(-1), dim=0) + assert cos_sim > 0.99 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/gemm/test_mm_bf16.py b/tests/gemm/test_mm_bf16.py new file mode 100644 index 0000000000..dc27bcebfe --- /dev/null +++ b/tests/gemm/test_mm_bf16.py @@ -0,0 +1,37 @@ +import pytest +import torch +import torch.nn.functional as F + +from flashinfer import autotune, mm_bf16 +from flashinfer.utils import get_compute_capability + + +@pytest.mark.parametrize("m", [1, 48, 128, 256, 512]) +@pytest.mark.parametrize("n", [128, 256, 512]) +@pytest.mark.parametrize("k", [128, 256, 512]) +@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) +def test_mm_bf16(m: int, n: int, k: int, res_dtype: torch.dtype): + compute_capability = get_compute_capability(torch.device(device="cuda")) + cc_number = compute_capability[0] * 10 + compute_capability[1] + if not mm_bf16.is_compute_capability_supported(cc_number): + pytest.skip( + f"mm_bf16 requires one of the following compute capabilities: " + f"{sorted(mm_bf16._supported_ccs)}. " + f"Detected sm{cc_number}." + ) + + torch.manual_seed(42) + a = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) + b = torch.randn([k, n], device="cuda", dtype=torch.bfloat16) + reference = torch.mm(a.float(), b.float()) + + out = torch.empty([m, n], device="cuda", dtype=res_dtype) + with autotune(): + mm_bf16(a, b, out=out, out_dtype=res_dtype) + + cos_sim = F.cosine_similarity(reference.reshape(-1), out.float().reshape(-1), dim=0) + assert cos_sim > 0.99 + + +if __name__ == "__main__": + pytest.main([__file__]) From 696295a3920b759834e9eab5a8f0812a97bc28b1 Mon Sep 17 00:00:00 2001 From: raayandhar Date: Sun, 9 Nov 2025 20:12:00 -0800 Subject: [PATCH 02/11] fix rebase Signed-off-by: raayandhar --- flashinfer/gemm/gemm_base.py | 46 ------------------------------------ 1 file changed, 46 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 8ed2acd6b2..cba52b2a22 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -621,52 +621,6 @@ def bf16_gemm_sm100( runner(inputs=inputs, tactic=tactic) -def bf16_gemm_sm100( - a: torch.Tensor, - b: torch.Tensor, - out: torch.Tensor, - workspace_buffer: torch.Tensor, -) -> None: - runners = [] - is_sm_supported = _match_sm_version(a.device, ["100"]) - - if is_sm_supported: - runners.append(get_gemm_sm100_module_cutlass_bf16().cutlass_bf16_gemm_runner()) - - if len(runners) == 0: - major, minor = get_compute_capability(torch.device("cuda")) - raise ValueError(f"No valid runner found for current device sm{major}{minor}") - - tuner = AutoTuner.get() - a_tensor_index = 0 - out_tensor_index = 2 - tuning_config = TuningConfig( - dynamic_tensor_specs=( - DynamicTensorSpec( - (a_tensor_index,), - (-2,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, - ), - ), - constraint_specs=( - ConstraintSpec( - out_tensor_index, -2, lambda shapes: shapes[a_tensor_index][-2] - ), - ), - ) - - inputs = [a, b, out, workspace_buffer] - runner, tactic = tuner.choose_one( - "bf16_gemm", - runners, - tuning_config, - inputs, - ) - - runner(inputs=inputs, tactic=tactic) - - def _create_cutlass_fp4_gemm_module(module, op_name: str, tuner_name: str): """Helper function to create cutlass FP4 GEMM module.""" From b0dd1f3b55c1f960252c200c8888a3ba3e1f0568 Mon Sep 17 00:00:00 2001 From: raayandhar Date: Sun, 9 Nov 2025 20:17:27 -0800 Subject: [PATCH 03/11] fix import path from bad rebase Signed-off-by: raayandhar --- flashinfer/gemm/gemm_base.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index cba52b2a22..1889acc4f3 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -44,17 +44,17 @@ backend_requirement, supported_compute_capability, ) -from .jit.gemm import gen_gemm_sm90_module -from .jit.gemm import gen_gemm_module -from .jit.gemm import gen_gemm_sm100_module -from .jit.gemm import gen_gemm_sm120_module -from .jit.gemm import gen_gemm_sm120_module_cutlass_fp4 -from .jit.gemm import gen_gemm_sm100_module_cutlass_fp4 -from .jit.gemm import gen_gemm_sm100_module_cutlass_fp8 -from .jit.gemm import gen_gemm_sm100_module_cutlass_bf16 -from .jit.gemm import gen_trtllm_gen_gemm_module -from .jit.gemm import gen_tgv_gemm_sm10x_module -from .jit.gemm import gen_deepgemm_sm100_module +from ..jit.gemm import gen_gemm_sm90_module +from ..jit.gemm import gen_gemm_module +from ..jit.gemm import gen_gemm_sm100_module +from ..jit.gemm import gen_gemm_sm120_module +from ..jit.gemm import gen_gemm_sm120_module_cutlass_fp4 +from ..jit.gemm import gen_gemm_sm100_module_cutlass_fp4 +from ..jit.gemm import gen_gemm_sm100_module_cutlass_fp8 +from ..jit.gemm import gen_gemm_sm100_module_cutlass_bf16 +from ..jit.gemm import gen_trtllm_gen_gemm_module +from ..jit.gemm import gen_tgv_gemm_sm10x_module +from ..jit.gemm import gen_deepgemm_sm100_module CUDNN_AVAILABLE = False From cf8182c8c8ff34f98b6d68f4ae7e67d7fa44eccb Mon Sep 17 00:00:00 2001 From: raayandhar Date: Sun, 9 Nov 2025 20:19:33 -0800 Subject: [PATCH 04/11] add missing exports Signed-off-by: raayandhar --- flashinfer/gemm/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index 15652268ba..f84a71e920 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -1,5 +1,7 @@ from .gemm_base import SegmentGEMMWrapper as SegmentGEMMWrapper +from .gemm_base import bmm_bf16 as bmm_bf16 from .gemm_base import bmm_fp8 as bmm_fp8 +from .gemm_base import mm_bf16 as mm_bf16 from .gemm_base import mm_fp4 as mm_fp4 from .gemm_base import mm_fp8 as mm_fp8 from .gemm_base import tgv_gemm_sm100 as tgv_gemm_sm100 @@ -20,7 +22,9 @@ __all__ = [ "SegmentGEMMWrapper", + "bmm_bf16", "bmm_fp8", + "mm_bf16", "mm_fp4", "mm_fp8", "tgv_gemm_sm100", From 1e8da31e4d96bea4a7dc971a0a7141adcdfc35f5 Mon Sep 17 00:00:00 2001 From: "Raayan Dhar raayan.dhar@gmail.com" Date: Tue, 11 Nov 2025 19:33:41 -0800 Subject: [PATCH 05/11] address coderabbit comments + try to fix contiguous check? Signed-off-by: Raayan Dhar raayan.dhar@gmail.com --- flashinfer/gemm/gemm_base.py | 29 ++++++++----------- .../gemm/bf16_gemm_cutlass_template.h | 1 + tests/gemm/test_bmm_bf16.py | 17 +++++------ tests/gemm/test_mm_bf16.py | 16 +++++----- 4 files changed, 29 insertions(+), 34 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 1889acc4f3..29939139c8 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -248,8 +248,8 @@ def mm_bf16( @supported_compute_capability([100]) def bmm_bf16( - a: torch.Tensor, - b: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, backend: Literal["cutlass"] = "cutlass", @@ -258,10 +258,10 @@ def bmm_bf16( Parameters ---------- - a: torch.Tensor + A: torch.Tensor Input tensor, shape (b, m, k), bf16. - b: torch.Tensor + B: torch.Tensor Weight tensor, shape (b, k, n), bf16. out: Optional[torch.Tensor] @@ -285,11 +285,11 @@ def bmm_bf16( if out_dtype not in (torch.bfloat16, torch.float16): raise ValueError("Only bf16 and fp16 outputs are supported.") - expected_shape = (a.shape[0], a.shape[1], b.shape[2]) + expected_shape = (A.shape[0], A.shape[1], B.shape[2]) if out is None: out = torch.empty( expected_shape, - device=a.device, + device=A.device, dtype=out_dtype, ) else: @@ -297,9 +297,9 @@ def bmm_bf16( raise ValueError( f"Output shape mismatch. Expected {expected_shape}, got {out.shape}." ) - if out.device != a.device: + if out.device != A.device: raise ValueError( - f"Output device mismatch. Expected {a.device}, got {out.device}." + f"Output device mismatch. Expected {A.device}, got {out.device}." ) if out.dtype != out_dtype: raise ValueError( @@ -307,9 +307,9 @@ def bmm_bf16( ) workspace_buffer = _get_cache_buf( - "bmm_bf16_workspace", DEFAULT_WORKSPACE_SIZE, a.device + "bmm_bf16_workspace", DEFAULT_WORKSPACE_SIZE, A.device ) - bf16_gemm_sm100(a, b, out, workspace_buffer) + bf16_gemm_sm100(A, B, out, workspace_buffer) return out @@ -582,14 +582,9 @@ def bf16_gemm_sm100( workspace_buffer: torch.Tensor, ) -> None: runners = [] - is_sm_supported = _match_sm_version(a.device, ["100"]) - - if is_sm_supported: + if _match_sm_version(a.device, ["100"]): runners.append(get_gemm_sm100_module_cutlass_bf16().cutlass_bf16_gemm_runner()) - - if len(runners) == 0: - major, minor = get_compute_capability(torch.device("cuda")) - raise ValueError(f"No valid runner found for current device sm{major}{minor}") + assert runners, "No suitable runners found" tuner = AutoTuner.get() a_tensor_index = 0 diff --git a/include/flashinfer/gemm/bf16_gemm_cutlass_template.h b/include/flashinfer/gemm/bf16_gemm_cutlass_template.h index 54953e3043..82c863fe05 100644 --- a/include/flashinfer/gemm/bf16_gemm_cutlass_template.h +++ b/include/flashinfer/gemm/bf16_gemm_cutlass_template.h @@ -152,6 +152,7 @@ size_t CutlassBf16GemmRunner::getWorkspaceSizeImpl(int m, int n, int k) { dispatchToArch(nullptr, nullptr, nullptr, m, n, k, 1, gemmConfig, nullptr, 0, nullptr); workspace_size = std::max(workspace_size, curr_workspace_size); } catch (std::runtime_error&) { + // Swallow errors when SMEM exceeds maximum allowed continue; } } diff --git a/tests/gemm/test_bmm_bf16.py b/tests/gemm/test_bmm_bf16.py index 6c635b5a17..9e678db06e 100644 --- a/tests/gemm/test_bmm_bf16.py +++ b/tests/gemm/test_bmm_bf16.py @@ -13,24 +13,23 @@ @pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) def test_bmm_bf16(b, m, n, k, res_dtype): compute_capability = get_compute_capability(torch.device(device="cuda")) - print(compute_capability) - cc_number = compute_capability[0] * 10 + compute_capability[1] - if not bmm_bf16.is_compute_capability_supported(cc_number): + compute_capability_number = compute_capability[0] * 10 + compute_capability[1] + if not bmm_bf16.is_compute_capability_supported(compute_capability_number): pytest.skip( f"bmm_bf16 requires one of the following compute capabilities: " f"{sorted(bmm_bf16._supported_ccs)}. " - f"Detected sm{cc_number}." + f"Detected sm{compute_capability_number}." ) torch.manual_seed(7) - a = torch.randn([b, m, k], device="cuda", dtype=torch.bfloat16) - b = torch.randn([b, k, n], device="cuda", dtype=torch.bfloat16) - reference = torch.bmm(a.float(), b.float()) + input = torch.randn([b, m, k], device="cuda", dtype=torch.bfloat16) + mat2 = torch.randn([b, n, k], device="cuda", dtype=torch.bfloat16).tranpose(-2, -1) + reference = torch.bmm(input, mat2) out = torch.empty([b, m, n], device="cuda", dtype=res_dtype) with autotune(): - bmm_bf16(a, b, out=out, out_dtype=res_dtype) + bmm_bf16(input, mat2, out=out, out_dtype=res_dtype) - cos_sim = F.cosine_similarity(reference.reshape(-1), out.float().reshape(-1), dim=0) + cos_sim = F.cosine_similarity(reference.reshape(-1), out.reshape(-1), dim=0) assert cos_sim > 0.99 diff --git a/tests/gemm/test_mm_bf16.py b/tests/gemm/test_mm_bf16.py index dc27bcebfe..cb8de90a29 100644 --- a/tests/gemm/test_mm_bf16.py +++ b/tests/gemm/test_mm_bf16.py @@ -12,24 +12,24 @@ @pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) def test_mm_bf16(m: int, n: int, k: int, res_dtype: torch.dtype): compute_capability = get_compute_capability(torch.device(device="cuda")) - cc_number = compute_capability[0] * 10 + compute_capability[1] - if not mm_bf16.is_compute_capability_supported(cc_number): + compute_capability_number = compute_capability[0] * 10 + compute_capability[1] + if not mm_bf16.is_compute_capability_supported(compute_capability_number): pytest.skip( f"mm_bf16 requires one of the following compute capabilities: " f"{sorted(mm_bf16._supported_ccs)}. " - f"Detected sm{cc_number}." + f"Detected sm{compute_capability_number}." ) torch.manual_seed(42) - a = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) - b = torch.randn([k, n], device="cuda", dtype=torch.bfloat16) - reference = torch.mm(a.float(), b.float()) + input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) + mat2 = torch.randn([k, n], device="cuda", dtype=torch.bfloat16) + reference = torch.mm(input, mat2.T) out = torch.empty([m, n], device="cuda", dtype=res_dtype) with autotune(): - mm_bf16(a, b, out=out, out_dtype=res_dtype) + mm_bf16(input, mat2.T, out=out, out_dtype=res_dtype) - cos_sim = F.cosine_similarity(reference.reshape(-1), out.float().reshape(-1), dim=0) + cos_sim = F.cosine_similarity(reference.reshape(-1), out.reshape(-1), dim=0) assert cos_sim > 0.99 From ccc6641950112d9a37217d84305fa59dca9d0a93 Mon Sep 17 00:00:00 2001 From: "Raayan Dhar raayan.dhar@gmail.com" Date: Tue, 11 Nov 2025 19:46:25 -0800 Subject: [PATCH 06/11] small fixes Signed-off-by: Raayan Dhar raayan.dhar@gmail.com --- flashinfer/gemm/gemm_base.py | 5 ++--- tests/gemm/test_bmm_bf16.py | 2 +- tests/gemm/test_mm_bf16.py | 3 ++- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 29939139c8..9122fe5927 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -192,11 +192,10 @@ def mm_bf16( Parameters ---------- a: torch.Tensor - Input tensor, shape (m, k), bf16 row-major. + Input tensor, shape (m, k), bf16. b: torch.Tensor - Weight tensor, shape (k, n), bf16 row-major. This tensor is interpreted - as a column-major (n, k) matrix internally. + Weight tensor, shape (k, n), bf16. out: Optional[torch.Tensor] Out tensor, shape (m, n), bf16 or fp16, defaults to ``None``. diff --git a/tests/gemm/test_bmm_bf16.py b/tests/gemm/test_bmm_bf16.py index 9e678db06e..b6b47e5860 100644 --- a/tests/gemm/test_bmm_bf16.py +++ b/tests/gemm/test_bmm_bf16.py @@ -22,7 +22,7 @@ def test_bmm_bf16(b, m, n, k, res_dtype): ) torch.manual_seed(7) input = torch.randn([b, m, k], device="cuda", dtype=torch.bfloat16) - mat2 = torch.randn([b, n, k], device="cuda", dtype=torch.bfloat16).tranpose(-2, -1) + mat2 = torch.randn([b, n, k], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) reference = torch.bmm(input, mat2) out = torch.empty([b, m, n], device="cuda", dtype=res_dtype) diff --git a/tests/gemm/test_mm_bf16.py b/tests/gemm/test_mm_bf16.py index cb8de90a29..6ccd7518f4 100644 --- a/tests/gemm/test_mm_bf16.py +++ b/tests/gemm/test_mm_bf16.py @@ -22,7 +22,8 @@ def test_mm_bf16(m: int, n: int, k: int, res_dtype: torch.dtype): torch.manual_seed(42) input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) - mat2 = torch.randn([k, n], device="cuda", dtype=torch.bfloat16) + mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) + reference = torch.mm(input, mat2.T) out = torch.empty([m, n], device="cuda", dtype=res_dtype) From 3c6393cce3049ba7ef71b1c0f02e2cbdacf1ab20 Mon Sep 17 00:00:00 2001 From: "Raayan Dhar raayan.dhar@gmail.com" Date: Sun, 16 Nov 2025 21:09:25 -0800 Subject: [PATCH 07/11] small notes, enable other tile sizes Signed-off-by: Raayan Dhar raayan.dhar@gmail.com --- csrc/bf16_gemm_cutlass.jinja | 8 +- flashinfer/gemm/gemm_base.py | 26 +++++- flashinfer/jit/gemm/core.py | 9 +- .../gemm/bf16_gemm_cutlass_template.h | 93 +++++++++---------- .../gemm/bf16_gemm_template_sm100.h | 20 ++-- 5 files changed, 87 insertions(+), 69 deletions(-) diff --git a/csrc/bf16_gemm_cutlass.jinja b/csrc/bf16_gemm_cutlass.jinja index 27d6913346..0e8a5f0f9f 100644 --- a/csrc/bf16_gemm_cutlass.jinja +++ b/csrc/bf16_gemm_cutlass.jinja @@ -19,9 +19,9 @@ namespace flashinfer { namespace gemm { INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 1, 1, _1SM); - // INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 2, 1, _1SM); - // INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 4, 1, _1SM); - // INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 1, 1, _2SM); - // INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 2, 1, _2SM); + INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 2, 1, _1SM); + INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 4, 1, _1SM); + INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 1, 1, _2SM); + INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 2, 1, _2SM); } // namespace gemm } // namespace flashinfer diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 9122fe5927..e6ca2d6e35 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -211,7 +211,18 @@ def mm_bf16( torch.Tensor Out tensor, shape (m, n), bf16 or fp16. - # Note: add Examples section here + Examples + -------- + >>> import torch + >>> import torch.nn.functional as F + >>> import flashinfer + >>> input = torch.randn([48, 64], device="cuda", dtype=torch.bfloat16 + >>> weight = torch.randn([80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) + >>> out = flashinfer.mm_bf16(input, weight) + >>> print(out.shape) + torch.Size([48, 80]) + >>> out.dtype + torch.bfloat16 """ if backend != "cutlass": raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.") @@ -277,7 +288,18 @@ def bmm_bf16( torch.Tensor Out tensor, shape (b, m, n), bf16 or fp16. - # Note: add Examples section here + Examples + -------- + >>> import torch + >>> import torch.nn.functional as F + >>> import flashinfer + >>> input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16 + >>> weight = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) + >>> out = flashinfer.bmm_bf16(input, weight) + >>> print(out.shape) + torch.Size([16, 48, 80]) + >>> out.dtype + torch.bfloat16 """ if backend != "cutlass": raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.") diff --git a/flashinfer/jit/gemm/core.py b/flashinfer/jit/gemm/core.py index 383a38a87e..5d40b510ac 100644 --- a/flashinfer/jit/gemm/core.py +++ b/flashinfer/jit/gemm/core.py @@ -202,11 +202,10 @@ def gen_gemm_sm100_module_cutlass_bf16() -> JitSpec: dtype_list = ["__nv_bfloat16", "half"] cta_m_n_k_list = [ (64, 64, 128), - # (64, 128, 128), - # (64, 256, 128), - # (128, 64, 128), - # (128, 128, 128), - # (128, 256, 128), + (64, 128, 128), + (64, 256, 128), + (128, 64, 128), + (128, 128, 128), ] for cta_m, cta_n, cta_k in cta_m_n_k_list: for dtype in dtype_list: diff --git a/include/flashinfer/gemm/bf16_gemm_cutlass_template.h b/include/flashinfer/gemm/bf16_gemm_cutlass_template.h index 82c863fe05..f73ea1bde2 100644 --- a/include/flashinfer/gemm/bf16_gemm_cutlass_template.h +++ b/include/flashinfer/gemm/bf16_gemm_cutlass_template.h @@ -64,26 +64,26 @@ size_t dispatchGemmClusterShapeSm100(__nv_bfloat16 const* A, __nv_bfloat16 const _1SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); break; - // case ClusterShape::ClusterShape_1x2x1: - // return genericBf16GemmKernelLauncherSm100, _1SM>( - // A, B, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); - // break; - // case ClusterShape::ClusterShape_1x4x1: - // return genericBf16GemmKernelLauncherSm100, _1SM>( - // A, B, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); - // break; - // case ClusterShape::ClusterShape_2x1x1: - // return genericBf16GemmKernelLauncherSm100, _2SM>( - // A, B, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); - // break; - // case ClusterShape::ClusterShape_2x2x1: - // return genericBf16GemmKernelLauncherSm100, _2SM>( - // A, B, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); - // break; + case ClusterShape::ClusterShape_1x2x1: + return genericBf16GemmKernelLauncherSm100, + _1SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr, + workspaceBytes, stream); + break; + case ClusterShape::ClusterShape_1x4x1: + return genericBf16GemmKernelLauncherSm100, + _1SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr, + workspaceBytes, stream); + break; + case ClusterShape::ClusterShape_2x1x1: + return genericBf16GemmKernelLauncherSm100, + _2SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr, + workspaceBytes, stream); + break; + case ClusterShape::ClusterShape_2x2x1: + return genericBf16GemmKernelLauncherSm100, + _2SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr, + workspaceBytes, stream); + break; default: throw std::runtime_error("invalid config for bf16 gemm"); break; @@ -101,31 +101,22 @@ size_t dispatchToArch(__nv_bfloat16 const* A, __nv_bfloat16 const* B, void* D, i return dispatchGemmClusterShapeSm100( B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); break; - // case CutlassTileConfigSM100::CtaShape64x128x128B: - // return dispatchGemmClusterShapeSm100( - // B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, - // stream); - // break; - // case CutlassTileConfigSM100::CtaShape64x256x128B: - // return dispatchGemmClusterShapeSm100( - // B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, - // stream); - // break; - // case CutlassTileConfigSM100::CtaShape128x64x128B: - // return dispatchGemmClusterShapeSm100( - // B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, - // stream); - // break; - // case CutlassTileConfigSM100::CtaShape128x128x128B: - // return dispatchGemmClusterShapeSm100( - // B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, - // stream); - // break; - // case CutlassTileConfigSM100::CtaShape128x256x128B: - // return dispatchGemmClusterShapeSm100( - // B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, - // stream); - // break; + case CutlassTileConfigSM100::CtaShape64x128x128B: + return dispatchGemmClusterShapeSm100( + B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); + break; + case CutlassTileConfigSM100::CtaShape64x256x128B: + return dispatchGemmClusterShapeSm100( + B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); + break; + case CutlassTileConfigSM100::CtaShape128x64x128B: + return dispatchGemmClusterShapeSm100( + B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); + break; + case CutlassTileConfigSM100::CtaShape128x128x128B: + return dispatchGemmClusterShapeSm100( + B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); + break; default: throw std::runtime_error("unsupported tile config for bf16 gemm"); @@ -189,15 +180,15 @@ std::vector CutlassBf16GemmRunner::getConfigs() const { std::vector candidate_configs; std::vector tilesSm100 = { - CutlassTileConfigSM100::CtaShape64x64x128B, // CutlassTileConfigSM100::CtaShape64x128x128B, - // CutlassTileConfigSM100::CtaShape64x256x128B, CutlassTileConfigSM100::CtaShape128x64x128B, - // CutlassTileConfigSM100::CtaShape128x128x128B, CutlassTileConfigSM100::CtaShape128x256x128B, + CutlassTileConfigSM100::CtaShape64x64x128B, CutlassTileConfigSM100::CtaShape64x128x128B, + CutlassTileConfigSM100::CtaShape64x256x128B, CutlassTileConfigSM100::CtaShape128x64x128B, + CutlassTileConfigSM100::CtaShape128x128x128B, }; std::vector clusterShapes = { - ClusterShape::ClusterShape_1x1x1, // ClusterShape::ClusterShape_1x2x1, - // ClusterShape::ClusterShape_1x4x1, ClusterShape::ClusterShape_2x1x1, - // ClusterShape::ClusterShape_2x2x1, + ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1, + ClusterShape::ClusterShape_1x4x1, ClusterShape::ClusterShape_2x1x1, + ClusterShape::ClusterShape_2x2x1, }; for (auto const& tile_config : tilesSm100) { diff --git a/include/flashinfer/gemm/bf16_gemm_template_sm100.h b/include/flashinfer/gemm/bf16_gemm_template_sm100.h index 990fcd910a..fea3a0f2bd 100644 --- a/include/flashinfer/gemm/bf16_gemm_template_sm100.h +++ b/include/flashinfer/gemm/bf16_gemm_template_sm100.h @@ -151,25 +151,31 @@ size_t genericBf16GemmKernelLauncherSm100(__nv_bfloat16 const* A, __nv_bfloat16 Gemm gemm; - CUTLASS_CHECK(gemm.can_implement(arguments)); + // Return workspace size + if (!A && !B && !D) { + return gemm.get_workspace_size(arguments); + } - size_t workspace_size = gemm.get_workspace_size(arguments); - if (workspace_size > workspaceBytes) { + if (gemm.get_workspace_size(arguments) > workspaceBytes) { throw std::runtime_error("[Bf16 Gemm Runner] insufficient workspace"); } - // NOTE: These can also be simplified using CUTLASS_CHECK. Same goes for some of the other files. - cutlass::Status initStatus = gemm.initialize(arguments, workspacePtr, stream); + auto can_implement = gemm.can_implement(arguments); + if (can_implement != cutlass::Status::kSuccess) { + throw std::runtime_error("[Bf16 Gemm Runner] cutlass kernel not implemented given the params"); + } + + auto initStatus = gemm.initialize(arguments, workspacePtr); if (initStatus != cutlass::Status::kSuccess) { throw std::runtime_error("[Bf16 Gemm Runner] failed to initialize"); } - cutlass::Status runStatus = gemm.run(stream); + auto runStatus = gemm.run(stream); if (runStatus != cutlass::Status::kSuccess) { throw std::runtime_error("[Bf16 Gemm Runner] failed to run"); } - return workspace_size; + return gemm.get_workspace_size(arguments); } } // namespace gemm From 6830dcc391483fdc96c5d5970beb82e1b9e15831 Mon Sep 17 00:00:00 2001 From: raayandhar Date: Sun, 16 Nov 2025 22:10:14 -0800 Subject: [PATCH 08/11] remove extraneous comment Signed-off-by: raayandhar --- include/flashinfer/gemm/bf16_gemm_template_sm100.h | 1 - 1 file changed, 1 deletion(-) diff --git a/include/flashinfer/gemm/bf16_gemm_template_sm100.h b/include/flashinfer/gemm/bf16_gemm_template_sm100.h index fea3a0f2bd..1ba9e773e6 100644 --- a/include/flashinfer/gemm/bf16_gemm_template_sm100.h +++ b/include/flashinfer/gemm/bf16_gemm_template_sm100.h @@ -144,7 +144,6 @@ size_t genericBf16GemmKernelLauncherSm100(__nv_bfloat16 const* A, __nv_bfloat16 stride_B}, {{}, nullptr, stride_C, reinterpret_cast(D), stride_D}}; - // Is not the right way to do this? auto& fusion_args = arguments.epilogue.thread; fusion_args.alpha = 1.0f; fusion_args.beta = 0.0f; From a56d74b067f13b4ed6a94016b50bd796abeef791 Mon Sep 17 00:00:00 2001 From: raayandhar Date: Sun, 16 Nov 2025 22:12:01 -0800 Subject: [PATCH 09/11] add back space Signed-off-by: raayandhar --- include/flashinfer/gemm/fp8_gemm_cutlass_template.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/flashinfer/gemm/fp8_gemm_cutlass_template.h b/include/flashinfer/gemm/fp8_gemm_cutlass_template.h index c346fe53f6..a3c9d0f2c4 100644 --- a/include/flashinfer/gemm/fp8_gemm_cutlass_template.h +++ b/include/flashinfer/gemm/fp8_gemm_cutlass_template.h @@ -68,6 +68,7 @@ size_t dispatchGemmClusterShapeSm100(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const _1SM>( A, B, scale_a, scale_b, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); break; + case ClusterShape::ClusterShape_2x1x1: return genericFp8GemmKernelLauncherSm100, _2SM>( From db00e5151569a160ec19a8af18b07316e7c9e234 Mon Sep 17 00:00:00 2001 From: "Raayan Dhar raayan.dhar@gmail.com" Date: Wed, 26 Nov 2025 20:25:08 -0800 Subject: [PATCH 10/11] fix precommit Signed-off-by: Raayan Dhar raayan.dhar@gmail.com --- flashinfer/gemm/gemm_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index bed4707d55..626d71792c 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -547,8 +547,8 @@ def forward( return SimpleNamespace( cutlass_bf16_gemm_runner=cutlass_bf16_gemm_runner, ) - - + + _FP8_GEMM_SM100_TUNING_CONFIG = TuningConfig( dynamic_tensor_specs=( DynamicTensorSpec( From 323e6fa7c4236dc24142071a5e493926e99cf832 Mon Sep 17 00:00:00 2001 From: "Raayan Dhar raayan.dhar@gmail.com" Date: Wed, 26 Nov 2025 20:31:54 -0800 Subject: [PATCH 11/11] fix docstring Signed-off-by: Raayan Dhar raayan.dhar@gmail.com --- flashinfer/gemm/gemm_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 626d71792c..489ae9755c 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -218,7 +218,7 @@ def mm_bf16( >>> import torch >>> import torch.nn.functional as F >>> import flashinfer - >>> input = torch.randn([48, 64], device="cuda", dtype=torch.bfloat16 + >>> input = torch.randn([48, 64], device="cuda", dtype=torch.bfloat16) >>> weight = torch.randn([80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) >>> out = flashinfer.mm_bf16(input, weight) >>> print(out.shape) @@ -295,7 +295,7 @@ def bmm_bf16( >>> import torch >>> import torch.nn.functional as F >>> import flashinfer - >>> input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16 + >>> input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16) >>> weight = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) >>> out = flashinfer.bmm_bf16(input, weight) >>> print(out.shape)