diff --git a/csrc/bf16_gemm_cutlass.cu b/csrc/bf16_gemm_cutlass.cu new file mode 100644 index 0000000000..f5e3750ad2 --- /dev/null +++ b/csrc/bf16_gemm_cutlass.cu @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2025, FlashInfer. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include + +#include "flashinfer/gemm/bf16_gemm_cutlass.h" +#include "flashinfer/gemm/bf16_gemm_cutlass_template.h" +#include "flashinfer/gemm/cutlass_gemm_configs.h" +#include "tvm_ffi_utils.h" + +using flashinfer::gemm::ClusterShape; +using flashinfer::gemm::CutlassBf16GemmRunner; +using flashinfer::gemm::CutlassBf16GemmRunnerInterface; +using flashinfer::gemm::CutlassGemmConfig; +using flashinfer::gemm::CutlassTileConfigSM100; +using flashinfer::gemm::EpilogueScheduleType; +using flashinfer::gemm::MainloopScheduleType; + +namespace flashinfer { +namespace gemm { +template class CutlassBf16GemmRunner<__nv_bfloat16>; +template class CutlassBf16GemmRunner; +} // namespace gemm +} // namespace flashinfer + +namespace torch_ext { + +namespace { + +CutlassGemmConfig getBf16GemmConfig(int64_t m, int64_t n, int64_t k, int64_t tactic) { + auto getCutlassBf16GemmConfigs = []() { + CutlassBf16GemmRunner<__nv_bfloat16> gemmRunner; + return gemmRunner.getConfigs(); + }; + static std::vector globalConfigs = getCutlassBf16GemmConfigs(); + TVM_FFI_ICHECK(tactic >= 0 && tactic < static_cast(globalConfigs.size())) + << "tactic must be between 0 and " << globalConfigs.size(); + return globalConfigs[tactic]; +} + +template +void runGemm(TensorView out, TensorView mat1, TensorView mat2, int64_t m, int64_t n, int64_t k, + int64_t b, CutlassGemmConfig const& gemmConfig, TensorView workspace_buffer) { + CutlassBf16GemmRunner gemmRunner; + + int64_t const required_workspace_size = gemmRunner.getWorkspaceSize(m, n, k); + int64_t const provided_workspace_size = + workspace_buffer.numel() * get_element_size(workspace_buffer); + + auto runKernel = [&](void* workspace) { + gemmRunner.gemm(static_cast<__nv_bfloat16*>(mat1.data_ptr()), + static_cast<__nv_bfloat16*>(mat2.data_ptr()), out.data_ptr(), m, n, k, b, + gemmConfig, static_cast(workspace), required_workspace_size, + get_stream(mat1.device())); + }; + + if (provided_workspace_size < required_workspace_size) { + Tensor new_workspace = + alloc_tensor({required_workspace_size}, DLDataType{kDLInt, 8, 1}, mat1.device()); + runKernel(new_workspace.data_ptr()); + } else { + runKernel(workspace_buffer.data_ptr()); + } +} + +void bf16_bmm_impl(TensorView mat1, TensorView mat2, TensorView out, TensorView workspace_buffer, + int64_t tactic) { + CHECK_INPUT_AND_TYPE(mat1, dl_bfloat16); + CHECK_INPUT_AND_TYPE(mat2, dl_bfloat16); + + int64_t m, n, k, b; + if (mat1.ndim() == 2) { + TVM_FFI_ICHECK_EQ(mat2.ndim(), 2) << "mat2 must be a matrix"; + TVM_FFI_ICHECK_EQ(mat1.size(1), mat2.size(1)) + << "mat1 and mat2 shapes cannot be multiplied (" << mat1.size(0) << "x" << mat1.size(1) + << " and " << mat2.size(0) << "x" << mat2.size(1) << ")"; + m = mat1.size(0); + n = mat2.size(0); + k = mat2.size(1); + b = 1; + } else if (mat1.ndim() == 3) { + TVM_FFI_ICHECK_EQ(mat2.ndim(), 3) << "mat2 must be a batch of matrices"; + TVM_FFI_ICHECK_EQ(mat1.size(0), mat2.size(0)) << "mat1 and mat2 must have the same batch size (" + << mat1.size(0) << " and " << mat2.size(0) << ")"; + TVM_FFI_ICHECK_EQ(mat1.size(2), mat2.size(2)) + << "mat1 and mat2 shapes cannot be multiplied (" << mat1.size(1) << "x" << mat1.size(2) + << " and " << mat2.size(1) << "x" << mat2.size(2) << ")"; + m = mat1.size(1); + n = mat2.size(1); + k = mat2.size(2); + b = mat1.size(0); + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "mat1 must be a matrix or a batch of matrices"; + } + + if (tactic == -1) { + tactic = 0; + } + auto config = getBf16GemmConfig(m, n, k, tactic); + + std::vector out_shape = + mat1.ndim() == 2 ? std::vector{m, n} : std::vector{b, m, n}; + TVM_FFI_ICHECK_EQ(out.ndim(), static_cast(out_shape.size())) + << "out must have " << out_shape.size() << " dimensions, but got " << out.ndim(); + for (int i = 0; i < static_cast(out_shape.size()); ++i) { + TVM_FFI_ICHECK_EQ(out.size(i), out_shape[i]) + << "out shape mismatch at dimension " << i << ": expected " << out_shape[i] << ", got " + << out.size(i); + } + + switch (encode_dlpack_dtype(out.dtype())) { + case float16_code: + runGemm(out, mat1, mat2, m, n, k, b, config, workspace_buffer); + break; + case bfloat16_code: + runGemm<__nv_bfloat16>(out, mat1, mat2, m, n, k, b, config, workspace_buffer); + break; + default: + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "out_dtype must be one of fp16/bf16."; + } +} + +} // namespace + +void bf16_gemm(TensorView mat1, TensorView mat2, TensorView out, TensorView workspace_buffer, + int64_t tactic) { + bf16_bmm_impl(mat1, mat2, out, workspace_buffer, tactic); +} + +int64_t bf16_gemm_tactic_num() { + auto getCutlassConfigs = []() { + CutlassBf16GemmRunner<__nv_bfloat16> gemmRunner; + return gemmRunner.getConfigs(); + }; + static int64_t totalTactics = getCutlassConfigs().size(); + return totalTactics; +} + +} // namespace torch_ext + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(bf16_gemm, torch_ext::bf16_gemm); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(bf16_gemm_tactic_num, torch_ext::bf16_gemm_tactic_num); diff --git a/csrc/bf16_gemm_cutlass.jinja b/csrc/bf16_gemm_cutlass.jinja new file mode 100644 index 0000000000..0e8a5f0f9f --- /dev/null +++ b/csrc/bf16_gemm_cutlass.jinja @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2025, FlashInfer. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flashinfer/gemm/bf16_gemm_template_sm100.h" + +namespace flashinfer { +namespace gemm { + INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 1, 1, _1SM); + INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 2, 1, _1SM); + INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 4, 1, _1SM); + INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 1, 1, _2SM); + INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 2, 1, _2SM); +} // namespace gemm +} // namespace flashinfer diff --git a/docs/api/gemm.rst b/docs/api/gemm.rst index 8c9fbeeea6..c0c99a7f92 100644 --- a/docs/api/gemm.rst +++ b/docs/api/gemm.rst @@ -7,6 +7,15 @@ flashinfer.gemm This module provides a set of GEMM operations. +BF16 GEMM +--------- + +.. autosummary:: + :toctree: ../generated + + mm_bf16 + bmm_bf16 + FP4 GEMM -------- diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index faad4f12a3..9f4d8de1f4 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -85,7 +85,9 @@ trtllm_fp8_per_tensor_scale_moe, ) from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper +from .gemm import bmm_bf16 as bmm_bf16 from .gemm import bmm_fp8 as bmm_fp8 +from .gemm import mm_bf16 as mm_bf16 from .gemm import mm_fp4 as mm_fp4 from .gemm import mm_fp8 as mm_fp8 from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100 diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index 15652268ba..f84a71e920 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -1,5 +1,7 @@ from .gemm_base import SegmentGEMMWrapper as SegmentGEMMWrapper +from .gemm_base import bmm_bf16 as bmm_bf16 from .gemm_base import bmm_fp8 as bmm_fp8 +from .gemm_base import mm_bf16 as mm_bf16 from .gemm_base import mm_fp4 as mm_fp4 from .gemm_base import mm_fp8 as mm_fp8 from .gemm_base import tgv_gemm_sm100 as tgv_gemm_sm100 @@ -20,7 +22,9 @@ __all__ = [ "SegmentGEMMWrapper", + "bmm_bf16", "bmm_fp8", + "mm_bf16", "mm_fp4", "mm_fp8", "tgv_gemm_sm100", diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 15b26f02ee..489ae9755c 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -52,6 +52,7 @@ from ..jit.gemm import gen_gemm_sm120_module_cutlass_fp4 from ..jit.gemm import gen_gemm_sm100_module_cutlass_fp4 from ..jit.gemm import gen_gemm_sm100_module_cutlass_fp8 +from ..jit.gemm import gen_gemm_sm100_module_cutlass_bf16 from ..jit.gemm import gen_trtllm_gen_gemm_module from ..jit.gemm import gen_tgv_gemm_sm10x_module from ..jit.gemm import gen_deepgemm_sm100_module @@ -180,6 +181,161 @@ def _fake_cutlass_segment_gemm( return _gemm_module +@supported_compute_capability([100]) +def mm_bf16( + a: torch.Tensor, + b: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + backend: Literal["cutlass"] = "cutlass", +) -> torch.Tensor: + r"""MM BF16 + + Parameters + ---------- + a: torch.Tensor + Input tensor, shape (m, k), bf16. + + b: torch.Tensor + Weight tensor, shape (k, n), bf16. + + out: Optional[torch.Tensor] + Out tensor, shape (m, n), bf16 or fp16, defaults to ``None``. + + out_dtype: torch.dtype + Output dtype, bf16 (default) or fp16. + + backend: Literal["cutlass"] + Backend to use, defaults to "cutlass". + + Returns + ------- + torch.Tensor + Out tensor, shape (m, n), bf16 or fp16. + + Examples + -------- + >>> import torch + >>> import torch.nn.functional as F + >>> import flashinfer + >>> input = torch.randn([48, 64], device="cuda", dtype=torch.bfloat16) + >>> weight = torch.randn([80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) + >>> out = flashinfer.mm_bf16(input, weight) + >>> print(out.shape) + torch.Size([48, 80]) + >>> out.dtype + torch.bfloat16 + """ + if backend != "cutlass": + raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.") + if out_dtype not in (torch.bfloat16, torch.float16): + raise ValueError("Only bf16 and fp16 outputs are supported.") + + if out is None: + out = torch.empty( + (a.shape[0], b.shape[1]), + device=a.device, + dtype=out_dtype, + ) + else: + if out.shape != (a.shape[0], b.shape[1]): + raise ValueError( + f"Output shape mismatch. Expected {(a.shape[0], b.shape[1])}, got {out.shape}." + ) + if out.device != a.device: + raise ValueError( + f"Output device mismatch. Expected {a.device}, got {out.device}." + ) + if out.dtype != out_dtype: + raise ValueError( + f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}." + ) + + workspace_buffer = _get_cache_buf( + "mm_bf16_workspace", DEFAULT_WORKSPACE_SIZE, a.device + ) + bf16_gemm_sm100(a, b, out, workspace_buffer) + return out + + +@supported_compute_capability([100]) +def bmm_bf16( + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + backend: Literal["cutlass"] = "cutlass", +) -> torch.Tensor: + r"""BMM BF16 + + Parameters + ---------- + A: torch.Tensor + Input tensor, shape (b, m, k), bf16. + + B: torch.Tensor + Weight tensor, shape (b, k, n), bf16. + + out: Optional[torch.Tensor] + Out tensor, shape (b, m, n), bf16 or fp16, defaults to ``None``. + + out_dtype: torch.dtype + Output dtype, bf16 (default) or fp16. + + backend: Literal["cutlass"] + Backend to use, defaults to "cutlass". + + Returns + ------- + torch.Tensor + Out tensor, shape (b, m, n), bf16 or fp16. + + Examples + -------- + >>> import torch + >>> import torch.nn.functional as F + >>> import flashinfer + >>> input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16) + >>> weight = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) + >>> out = flashinfer.bmm_bf16(input, weight) + >>> print(out.shape) + torch.Size([16, 48, 80]) + >>> out.dtype + torch.bfloat16 + """ + if backend != "cutlass": + raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.") + if out_dtype not in (torch.bfloat16, torch.float16): + raise ValueError("Only bf16 and fp16 outputs are supported.") + + expected_shape = (A.shape[0], A.shape[1], B.shape[2]) + if out is None: + out = torch.empty( + expected_shape, + device=A.device, + dtype=out_dtype, + ) + else: + if out.shape != expected_shape: + raise ValueError( + f"Output shape mismatch. Expected {expected_shape}, got {out.shape}." + ) + if out.device != A.device: + raise ValueError( + f"Output device mismatch. Expected {A.device}, got {out.device}." + ) + if out.dtype != out_dtype: + raise ValueError( + f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}." + ) + + workspace_buffer = _get_cache_buf( + "bmm_bf16_workspace", DEFAULT_WORKSPACE_SIZE, A.device + ) + bf16_gemm_sm100(A, B, out, workspace_buffer) + return out + + @functools.cache def get_gemm_sm100_module(): module = gen_gemm_sm100_module().build_and_load() @@ -356,6 +512,43 @@ def forward( ) +@functools.cache +def get_gemm_sm100_module_cutlass_bf16(): + module = gen_gemm_sm100_module_cutlass_bf16().build_and_load() + + def cutlass_bf16_gemm_runner(): + class CutlassBf16GemmRunner(TunableRunner): + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + return list(range(module.bf16_gemm_tactic_num())) + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ) -> torch.Tensor: + a, b, out, workspace_buffer = inputs + module.bf16_gemm( + a, + b.transpose(-2, -1), + out, + workspace_buffer, + tactic, + ) + return out + + return CutlassBf16GemmRunner() + + return SimpleNamespace( + cutlass_bf16_gemm_runner=cutlass_bf16_gemm_runner, + ) + + _FP8_GEMM_SM100_TUNING_CONFIG = TuningConfig( dynamic_tensor_specs=( DynamicTensorSpec( @@ -407,6 +600,47 @@ def fp8_gemm_sm100( runner(inputs=inputs, tactic=tactic) +def bf16_gemm_sm100( + a: torch.Tensor, + b: torch.Tensor, + out: torch.Tensor, + workspace_buffer: torch.Tensor, +) -> None: + runners = [] + if _match_sm_version(a.device, ["100"]): + runners.append(get_gemm_sm100_module_cutlass_bf16().cutlass_bf16_gemm_runner()) + assert runners, "No suitable runners found" + + tuner = AutoTuner.get() + a_tensor_index = 0 + out_tensor_index = 2 + tuning_config = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + (a_tensor_index,), + (-2,), + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, + ), + ), + constraint_specs=( + ConstraintSpec( + out_tensor_index, -2, lambda shapes: shapes[a_tensor_index][-2] + ), + ), + ) + + inputs = [a, b, out, workspace_buffer] + runner, tactic = tuner.choose_one( + "bf16_gemm", + runners, + tuning_config, + inputs, + ) + + runner(inputs=inputs, tactic=tactic) + + def _create_cutlass_fp4_gemm_module(module, op_name: str, tuner_name: str): """Helper function to create cutlass FP4 GEMM module.""" diff --git a/flashinfer/jit/gemm/__init__.py b/flashinfer/jit/gemm/__init__.py index f1681d3bf5..617327c451 100644 --- a/flashinfer/jit/gemm/__init__.py +++ b/flashinfer/jit/gemm/__init__.py @@ -19,6 +19,7 @@ gen_gemm_sm100_module_cutlass_fp4, gen_gemm_sm120_module_cutlass_fp4, gen_gemm_sm100_module_cutlass_fp8, + gen_gemm_sm100_module_cutlass_bf16, gen_gemm_sm100_module, gen_gemm_sm120_module, gen_trtllm_gen_gemm_module, @@ -33,6 +34,7 @@ "gen_gemm_sm100_module_cutlass_fp4", "gen_gemm_sm120_module_cutlass_fp4", "gen_gemm_sm100_module_cutlass_fp8", + "gen_gemm_sm100_module_cutlass_bf16", "gen_gemm_sm100_module", "gen_gemm_sm120_module", "gen_trtllm_gen_gemm_module", diff --git a/flashinfer/jit/gemm/core.py b/flashinfer/jit/gemm/core.py index 7873d0de14..5d40b510ac 100644 --- a/flashinfer/jit/gemm/core.py +++ b/flashinfer/jit/gemm/core.py @@ -190,6 +190,52 @@ def gen_gemm_sm100_module_cutlass_fp8() -> JitSpec: ) +def gen_gemm_sm100_module_cutlass_bf16() -> JitSpec: + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100_cutlass_bf16" + os.makedirs(gen_directory, exist_ok=True) + source_paths = [ + jit_env.FLASHINFER_CSRC_DIR / "bf16_gemm_cutlass.cu", + ] + + with open(jit_env.FLASHINFER_CSRC_DIR / "bf16_gemm_cutlass.jinja") as f: + kernel_inst_templ = jinja2.Template(f.read()) + dtype_list = ["__nv_bfloat16", "half"] + cta_m_n_k_list = [ + (64, 64, 128), + (64, 128, 128), + (64, 256, 128), + (128, 64, 128), + (128, 128, 128), + ] + for cta_m, cta_n, cta_k in cta_m_n_k_list: + for dtype in dtype_list: + dest_path = ( + gen_directory + / f"bf16_gemm_cutlass_{dtype}_{cta_m}_{cta_n}_{cta_k}.cu" + ) + source_paths.append(dest_path) + source = kernel_inst_templ.render( + type=dtype, + cta_m=cta_m, + cta_n=cta_n, + cta_k=cta_k, + ) + write_if_different(dest_path, source) + + nvcc_flags = current_compilation_context.get_nvcc_flags_list( + supported_major_versions=[10, 11, 12] + ) + + return gen_jit_spec( + "bf16_gemm_cutlass", + source_paths, + extra_cuda_cflags=nvcc_flags + ["-DENABLE_BF16"], + extra_cflags=[ + "-DFAST_BUILD", + ], + ) + + def gen_gemm_sm100_module() -> JitSpec: gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100" os.makedirs(gen_directory, exist_ok=True) diff --git a/include/flashinfer/gemm/bf16_gemm_cutlass.h b/include/flashinfer/gemm/bf16_gemm_cutlass.h new file mode 100644 index 0000000000..6011075a19 --- /dev/null +++ b/include/flashinfer/gemm/bf16_gemm_cutlass.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2025, FlashInfer. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef FLASHINFER_BF16_GEMM_CUTLASS_H_ +#define FLASHINFER_BF16_GEMM_CUTLASS_H_ + +#include + +#include + +#include "flashinfer/gemm/cutlass_gemm_configs.h" + +namespace flashinfer { +namespace gemm { + +class CutlassBf16GemmRunnerInterface { + public: + CutlassBf16GemmRunnerInterface() = default; + virtual ~CutlassBf16GemmRunnerInterface() = default; + + virtual void gemm(__nv_bfloat16 const* A, __nv_bfloat16 const* B, void* D, int m, int n, int k, + int b, CutlassGemmConfig gemmConfig, char* workspacePtr, + size_t const workspaceBytes, cudaStream_t stream) = 0; + + virtual size_t getWorkspaceSize(int m, int n, int k) = 0; + + virtual std::vector getConfigs() const = 0; +}; + +template +class CutlassBf16GemmRunner : public CutlassBf16GemmRunnerInterface { + public: + CutlassBf16GemmRunner() = default; + ~CutlassBf16GemmRunner() = default; + + void gemm(__nv_bfloat16 const* A, __nv_bfloat16 const* B, void* D, int m, int n, int k, int b, + CutlassGemmConfig gemmConfig, char* workspacePtr, size_t const workspaceBytes, + cudaStream_t stream) override; + size_t getWorkspaceSize(int m, int n, int k) override; + std::vector getConfigs() const override; + + private: + size_t getWorkspaceSizeImpl(int m, int n, int k); +}; + +} // namespace gemm +} // namespace flashinfer + +#endif // FLASHINFER_BF16_GEMM_CUTLASS_H_ diff --git a/include/flashinfer/gemm/bf16_gemm_cutlass_template.h b/include/flashinfer/gemm/bf16_gemm_cutlass_template.h new file mode 100644 index 0000000000..f73ea1bde2 --- /dev/null +++ b/include/flashinfer/gemm/bf16_gemm_cutlass_template.h @@ -0,0 +1,207 @@ +/* + * Copyright (c) 2025, FlashInfer. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_BF16_GEMM_CUTLASS_TEMPLATE_H_ +#define FLASHINFER_BF16_GEMM_CUTLASS_TEMPLATE_H_ + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/gemm.h" +#include "flashinfer/arch_condition.h" +#include "flashinfer/cutlass_utils.cuh" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC__ + +#include +#include + +#include "cutlass/bfloat16.h" + +namespace flashinfer { +namespace gemm { + +struct _1SM {}; +struct _2SM {}; + +template +size_t genericBf16GemmKernelLauncherSm100(__nv_bfloat16 const* A, __nv_bfloat16 const* B, T* D, + int m, int n, int k, int b, CutlassGemmConfig config, + char* workspacePtr, size_t const workspaceBytes, + cudaStream_t stream); + +template +size_t dispatchGemmClusterShapeSm100(__nv_bfloat16 const* A, __nv_bfloat16 const* B, T* D, int m, + int n, int k, int b, CutlassGemmConfig gemmConfig, + char* workspacePtr, size_t const workspaceBytes, + cudaStream_t stream) { + using namespace cute; + + switch (gemmConfig.cluster_shape) { + case ClusterShape::ClusterShape_1x1x1: + return genericBf16GemmKernelLauncherSm100, + _1SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr, + workspaceBytes, stream); + break; + case ClusterShape::ClusterShape_1x2x1: + return genericBf16GemmKernelLauncherSm100, + _1SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr, + workspaceBytes, stream); + break; + case ClusterShape::ClusterShape_1x4x1: + return genericBf16GemmKernelLauncherSm100, + _1SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr, + workspaceBytes, stream); + break; + case ClusterShape::ClusterShape_2x1x1: + return genericBf16GemmKernelLauncherSm100, + _2SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr, + workspaceBytes, stream); + break; + case ClusterShape::ClusterShape_2x2x1: + return genericBf16GemmKernelLauncherSm100, + _2SM>(A, B, D, m, n, k, b, gemmConfig, workspacePtr, + workspaceBytes, stream); + break; + default: + throw std::runtime_error("invalid config for bf16 gemm"); + break; + } +} + +template +size_t dispatchToArch(__nv_bfloat16 const* A, __nv_bfloat16 const* B, void* D, int m, int n, int k, + int b, CutlassGemmConfig gemmConfig, char* workspacePtr, + size_t const workspaceBytes, cudaStream_t stream) { + using arch = cutlass::arch::Sm100; + + switch (gemmConfig.tile_config_sm100) { + case CutlassTileConfigSM100::CtaShape64x64x128B: + return dispatchGemmClusterShapeSm100( + B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); + break; + case CutlassTileConfigSM100::CtaShape64x128x128B: + return dispatchGemmClusterShapeSm100( + B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); + break; + case CutlassTileConfigSM100::CtaShape64x256x128B: + return dispatchGemmClusterShapeSm100( + B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); + break; + case CutlassTileConfigSM100::CtaShape128x64x128B: + return dispatchGemmClusterShapeSm100( + B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); + break; + case CutlassTileConfigSM100::CtaShape128x128x128B: + return dispatchGemmClusterShapeSm100( + B, A, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); + break; + + default: + throw std::runtime_error("unsupported tile config for bf16 gemm"); + break; + } +} + +template +void CutlassBf16GemmRunner::gemm(__nv_bfloat16 const* A, __nv_bfloat16 const* B, void* D, int m, + int n, int k, int b, CutlassGemmConfig gemmConfig, + char* workspacePtr, size_t const workspaceBytes, + cudaStream_t stream) { + dispatchToArch(A, B, reinterpret_cast(D), m, n, k, b, gemmConfig, workspacePtr, + workspaceBytes, stream); +} + +template +size_t CutlassBf16GemmRunner::getWorkspaceSizeImpl(int m, int n, int k) { + size_t workspace_size = 0; + auto gemmConfigs = CutlassBf16GemmRunner{}.getConfigs(); + for (auto const& gemmConfig : gemmConfigs) { + try { + size_t curr_workspace_size = + dispatchToArch(nullptr, nullptr, nullptr, m, n, k, 1, gemmConfig, nullptr, 0, nullptr); + workspace_size = std::max(workspace_size, curr_workspace_size); + } catch (std::runtime_error&) { + // Swallow errors when SMEM exceeds maximum allowed + continue; + } + } + return workspace_size; +} + +template +size_t CutlassBf16GemmRunner::getWorkspaceSize(int m, int n, int k) { + using MNK = std::tuple; + + struct MNKHash { + size_t operator()(const MNK& mnk) const { + auto h1 = std::hash{}(std::get<0>(mnk)); + auto h2 = std::hash{}(std::get<1>(mnk)); + auto h3 = std::hash{}(std::get<2>(mnk)); + return h1 ^ h2 ^ h3; + } + }; + + static std::unordered_map workspace_hashmap; + + size_t workspace_size = 0; + if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) { + workspace_size = CutlassBf16GemmRunner::getWorkspaceSizeImpl(m, n, k); + workspace_hashmap[std::make_tuple(m, n, k)] = workspace_size; + } else { + workspace_size = workspace_hashmap[std::make_tuple(m, n, k)]; + } + return workspace_size; +} + +template +std::vector CutlassBf16GemmRunner::getConfigs() const { + std::vector candidate_configs; + + std::vector tilesSm100 = { + CutlassTileConfigSM100::CtaShape64x64x128B, CutlassTileConfigSM100::CtaShape64x128x128B, + CutlassTileConfigSM100::CtaShape64x256x128B, CutlassTileConfigSM100::CtaShape128x64x128B, + CutlassTileConfigSM100::CtaShape128x128x128B, + }; + + std::vector clusterShapes = { + ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1, + ClusterShape::ClusterShape_1x4x1, ClusterShape::ClusterShape_2x1x1, + ClusterShape::ClusterShape_2x2x1, + }; + + for (auto const& tile_config : tilesSm100) { + for (auto const& cluster_config : clusterShapes) { + CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, + cluster_config); + candidate_configs.push_back(config); + } + } + return candidate_configs; +} + +} // namespace gemm +} // namespace flashinfer + +#endif // FLASHINFER_BF16_GEMM_CUTLASS_TEMPLATE_H_ diff --git a/include/flashinfer/gemm/bf16_gemm_template_sm100.h b/include/flashinfer/gemm/bf16_gemm_template_sm100.h new file mode 100644 index 0000000000..1ba9e773e6 --- /dev/null +++ b/include/flashinfer/gemm/bf16_gemm_template_sm100.h @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2025, FlashInfer. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_BF16_GEMM_TEMPLATE_SM100_H_ +#define FLASHINFER_BF16_GEMM_TEMPLATE_SM100_H_ + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/numeric_conversion.h" +#include "flashinfer/arch_condition.h" +#include "flashinfer/cutlass_utils.cuh" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC__ + +#include +#include + +#include "cutlass/bfloat16.h" +#include "flashinfer/gemm/cutlass_gemm_configs.h" + +namespace flashinfer { +namespace gemm { + +template +struct SMTypeAdapter {}; + +struct _1SM; +struct _2SM; + +template <> +struct SMTypeAdapter<_1SM> { + static int const Scale = 1; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; +}; + +template <> +struct SMTypeAdapter<_2SM> { + static int const Scale = 2; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmSm100; +}; + +template +size_t genericBf16GemmKernelLauncherSm100(__nv_bfloat16 const* A, __nv_bfloat16 const* B, T* D, + int m, int n, int k, int b, CutlassGemmConfig config, + char* workspacePtr, size_t const workspaceBytes, + cudaStream_t stream) { + using namespace cute; + + using ElementA = cutlass::bfloat16_t; + using LayoutA = cutlass::layout::RowMajor; + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = cutlass::bfloat16_t; + using LayoutB = cutlass::layout::ColumnMajor; + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementOutput_ = + typename cutlass::platform::conditional::value, + cutlass::half_t, T>::type; +#ifdef ENABLE_BF16 + using ElementOutput = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, cutlass::bfloat16_t, + ElementOutput_>::type; +#else + using ElementOutput = ElementOutput_; +#endif + + using ElementC = ElementOutput; + using LayoutC = cutlass::layout::ColumnMajor; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ElementD = ElementC; + using LayoutD = LayoutC; + constexpr int AlignmentD = AlignmentC; + + using ElementAccumulator = float; + using ElementCompute = float; + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape::Scale>, cute::Int, + cute::Int>; + + using ClusterShape = ClusterShape_; + using EpilogueSchedule = typename SMTypeAdapter::EpilogueSchedule; + using MainloopSchedule = typename SMTypeAdapter::MainloopSchedule; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, + ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal, + CollectiveMainloop, CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, b)); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, b)); + auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, b)); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, b)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, b}, + {reinterpret_cast(A), stride_A, reinterpret_cast(B), + stride_B}, + {{}, nullptr, stride_C, reinterpret_cast(D), stride_D}}; + + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha = 1.0f; + fusion_args.beta = 0.0f; + + Gemm gemm; + + // Return workspace size + if (!A && !B && !D) { + return gemm.get_workspace_size(arguments); + } + + if (gemm.get_workspace_size(arguments) > workspaceBytes) { + throw std::runtime_error("[Bf16 Gemm Runner] insufficient workspace"); + } + + auto can_implement = gemm.can_implement(arguments); + if (can_implement != cutlass::Status::kSuccess) { + throw std::runtime_error("[Bf16 Gemm Runner] cutlass kernel not implemented given the params"); + } + + auto initStatus = gemm.initialize(arguments, workspacePtr); + if (initStatus != cutlass::Status::kSuccess) { + throw std::runtime_error("[Bf16 Gemm Runner] failed to initialize"); + } + + auto runStatus = gemm.run(stream); + if (runStatus != cutlass::Status::kSuccess) { + throw std::runtime_error("[Bf16 Gemm Runner] failed to run"); + } + + return gemm.get_workspace_size(arguments); +} + +} // namespace gemm +} // namespace flashinfer + +#define INSTANCE_BF16_GEMM_TEMPLATE_SM100(RET_TYPE, TILE_M, TILE_N, TILE_K, CGA_M_, CGA_N_, \ + CGA_K_, SM_TYPE) \ + template size_t genericBf16GemmKernelLauncherSm100< \ + RET_TYPE, cutlass::arch::Sm100, TILE_M, TILE_N, TILE_K, \ + cute::Shape, cute::Int, cute::Int>, SM_TYPE>( \ + __nv_bfloat16 const* A, __nv_bfloat16 const* B, RET_TYPE* D, int m, int n, int k, int b, \ + CutlassGemmConfig config, char* workspacePtr, size_t const workspaceBytes, \ + cudaStream_t stream); + +#endif // FLASHINFER_BF16_GEMM_TEMPLATE_SM100_H_ diff --git a/tests/gemm/test_bmm_bf16.py b/tests/gemm/test_bmm_bf16.py new file mode 100644 index 0000000000..b6b47e5860 --- /dev/null +++ b/tests/gemm/test_bmm_bf16.py @@ -0,0 +1,37 @@ +import pytest +import torch +import torch.nn.functional as F + +from flashinfer import autotune, bmm_bf16 +from flashinfer.utils import get_compute_capability + + +@pytest.mark.parametrize("b", [1, 16]) +@pytest.mark.parametrize("m", [48, 128]) +@pytest.mark.parametrize("n", [80, 64]) +@pytest.mark.parametrize("k", [64, 256]) +@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) +def test_bmm_bf16(b, m, n, k, res_dtype): + compute_capability = get_compute_capability(torch.device(device="cuda")) + compute_capability_number = compute_capability[0] * 10 + compute_capability[1] + if not bmm_bf16.is_compute_capability_supported(compute_capability_number): + pytest.skip( + f"bmm_bf16 requires one of the following compute capabilities: " + f"{sorted(bmm_bf16._supported_ccs)}. " + f"Detected sm{compute_capability_number}." + ) + torch.manual_seed(7) + input = torch.randn([b, m, k], device="cuda", dtype=torch.bfloat16) + mat2 = torch.randn([b, n, k], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) + reference = torch.bmm(input, mat2) + + out = torch.empty([b, m, n], device="cuda", dtype=res_dtype) + with autotune(): + bmm_bf16(input, mat2, out=out, out_dtype=res_dtype) + + cos_sim = F.cosine_similarity(reference.reshape(-1), out.reshape(-1), dim=0) + assert cos_sim > 0.99 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/gemm/test_mm_bf16.py b/tests/gemm/test_mm_bf16.py new file mode 100644 index 0000000000..6ccd7518f4 --- /dev/null +++ b/tests/gemm/test_mm_bf16.py @@ -0,0 +1,38 @@ +import pytest +import torch +import torch.nn.functional as F + +from flashinfer import autotune, mm_bf16 +from flashinfer.utils import get_compute_capability + + +@pytest.mark.parametrize("m", [1, 48, 128, 256, 512]) +@pytest.mark.parametrize("n", [128, 256, 512]) +@pytest.mark.parametrize("k", [128, 256, 512]) +@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) +def test_mm_bf16(m: int, n: int, k: int, res_dtype: torch.dtype): + compute_capability = get_compute_capability(torch.device(device="cuda")) + compute_capability_number = compute_capability[0] * 10 + compute_capability[1] + if not mm_bf16.is_compute_capability_supported(compute_capability_number): + pytest.skip( + f"mm_bf16 requires one of the following compute capabilities: " + f"{sorted(mm_bf16._supported_ccs)}. " + f"Detected sm{compute_capability_number}." + ) + + torch.manual_seed(42) + input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) + mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) + + reference = torch.mm(input, mat2.T) + + out = torch.empty([m, n], device="cuda", dtype=res_dtype) + with autotune(): + mm_bf16(input, mat2.T, out=out, out_dtype=res_dtype) + + cos_sim = F.cosine_similarity(reference.reshape(-1), out.reshape(-1), dim=0) + assert cos_sim > 0.99 + + +if __name__ == "__main__": + pytest.main([__file__])