diff --git a/csrc/fp8_blockscale_gemm_sm90_binding.cu b/csrc/fp8_blockscale_gemm_sm90_binding.cu new file mode 100755 index 0000000000..426100f7e4 --- /dev/null +++ b/csrc/fp8_blockscale_gemm_sm90_binding.cu @@ -0,0 +1,157 @@ + +#include +#include "tvm_ffi_utils.h" +#include "tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h" + +#include +#include +#include +#include +#include + +namespace kernels = tensorrt_llm::kernels::fp8_blockscale_gemm; + +using tvm::ffi::Function; +using tvm::ffi::Optional; +using tvm::ffi::TensorView; + +#ifdef FLASHINFER_ENABLE_FP8_E4M3 +inline bool is_fp8_e4m3fn(DLDataType dtype) { + return encode_dlpack_dtype(dtype) == float8_e4m3fn_code; +} +#else +inline bool is_fp8_e4m3fn(DLDataType) { return false; } +#endif + +/** + * @brief FP8 Block-Scale GEMM binding for SM90 + * + * Supports: + * - BF16 + BF16 → BF16 + * - BF16 + FP8 → BF16 + * + * @note Output is BF16 + */ +class Fp8BlockScaleGemmRunner : public tvm::ffi::ModuleObj { + public: + Fp8BlockScaleGemmRunner() { + // Instantiate runners + runner_bf16_bf16_ = std::make_unique>(); + + runner_bf16_fp8_ = std::make_unique>(); + } + + ~Fp8BlockScaleGemmRunner() = default; + + const char* type_key() const { return "flashinfer.Fp8BlockScaleGemmRunner"; } + const char* kind() const final { return "fp8_blockscale_gemm_runner"; } + + Optional GetFunction(const tvm::ffi::String& name) { + if (name == "gemm") { + return Function::FromTyped( + [this](TensorView input, TensorView weight, TensorView output, + Optional scales_a, Optional scales_b) { + runGemm(input, weight, output, scales_a, scales_b); + }); + } else if (name == "get_workspace_size") { + return Function::FromTyped( + [this](int64_t shape_m, int64_t shape_n, int64_t shape_k) -> int64_t { + return getWorkspaceSize(shape_m, shape_n, shape_k); + }); + } else if (name == "configure_workspace") { + return Function::FromTyped([this](TensorView workspace) { + configureWorkspace(workspace); + }); + } + return Function(nullptr); + } + + private: + /** + * @brief Runtime dtype dispatch + */ + kernels::CutlassFp8BlockScaleGemmRunnerInterface* selectRunner( + bool input_is_fp8, bool weight_is_fp8) { + + if (!input_is_fp8 && !weight_is_fp8) { + return runner_bf16_bf16_.get(); + } else if (!input_is_fp8 && weight_is_fp8) { + return runner_bf16_fp8_.get(); + } else { + return nullptr; + } + } + + void runGemm(const TensorView& input, const TensorView& weight, const TensorView& output, + const Optional& scales_a, const Optional& scales_b) { + auto stream = get_stream(input.device()); + + auto input_ptr = input.data_ptr(); + auto weight_ptr = weight.data_ptr(); + auto output_ptr = output.data_ptr(); + + int shape_m = input.size(0); + int shape_k = input.size(1); + int shape_n = weight.size(0); + + TVM_FFI_ICHECK(input_ptr != nullptr) << "input is null"; + TVM_FFI_ICHECK(weight_ptr != nullptr) << "weight is null"; + TVM_FFI_ICHECK(output_ptr != nullptr) << "output is null"; + TVM_FFI_ICHECK(shape_k == weight.size(1)) << "K dimension mismatch"; + + // Determine dtypes for runner selection + bool input_is_fp8 = is_fp8_e4m3fn(input.dtype()); + bool weight_is_fp8 = is_fp8_e4m3fn(weight.dtype()); + + // Extract scale pointers + float const* scales_a_ptr = scales_a.has_value() + ? reinterpret_cast(scales_a.value().data_ptr()) + : nullptr; + float const* scales_b_ptr = scales_b.has_value() + ? reinterpret_cast(scales_b.value().data_ptr()) + : nullptr; + + // Select appropriate runner + auto* runner = selectRunner(input_is_fp8, weight_is_fp8); + TVM_FFI_ICHECK(runner != nullptr) << "Unsupported dtype combination"; + TVM_FFI_ICHECK(workspace_ != nullptr) << "Workspace not configured. Call configure_workspace first."; + + runner->gemm(output_ptr, input_ptr, weight_ptr, shape_m, shape_n, shape_k, + stream, scales_a_ptr, scales_b_ptr); + } + + int64_t getWorkspaceSize(int64_t shape_m, int64_t shape_n, int64_t shape_k) { + size_t max_size = 0; + + max_size = std::max(max_size, + runner_bf16_bf16_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1)); + max_size = std::max(max_size, + runner_bf16_fp8_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1)); + + return max_size; + } + + void configureWorkspace(const TensorView& workspace) { + auto workspace_ptr = reinterpret_cast(workspace.data_ptr()); + workspace_ = workspace_ptr; + + runner_bf16_bf16_->configureWorkspace(workspace_ptr); + runner_bf16_fp8_->configureWorkspace(workspace_ptr); + } + + std::unique_ptr> runner_bf16_bf16_; + std::unique_ptr> runner_bf16_fp8_; + + char* workspace_ = nullptr; +}; + +tvm::ffi::Module init() { + auto ptr = tvm::ffi::make_object(); + return tvm::ffi::Module(ptr); +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(init, init); \ No newline at end of file diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index 15652268ba..7389eb5b1b 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -13,6 +13,12 @@ from .gemm_base import gemm_fp8_nt_blockscaled as gemm_fp8_nt_blockscaled from .gemm_base import gemm_fp8_nt_groupwise as gemm_fp8_nt_groupwise from .gemm_base import group_gemm_fp8_nt_groupwise as group_gemm_fp8_nt_groupwise +from .gemm_base import ( + get_fp8_blockscale_gemm_runner_sm90 as get_fp8_blockscale_gemm_runner_sm90, +) +from .gemm_base import ( + fp8_blockscale_gemm_swapab as fp8_blockscale_gemm_swapab, +) from .routergemm_dsv3 import ( mm_M1_16_K7168_N256 as mm_M1_16_K7168_N256, @@ -30,5 +36,7 @@ "gemm_fp8_nt_blockscaled", "gemm_fp8_nt_groupwise", "group_gemm_fp8_nt_groupwise", + "get_fp8_blockscale_gemm_runner_sm90", + "fp8_blockscale_gemm_swapab", "mm_M1_16_K7168_N256", ] diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 589c651aca..954f005d8a 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -54,6 +54,7 @@ from ..jit.gemm import gen_trtllm_gen_gemm_module from ..jit.gemm import gen_tgv_gemm_sm10x_module from ..jit.gemm import gen_deepgemm_sm100_module +from ..jit.gemm import gen_fp8_blockscale_gemm_sm90_module from ..jit.cpp_ext import get_cuda_version @@ -3243,3 +3244,234 @@ def batch_deepgemm_fp8_nt_groupwise( ) return out + +@functools.cache +def get_fp8_blockscale_gemm_runner_sm90(): + """Get the FP8 block scale GEMM runner module for SM90.""" + return gen_fp8_blockscale_gemm_sm90_module().build_and_load().init() + + +def fp8_blockscale_gemm_swapab( + input: torch.Tensor, + weight: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + weight_scale: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """ + Perform FP8 block-scaled GEMM with automatic swapAB optimization. + This function automatically selects between normal and swapAB kernel based on + the M dimension. For small M (< 32), it uses the swapAB kernel for + better performance. + + Supported Dtype Combinations + ----------------------------- + - **BF16 + BF16 → BF16**: Both inputs BF16, internal quantization (no scales needed) + - **BF16 + FP8 → BF16**: BF16 input, FP8 weight + + Parameters + ---------- + input : torch.Tensor + Input activation tensor of shape (M, K). + - BF16 (torch.bfloat16) with internal quantization + weight : torch.Tensor + Weight tensor of shape (N, K). Can be: + - FP8 (torch.float8_e4m3fn) with weight_scale required + - BF16 (torch.bfloat16) for internal quantization + input_scale : torch.Tensor, optional + weight_scale : torch.Tensor, optional + Scaling factors for weight. Required if weight is FP8. + out : torch.Tensor, optional + Output tensor of shape (M, N). If None, will be allocated. + out_dtype : torch.dtype, optional + Output data type. Default is torch.bfloat16. + Returns + ------- + torch.Tensor + Output tensor of shape (M, N) with dtype `out_dtype`. + Examples + -------- + >>> import torch + >>> from flashinfer.gemm import fp8_blockscale_gemm_swapab + >>> + >>> M, N, K = 16, 4096, 4096 + >>> device = "cuda" + >>> + >>> # BF16 inputs + >>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16) + >>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16) + >>> output = fp8_blockscale_gemm_swapab(input_bf16, weight_bf16) + >>> print(output.shape) # torch.Size([16, 4096]) + >>> + >>> # Mixed: BF16 input + FP8 weight + >>> from flashinfer.testing.utils import per_token_cast_to_fp8 + >>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16) + >>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16) + >>> weight_fp8, weight_scale = per_token_cast_to_fp8(weight_bf16) + >>> output = fp8_blockscale_gemm_swapab(input_bf16, weight_fp8, None, weight_scale) + >>> print(output.shape) # torch.Size([16, 4096]) + >>> + >>> # FP8 weight with 128x128 block scales + >>> from flashinfer.testing.utils import per_block_cast_to_fp8 + >>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16) + >>> weight_fp8, weight_scale = per_block_cast_to_fp8(weight_bf16) + >>> # weight_scale has shape (N // 128, K // 128) + >>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16) + >>> output = fp8_blockscale_gemm_swapab(input_bf16, weight_fp8, None, weight_scale) + >>> print(output.shape) # torch.Size([16, 4096]) + Notes + ----- + - This function requires NVIDIA Hopper (SM90) architecture and CUDA 12.8+ + - SwapAB kernel is automatically used when M < 32 (threshold) + - For FP8 inputs, scaling factors must be provided + - For BF16 inputs, quantization and scaling happen internally + - Weight scales support two granularities: + * Per-token (1x128 blocks): (N, K//128) + * Per-block (128x128 blocks): (N//128, K//128) + - Input scales only support per-token format: (M, K//128) + - The function uses DeepGEMM backend with JIT compilation + """ + # Validate architecture support + if not _match_sm_version(input.device, ["90", "90a"]): + raise ValueError( + "fp8_blockscale_gemm_swapab is only supported on SM90 (Hopper) architecture." + ) + + # Validate tensor dimensions + if input.ndim != 2: + raise ValueError(f"Input must be 2D (M, K), got shape {input.shape}") + if weight.ndim != 2: + raise ValueError(f"Weight must be 2D (N, K), got shape {weight.shape}") + + M, K = input.shape + N, K_weight = weight.shape + + if K != K_weight: + raise ValueError( + f"K dimension mismatch: input has K={K}, weight has K={K_weight}" + ) + + # Validate K is divisible by block size (128) + BLOCK_SIZE = 128 + if K % BLOCK_SIZE != 0: + raise ValueError( + f"K dimension must be divisible by block size ({BLOCK_SIZE}), got K={K}" + ) + + # Validate dtype combinations + input_is_fp8 = input.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + weight_is_fp8 = weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + input_is_bf16 = input.dtype == torch.bfloat16 + weight_is_bf16 = weight.dtype == torch.bfloat16 + + # Explicitly reject FP8 input + BF16 weight (missing kernel implementation) + if input_is_fp8 and weight_is_bf16: + raise ValueError( + "FP8 input + BF16 weight is not supported (missing kernel implementation). " + ) + + # Validate scale requirements for FP8 inputs + if input_is_fp8: + if input_scale is None: + raise ValueError( + "input_scale is required when input is FP8. " + ) + expected_scale_shape = (M, K // BLOCK_SIZE) + if input_scale.shape != expected_scale_shape: + raise ValueError( + f"input_scale shape mismatch. Expected {expected_scale_shape}, " + f"got {input_scale.shape}" + ) + if input_scale.dtype != torch.float32: + raise ValueError( + f"input_scale must be float32, got {input_scale.dtype}" + ) + if input_scale.device != input.device: + raise ValueError( + f"input_scale device mismatch. Expected {input.device}, " + f"got {input_scale.device}" + ) + else: + if not input_is_bf16: + raise ValueError( + f"Input must be either FP8 (torch.float8_e4m3fn) or BF16 (torch.bfloat16), " + f"got {input.dtype}" + ) + if input_scale is not None: + raise ValueError( + "input_scale should not be provided for BF16 inputs. " + "Use FP8 inputs if you want to provide external scales." + ) + + if weight_is_fp8: + if weight_scale is None: + raise ValueError( + "weight_scale is required when weight is FP8. " + ) + expected_per_token_shape = (N, K // BLOCK_SIZE) + expected_per_block_shape = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, K // BLOCK_SIZE) + is_per_token = weight_scale.shape == expected_per_token_shape + is_per_block = weight_scale.shape == expected_per_block_shape + + if not (is_per_token or is_per_block): + raise ValueError( + f"weight_scale shape mismatch. Expected either {expected_per_token_shape} " + f"(per-token, 1x128 blocks) or {expected_per_block_shape} " + f"(per-block, 128x128 blocks), got {weight_scale.shape}" + ) + if weight_scale.dtype != torch.float32: + raise ValueError( + f"weight_scale must be float32, got {weight_scale.dtype}" + ) + else: + if not weight_is_bf16: + raise ValueError( + f"Weight must be either FP8 (torch.float8_e4m3fn) or BF16 (torch.bfloat16), " + f"got {weight.dtype}" + ) + if weight_scale is not None: + raise ValueError( + "weight_scale should not be provided for BF16 weights. " + "Use FP8 weights if you want to provide external scales." + ) + + # Validate output tensor if provided + if out is not None: + if out.shape != (M, N): + raise ValueError( + f"Output shape mismatch. Expected ({M}, {N}), got {out.shape}" + ) + if out.device != input.device: + raise ValueError( + f"Output device mismatch. Expected {input.device}, got {out.device}" + ) + if out_dtype is not None and out.dtype != out_dtype: + raise ValueError( + f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}" + ) + out_dtype = out.dtype + else: + # Allocate output + out_dtype = out_dtype or torch.bfloat16 + if out_dtype not in [torch.bfloat16, torch.float16]: + raise ValueError( + f"Output dtype must be torch.bfloat16 or torch.float16, got {out_dtype}" + ) + out = torch.empty(M, N, dtype=out_dtype, device=input.device) + + # Get the runner + runner = get_fp8_blockscale_gemm_runner_sm90() + + # Allocate workspace + workspace_size = runner.get_workspace_size(M, N, K) + workspace = None + if workspace_size > 0: + workspace = torch.empty( + workspace_size, dtype=torch.uint8, device=input.device + ) + runner.configure_workspace(workspace) + + runner.gemm(input, weight, out, input_scale, weight_scale) + + return out diff --git a/flashinfer/jit/gemm/__init__.py b/flashinfer/jit/gemm/__init__.py index f1681d3bf5..e81d51e15f 100644 --- a/flashinfer/jit/gemm/__init__.py +++ b/flashinfer/jit/gemm/__init__.py @@ -27,6 +27,7 @@ gen_gemm_sm90_module, ) from .deepgemm import gen_deepgemm_sm100_module +from .fp8_blockscale import gen_fp8_blockscale_gemm_sm90_module __all__ = [ "gen_gemm_module", @@ -40,4 +41,5 @@ "gen_tgv_gemm_sm10x_module", "gen_gemm_sm90_module", "gen_deepgemm_sm100_module", + "gen_fp8_blockscale_gemm_sm90_module", ] diff --git a/flashinfer/jit/gemm/fp8_blockscale.py b/flashinfer/jit/gemm/fp8_blockscale.py new file mode 100755 index 0000000000..70546c93fd --- /dev/null +++ b/flashinfer/jit/gemm/fp8_blockscale.py @@ -0,0 +1,57 @@ +from typing import List + +from .. import env as jit_env +from ..core import ( + JitSpec, + gen_jit_spec, + sm90a_nvcc_flags, +) +from ..cpp_ext import is_cuda_version_at_least + + +def gen_fp8_blockscale_gemm_sm90_module(use_fast_build: bool = False) -> JitSpec: + """Generate JIT spec for FP8 block scale GEMM on SM90 (Hopper).""" + nvcc_flags = sm90a_nvcc_flags + [ + "-DCOMPILE_HOPPER_TMA_GEMMS", + "-DENABLE_BF16", + "-DENABLE_FP8", + "-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "", + ] + + return gen_jit_spec( + "fp8_blockscale_gemm_90", + [ + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.cu", + jit_env.FLASHINFER_CSRC_DIR / "fp8_blockscale_gemm_sm90_binding.cu", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/envUtils.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/logger.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/stringUtils.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/tllmException.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/memoryUtils.cu", + ], + extra_cuda_cflags=nvcc_flags, + extra_cflags=["-DFAST_BUILD"] if use_fast_build else [], + extra_ldflags=["-lnvrtc", "-lcuda"], + extra_include_paths=[ + jit_env.FLASHINFER_CSRC_DIR / "nv_internal", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "tensorrt_llm" + / "cutlass_extensions" + / "include", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "tensorrt_llm" + / "kernels" + / "cutlass_kernels" + / "include", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "tensorrt_llm" + / "kernels" + / "cutlass_kernels", + ], + ) + diff --git a/tests/gemm/test_fp8_blockscale_gemm.py b/tests/gemm/test_fp8_blockscale_gemm.py new file mode 100755 index 0000000000..e069b15c02 --- /dev/null +++ b/tests/gemm/test_fp8_blockscale_gemm.py @@ -0,0 +1,347 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +import torch +import torch.nn.functional as F + +import flashinfer +from flashinfer.gemm import fp8_blockscale_gemm_swapab +from flashinfer.testing.utils import per_token_cast_to_fp8, per_block_cast_to_fp8 +from flashinfer.utils import ( + get_compute_capability, + has_flashinfer_jit_cache, + is_sm90a_supported, +) +from flashinfer.jit.gemm import gen_fp8_blockscale_gemm_sm90_module + + +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) +def warmup_jit(): + """Warm up JIT compilation for FP8 block-scale GEMM if not cached.""" + if is_sm90a_supported(torch.device("cuda:0")): + jit_specs = [gen_fp8_blockscale_gemm_sm90_module()] + flashinfer.jit.build_jit_specs(jit_specs, verbose=False) + yield + + +@pytest.mark.parametrize("m", [1, 16, 32, 64, 128]) +@pytest.mark.parametrize("n", [128, 256, 512, 1024, 4096]) +@pytest.mark.parametrize("k", [256, 512, 1024, 4096]) +@pytest.mark.parametrize("input_dtype", [torch.bfloat16]) +@pytest.mark.parametrize("weight_dtype", [torch.bfloat16]) +def test_fp8_blockscale_gemm_swapab(m, n, k, input_dtype, weight_dtype): + """Test FP8 block-scale GEMM with swapAB optimization. + + This test focuses on the usage: BF16 inputs with internal quantization. + The kernel automatically handles FP8 quantization with proper block-scale computation. + """ + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + # K must be divisible by 128 (block size requirement) + if k % 128 != 0: + pytest.skip("K must be divisible by 128 for block-scale GEMM") + + device = "cuda" + torch.manual_seed(42) + + # Create BF16 inputs + input = torch.randn(m, k, device=device, dtype=input_dtype) + weight = torch.randn(n, k, device=device, dtype=weight_dtype) + + # Compute reference result + reference = torch.matmul(input, weight.T) + + # Run FP8 block-scale GEMM + output = fp8_blockscale_gemm_swapab(input, weight) + + # Verify output shape + assert output.shape == (m, n), f"Expected shape {(m, n)}, got {output.shape}" + assert output.dtype == torch.bfloat16, f"Expected BF16 output, got {output.dtype}" + + # Check correctness + cos_sim = F.cosine_similarity( + reference.flatten().float(), + output.flatten().float(), + dim=0 + ) + assert cos_sim > 0.99, f"Cosine similarity {cos_sim} is too low (expected > 0.99)" + + +@pytest.mark.parametrize("m", [1, 32, 128]) +@pytest.mark.parametrize("n", [1024, 4096]) +@pytest.mark.parametrize("k", [512, 4096]) +@pytest.mark.parametrize( + "input_dtype,weight_dtype", + [ + (torch.bfloat16, torch.bfloat16), # Both BF16 (for testing internal quantization) + (torch.bfloat16, torch.float8_e4m3fn), # BF16 input + FP8 weight + ] +) +def test_fp8_blockscale_gemm_dtypes(m, n, k, input_dtype, weight_dtype): + """Test the 2 recommended dtype combinations with proper FP8 quantization. + + Uses quantization from flashinfer.testing.utils: + - per_token_cast_to_fp8: 1x128 block quantization (for both input and weight) + + Note: Both input and weight use per_token (1x128 blocks). + The API expects scale shape (N, K//128), which per_token provides. + + These utilities return scales in the correct format (reciprocals) that + match TRT-LLM's kernel expectations. For kernel reference, + see csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh + """ + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + if k % 128 != 0: + pytest.skip("K must be divisible by 128 for block-scale GEMM") + + device = "cuda" + torch.manual_seed(42) + + # Create BF16 data for reference + input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight_bf16 = torch.randn(n, k, device=device, dtype=torch.bfloat16) + + # Quantize input + if input_dtype == torch.float8_e4m3fn: + input_tensor, input_scale = per_token_cast_to_fp8(input_bf16) + else: + input_tensor, input_scale = input_bf16, None + + # Quantize weight + if weight_dtype == torch.float8_e4m3fn: + weight_tensor, weight_scale = per_token_cast_to_fp8(weight_bf16) + else: + weight_tensor, weight_scale = weight_bf16, None + + # Compute reference + reference = torch.matmul(input_bf16, weight_bf16.T) + + # Run FP8 block-scale GEMM + output = fp8_blockscale_gemm_swapab( + input_tensor, weight_tensor, input_scale, weight_scale + ) + + # Verify output properties + assert output.shape == (m, n), f"Expected shape {(m, n)}, got {output.shape}" + assert output.dtype == torch.bfloat16, f"Expected BF16 output, got {output.dtype}" + + # Check correctness + cos_sim = F.cosine_similarity( + reference.flatten().float(), + output.flatten().float(), + dim=0 + ) + + if input_dtype == torch.bfloat16 and weight_dtype == torch.bfloat16: + threshold = 0.99 + else: + # BF16+FP8: BF16 input quantized internally, FP8 weight pre-quantized + # TODO: check threshold + threshold = 0.967 + + assert cos_sim > threshold, ( + f"Cosine similarity {cos_sim:.4f} too low for " + f"{input_dtype} + {weight_dtype} (expected > {threshold})" + ) + + +def test_fp8_blockscale_gemm_per_block_weight_scales(): + """Test BF16+FP8 GEMM with per-block (128x128) weight scales. + + This test demonstrates using 128x128 block quantization for weights with BF16 input, + """ + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + device = "cuda" + m, n, k = 16, 512, 512 + torch.manual_seed(42) + + # Create inputs + input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight_bf16 = torch.randn(n, k, device=device, dtype=torch.bfloat16) + + # Quantize weight with per-block (128x128) blocks + weight_fp8, weight_scale = per_block_cast_to_fp8(weight_bf16) + + # Verify scale shape + assert weight_scale.shape == (n // 128, k // 128), f"Expected weight scale shape ({n // 128}, {k // 128}), got {weight_scale.shape}" + assert weight_scale.min() > 0, "Weight scale should be positive (reciprocal format)" + + # Run GEMM: BF16 input (internal quant) + FP8 weight (per-block scales) + output = fp8_blockscale_gemm_swapab(input_bf16, weight_fp8, None, weight_scale) + + # Compare to BF16 reference + reference = torch.matmul(input_bf16, weight_bf16.T) + + cos_sim = F.cosine_similarity( + reference.flatten().float(), + output.flatten().float(), + dim=0 + ) + # TODO: check threshold + assert cos_sim > 0.967, f"Per-block weight scale accuracy too low: {cos_sim:.4f}" + + print(f"✓ Per-block weight scales: cosine similarity = {cos_sim:.4f}") + + +@pytest.mark.parametrize("m,n,k", [ + (1, 4096, 4096), + (8, 4096, 4096), + (128, 4096, 4096), + (16, 8192, 8192), + (32, 2048, 4096), +]) +def test_fp8_blockscale_gemm_shapes(m, n, k): + """Test various common shapes used in LLM inference.""" + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + if k % 128 != 0: + pytest.skip("K must be divisible by 128") + + device = "cuda" + torch.manual_seed(42) + + input = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) + + reference = torch.matmul(input, weight.T) + output = fp8_blockscale_gemm_swapab(input, weight) + + cos_sim = F.cosine_similarity( + reference.flatten().float(), + output.flatten().float(), + dim=0 + ) + assert cos_sim > 0.99, f"Shape ({m}, {n}, {k}): cosine similarity {cos_sim} too low" + + +def test_fp8_blockscale_gemm_error_handling(): + """Test that proper errors are raised for invalid inputs.""" + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + device = "cuda" + m, n, k = 16, 256, 256 + + # Test: K not divisible by 128 + input = torch.randn(m, 127, device=device, dtype=torch.bfloat16) + weight = torch.randn(n, 127, device=device, dtype=torch.bfloat16) + with pytest.raises(ValueError, match="divisible by block size"): + fp8_blockscale_gemm_swapab(input, weight) + + # Test: FP16 not supported + input = torch.randn(m, k, device=device, dtype=torch.float16) + weight = torch.randn(n, k, device=device, dtype=torch.float16) + with pytest.raises(ValueError, match="FP8.*or BF16"): + fp8_blockscale_gemm_swapab(input, weight) + + # Test: FP8 weight without scale (naive conversion) + input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight_bf16 = torch.randn(n, k, device=device, dtype=torch.bfloat16) + weight_fp8_naive = weight_bf16.to(torch.float8_e4m3fn) + with pytest.raises(ValueError, match="weight_scale is required when weight is FP8"): + fp8_blockscale_gemm_swapab(input_bf16, weight_fp8_naive, None, None) + + # Test: BF16 input with scale (should raise error) + input = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) + fake_scale = torch.ones(m, k // 128, device=device, dtype=torch.float32) + with pytest.raises(ValueError, match="input_scale should not be provided for BF16"): + fp8_blockscale_gemm_swapab(input, weight, input_scale=fake_scale) + + # Test: Wrong scale shape for FP8 input + input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) + input_fp8, _ = per_token_cast_to_fp8(input_bf16) + weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) + wrong_scale = torch.ones(m, k // 64, device=device, dtype=torch.float32) + with pytest.raises(ValueError): + fp8_blockscale_gemm_swapab(input_fp8, weight, input_scale=wrong_scale) + + # Test: FP8 input + BF16 weight is NOT supported + input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) + input_fp8, input_scale = per_token_cast_to_fp8(input_bf16) + weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) + with pytest.raises(ValueError, match="FP8 input.*BF16 weight.*not supported"): + fp8_blockscale_gemm_swapab(input_fp8, weight, input_scale, None) + + +def test_fp8_blockscale_gemm_output_buffer(): + """Test providing pre-allocated output buffer.""" + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + device = "cuda" + m, n, k = 16, 256, 256 + torch.manual_seed(42) + + input = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) + + # Pre-allocate output + output = torch.empty(m, n, device=device, dtype=torch.bfloat16) + + # Run GEMM with pre-allocated output + result = fp8_blockscale_gemm_swapab(input, weight, out=output) + + # Verify result is the same buffer + assert result is output + + # Verify correctness + reference = torch.matmul(input, weight.T) + cos_sim = F.cosine_similarity( + reference.flatten().float(), + output.flatten().float(), + dim=0 + ) + assert cos_sim > 0.99 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) +