diff --git a/csrc/fp8_blockscale_gemm_sm90_binding.cu b/csrc/fp8_blockscale_gemm_sm90_binding.cu new file mode 100644 index 0000000000..4962d5023c --- /dev/null +++ b/csrc/fp8_blockscale_gemm_sm90_binding.cu @@ -0,0 +1,216 @@ + +#include + +#include +#include +#include +#include +#include + +#include "tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h" +#include "tvm_ffi_utils.h" + +namespace kernels = tensorrt_llm::kernels::fp8_blockscale_gemm; + +using tvm::ffi::Function; +using tvm::ffi::Optional; +using tvm::ffi::TensorView; + +#ifdef FLASHINFER_ENABLE_FP8_E4M3 +inline bool is_fp8_e4m3fn(DLDataType dtype) { + return encode_dlpack_dtype(dtype) == float8_e4m3fn_code; +} +#else +inline bool is_fp8_e4m3fn(DLDataType) { return false; } +#endif + +/** + * @brief FP8 Block-Scale GEMM binding for SM90 + * + * Supports: + * - BF16 + BF16 → BF16 + * - BF16 + FP8 → BF16 (weight-only quantization) + * - FP8 + FP8 → BF16 (W8A8 full quantization) + * + * @note Output is always BF16 + */ +class Fp8BlockScaleGemmRunner : public tvm::ffi::ModuleObj { + public: + Fp8BlockScaleGemmRunner() { + // Instantiate runners for all supported combinations + runner_bf16_bf16_ = std::make_unique< + kernels::CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16>>(); + + runner_bf16_fp8_ = std::make_unique< + kernels::CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>>(); + + runner_fp8_fp8_ = std::make_unique< + kernels::CutlassFp8BlockScaleGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>>(); + } + + ~Fp8BlockScaleGemmRunner() = default; + + const char* type_key() const { return "flashinfer.Fp8BlockScaleGemmRunner"; } + const char* kind() const final { return "fp8_blockscale_gemm_runner"; } + + Optional GetFunction(const tvm::ffi::String& name) { + if (name == "gemm") { + return Function::FromTyped([this](TensorView input, TensorView weight, TensorView output, + Optional scales_a, + Optional scales_b) { + runGemm(input, weight, output, scales_a, scales_b); + }); + } else if (name == "get_workspace_size") { + return Function::FromTyped( + [this](int64_t shape_m, int64_t shape_n, int64_t shape_k) -> int64_t { + return getWorkspaceSize(shape_m, shape_n, shape_k); + }); + } else if (name == "configure_workspace") { + return Function::FromTyped([this](TensorView workspace) { configureWorkspace(workspace); }); + } + return Function(nullptr); + } + + private: + /** + * @brief Runtime dtype dispatch + */ + kernels::CutlassFp8BlockScaleGemmRunnerInterface* selectRunner(bool input_is_fp8, + bool weight_is_fp8) { + if (!input_is_fp8 && !weight_is_fp8) { + return runner_bf16_bf16_.get(); + } else if (!input_is_fp8 && weight_is_fp8) { + return runner_bf16_fp8_.get(); + } else if (input_is_fp8 && weight_is_fp8) { + return runner_fp8_fp8_.get(); // W8A8 + } else { + // FP8 input + BF16 weight is not supported by TensorRT-LLM + return nullptr; + } + } + + void runGemm(const TensorView& input, const TensorView& weight, const TensorView& output, + const Optional& scales_a, const Optional& scales_b) { + auto stream = get_stream(input.device()); + + auto input_ptr = input.data_ptr(); + auto weight_ptr = weight.data_ptr(); + auto output_ptr = output.data_ptr(); + + int shape_m = input.size(0); + int shape_k = input.size(1); + int shape_n = weight.size(0); + + TVM_FFI_ICHECK(input_ptr != nullptr) << "input is null"; + TVM_FFI_ICHECK(weight_ptr != nullptr) << "weight is null"; + TVM_FFI_ICHECK(output_ptr != nullptr) << "output is null"; + TVM_FFI_ICHECK(shape_k == weight.size(1)) << "K dimension mismatch"; + + // Determine dtypes for runner selection + bool input_is_fp8 = is_fp8_e4m3fn(input.dtype()); + bool weight_is_fp8 = is_fp8_e4m3fn(weight.dtype()); + + // Validate scale requirements + if (input_is_fp8) { + TVM_FFI_ICHECK(scales_a.has_value() && scales_a.value().data_ptr() != nullptr) + << "scales_a is required for FP8 input"; + // TensorRT-LLM expects scale shape: (K/128, M) after transpose + int64_t expected_scale_k = (shape_k + 127) / 128; + TVM_FFI_ICHECK(scales_a.value().size(0) == expected_scale_k && + scales_a.value().size(1) == shape_m) + << "scales_a shape mismatch: expected (" << expected_scale_k << ", " << shape_m + << "), got (" << scales_a.value().size(0) << ", " << scales_a.value().size(1) << ")"; + } + + if (weight_is_fp8) { + TVM_FFI_ICHECK(scales_b.has_value() && scales_b.value().data_ptr() != nullptr) + << "scales_b is required for FP8 weight"; + // Validate scale shape: should be (N, K/128) for per-token or (N/128, K/128) for per-block + int64_t expected_scale_k = (shape_k + 127) / 128; + int64_t scale_dim0 = scales_b.value().size(0); + int64_t scale_dim1 = scales_b.value().size(1); + + bool is_per_token = (scale_dim0 == shape_n && scale_dim1 == expected_scale_k); + bool is_per_block = (scale_dim0 == (shape_n + 127) / 128 && scale_dim1 == expected_scale_k); + + TVM_FFI_ICHECK(is_per_token || is_per_block) + << "scales_b shape mismatch: expected (" << shape_n << ", " << expected_scale_k + << ") for per-token or (" << ((shape_n + 127) / 128) << ", " << expected_scale_k + << ") for per-block, got (" << scale_dim0 << ", " << scale_dim1 << ")"; + } + + // Extract scale pointers + float const* scales_a_ptr = scales_a.has_value() + ? reinterpret_cast(scales_a.value().data_ptr()) + : nullptr; + float const* scales_b_ptr = scales_b.has_value() + ? reinterpret_cast(scales_b.value().data_ptr()) + : nullptr; + + // Select appropriate runner + auto* runner = selectRunner(input_is_fp8, weight_is_fp8); + TVM_FFI_ICHECK(runner != nullptr) << "Unsupported dtype combination"; + TVM_FFI_ICHECK(workspace_ != nullptr) + << "Workspace not configured. Call configure_workspace first."; + + // TensorRT-LLM has two gemm() methods: + // 1. gemm(void*, ...) - for internal quantization (BF16 inputs) + // 2. gemm(__nv_fp8_e4m3*, int, __nv_fp8_e4m3*, int, ...) - for pre-quantized FP8 inputs + if (input_is_fp8 && weight_is_fp8) { + // W8A8: Use the pre-quantized FP8 path + auto* fp8_input = reinterpret_cast<__nv_fp8_e4m3*>(input_ptr); + auto* fp8_weight = reinterpret_cast<__nv_fp8_e4m3*>(weight_ptr); + auto* bf16_output = reinterpret_cast<__nv_bfloat16*>(output_ptr); + + runner->gemm(fp8_input, shape_k, // input with leading dimension + fp8_weight, shape_k, // weight with leading dimension + bf16_output, shape_n, // output with leading dimension + shape_m, shape_n, shape_k, scales_a_ptr, scales_b_ptr, stream); + } else { + // BF16+BF16 or BF16+FP8: Use internal quantization path + runner->gemm(output_ptr, input_ptr, weight_ptr, shape_m, shape_n, shape_k, stream, + scales_a_ptr, scales_b_ptr); + } + } + + int64_t getWorkspaceSize(int64_t shape_m, int64_t shape_n, int64_t shape_k) { + size_t max_size = 0; + + max_size = + std::max(max_size, runner_bf16_bf16_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1)); + max_size = + std::max(max_size, runner_bf16_fp8_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1)); + max_size = + std::max(max_size, runner_fp8_fp8_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1)); + + return max_size; + } + + void configureWorkspace(const TensorView& workspace) { + auto workspace_ptr = reinterpret_cast(workspace.data_ptr()); + workspace_ = workspace_ptr; + + runner_bf16_bf16_->configureWorkspace(workspace_ptr); + runner_bf16_fp8_->configureWorkspace(workspace_ptr); + runner_fp8_fp8_->configureWorkspace(workspace_ptr); + } + + std::unique_ptr< + kernels::CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16>> + runner_bf16_bf16_; + std::unique_ptr< + kernels::CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>> + runner_bf16_fp8_; + std::unique_ptr< + kernels::CutlassFp8BlockScaleGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>> + runner_fp8_fp8_; + + char* workspace_ = nullptr; +}; + +tvm::ffi::Module init() { + auto ptr = tvm::ffi::make_object(); + return tvm::ffi::Module(ptr); +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(init, init); diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index faad4f12a3..cb239b388f 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -89,6 +89,7 @@ from .gemm import mm_fp4 as mm_fp4 from .gemm import mm_fp8 as mm_fp8 from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100 +from .gemm import get_fp8_blockscale_gemm_runner_sm90 as Fp8BlockScaleGemmRunner from .mla import BatchMLAPagedAttentionWrapper as BatchMLAPagedAttentionWrapper from .norm import fused_add_rmsnorm as fused_add_rmsnorm from .norm import layernorm as layernorm diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index 15652268ba..7389eb5b1b 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -13,6 +13,12 @@ from .gemm_base import gemm_fp8_nt_blockscaled as gemm_fp8_nt_blockscaled from .gemm_base import gemm_fp8_nt_groupwise as gemm_fp8_nt_groupwise from .gemm_base import group_gemm_fp8_nt_groupwise as group_gemm_fp8_nt_groupwise +from .gemm_base import ( + get_fp8_blockscale_gemm_runner_sm90 as get_fp8_blockscale_gemm_runner_sm90, +) +from .gemm_base import ( + fp8_blockscale_gemm_swapab as fp8_blockscale_gemm_swapab, +) from .routergemm_dsv3 import ( mm_M1_16_K7168_N256 as mm_M1_16_K7168_N256, @@ -30,5 +36,7 @@ "gemm_fp8_nt_blockscaled", "gemm_fp8_nt_groupwise", "group_gemm_fp8_nt_groupwise", + "get_fp8_blockscale_gemm_runner_sm90", + "fp8_blockscale_gemm_swapab", "mm_M1_16_K7168_N256", ] diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index ac0fbab4a0..c716905100 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -54,6 +54,7 @@ from ..jit.gemm import gen_trtllm_gen_gemm_module from ..jit.gemm import gen_tgv_gemm_sm10x_module from ..jit.gemm import gen_deepgemm_sm100_module +from ..jit.gemm import gen_fp8_blockscale_gemm_sm90_module CUDNN_AVAILABLE = False @@ -3111,3 +3112,254 @@ def batch_deepgemm_fp8_nt_groupwise( ) return out + + +@functools.cache +def get_fp8_blockscale_gemm_runner_sm90(): + """Get the FP8 block scale GEMM runner module for SM90.""" + return gen_fp8_blockscale_gemm_sm90_module().build_and_load().init() + + +def fp8_blockscale_gemm_swapab( + input: torch.Tensor, + weight: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + weight_scale: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """ + Perform FP8 block-scaled GEMM with automatic swapAB optimization. + This function automatically selects between normal and swapAB kernel based on + the M dimension. For small M (< 32), it uses the swapAB kernel for + better performance. + + Supported Dtype Combinations + ----------------------------- + - **BF16 + BF16 → BF16**: Both inputs BF16, internal quantization (no scales needed) + - **BF16 + FP8 → BF16**: BF16 input, FP8 weight + + Parameters + ---------- + input : torch.Tensor + Input activation tensor of shape (M, K). + - BF16 (torch.bfloat16) with internal quantization + weight : torch.Tensor + Weight tensor of shape (N, K). Can be: + - FP8 (torch.float8_e4m3fn) with weight_scale required + - BF16 (torch.bfloat16) for internal quantization + input_scale : torch.Tensor, optional + weight_scale : torch.Tensor, optional + Scaling factors for weight. Required if weight is FP8. + out : torch.Tensor, optional + Output tensor of shape (M, N). If None, will be allocated. + out_dtype : torch.dtype, optional + Output data type. Default is torch.bfloat16. + Returns + ------- + torch.Tensor + Output tensor of shape (M, N) with dtype `out_dtype`. + Examples + -------- + >>> import torch + >>> from flashinfer.gemm import fp8_blockscale_gemm_swapab + >>> + >>> M, N, K = 16, 4096, 4096 + >>> device = "cuda" + >>> + >>> # BF16 inputs + >>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16) + >>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16) + >>> output = fp8_blockscale_gemm_swapab(input_bf16, weight_bf16) + >>> print(output.shape) # torch.Size([16, 4096]) + >>> + >>> # Mixed: BF16 input + FP8 weight + >>> from flashinfer.testing.utils import per_token_cast_to_fp8 + >>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16) + >>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16) + >>> weight_fp8, weight_scale = per_token_cast_to_fp8(weight_bf16) + >>> output = fp8_blockscale_gemm_swapab(input_bf16, weight_fp8, None, weight_scale) + >>> print(output.shape) # torch.Size([16, 4096]) + >>> + >>> # FP8 weight with 128x128 block scales + >>> from flashinfer.testing.utils import per_block_cast_to_fp8 + >>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16) + >>> weight_fp8, weight_scale = per_block_cast_to_fp8(weight_bf16) + >>> # weight_scale has shape (N // 128, K // 128) + >>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16) + >>> output = fp8_blockscale_gemm_swapab(input_bf16, weight_fp8, None, weight_scale) + >>> print(output.shape) # torch.Size([16, 4096]) + Notes + ----- + - This function requires NVIDIA Hopper (SM90) architecture and CUDA 12.8+ + - SwapAB kernel is automatically used when M < 32 (threshold) + - For FP8 inputs, scaling factors must be provided + - For BF16 inputs, quantization and scaling happen internally + - Weight scales support two granularities: + * Per-token (1x128 blocks): (N, K//128) + * Per-block (128x128 blocks): (N//128, K//128) + - Input scales only support per-token format: (M, K//128) + - The function uses DeepGEMM backend with JIT compilation + """ + # Validate architecture support + if not _match_sm_version(input.device, ["90", "90a"]): + raise ValueError( + "fp8_blockscale_gemm_swapab is only supported on SM90 (Hopper) architecture." + ) + + # Validate tensor dimensions + if input.ndim != 2: + raise ValueError(f"Input must be 2D (M, K), got shape {input.shape}") + if weight.ndim != 2: + raise ValueError(f"Weight must be 2D (N, K), got shape {weight.shape}") + + M, K = input.shape + N, K_weight = weight.shape + + if K != K_weight: + raise ValueError( + f"K dimension mismatch: input has K={K}, weight has K={K_weight}" + ) + + # Validate K is divisible by block size (128) + BLOCK_SIZE = 128 + if K % BLOCK_SIZE != 0: + raise ValueError( + f"K dimension must be divisible by block size ({BLOCK_SIZE}), got K={K}" + ) + + # Validate dtype combinations + input_is_fp8 = input.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + weight_is_fp8 = weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + input_is_bf16 = input.dtype == torch.bfloat16 + weight_is_bf16 = weight.dtype == torch.bfloat16 + + # Explicitly reject FP8 input + BF16 weight (missing kernel implementation) + if input_is_fp8 and weight_is_bf16: + raise ValueError( + "FP8 input + BF16 weight is not supported (missing kernel implementation). " + ) + + # Validate scale requirements for FP8 inputs + if input_is_fp8: + if input_scale is None: + raise ValueError("input_scale is required when input is FP8. ") + # Users provide input_scale in shape (M, K//128), matching per_token_cast_to_fp8 output. + # We transpose it internally to (K//128, M) to match TensorRT-LLM kernel expectations. + expected_scale_shape = (M, K // BLOCK_SIZE) + if input_scale.shape != expected_scale_shape: + raise ValueError( + f"input_scale shape mismatch. Expected {expected_scale_shape}, " + f"got {input_scale.shape}" + ) + if input_scale.dtype != torch.float32: + raise ValueError(f"input_scale must be float32, got {input_scale.dtype}") + if input_scale.device != input.device: + raise ValueError( + f"input_scale device mismatch. Expected {input.device}, " + f"got {input_scale.device}" + ) + else: + if not input_is_bf16: + raise ValueError( + f"Input must be either FP8 (torch.float8_e4m3fn) or BF16 (torch.bfloat16), " + f"got {input.dtype}" + ) + if input_scale is not None: + raise ValueError( + "input_scale should not be provided for BF16 inputs. " + "Use FP8 inputs if you want to provide external scales." + ) + + if weight_is_fp8: + if weight_scale is None: + raise ValueError("weight_scale is required when weight is FP8. ") + expected_per_token_shape = (N, K // BLOCK_SIZE) + expected_per_block_shape = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, K // BLOCK_SIZE) + is_per_token = weight_scale.shape == expected_per_token_shape + is_per_block = weight_scale.shape == expected_per_block_shape + + if not (is_per_token or is_per_block): + raise ValueError( + f"weight_scale shape mismatch. Expected either {expected_per_token_shape} " + f"(per-token, 1x128 blocks) or {expected_per_block_shape} " + f"(per-block, 128x128 blocks), got {weight_scale.shape}" + ) + if weight_scale.dtype != torch.float32: + raise ValueError(f"weight_scale must be float32, got {weight_scale.dtype}") + else: + if not weight_is_bf16: + raise ValueError( + f"Weight must be either FP8 (torch.float8_e4m3fn) or BF16 (torch.bfloat16), " + f"got {weight.dtype}" + ) + if weight_scale is not None: + raise ValueError( + "weight_scale should not be provided for BF16 weights. " + "Use FP8 weights if you want to provide external scales." + ) + + # Validate output tensor if provided + if out is not None: + if out.shape != (M, N): + raise ValueError( + f"Output shape mismatch. Expected ({M}, {N}), got {out.shape}" + ) + if out.device != input.device: + raise ValueError( + f"Output device mismatch. Expected {input.device}, got {out.device}" + ) + if out_dtype is not None and out.dtype != out_dtype: + raise ValueError( + f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}" + ) + out_dtype = out.dtype + else: + # Allocate output + out_dtype = out_dtype or torch.bfloat16 + if out_dtype not in [torch.bfloat16, torch.float16]: + raise ValueError( + f"Output dtype must be torch.bfloat16 or torch.float16, got {out_dtype}" + ) + out = torch.empty(M, N, dtype=out_dtype, device=input.device) + + # Get the runner + runner = get_fp8_blockscale_gemm_runner_sm90() + + # Allocate workspace + workspace_size = runner.get_workspace_size(M, N, K) + workspace = None + if workspace_size > 0: + workspace = torch.empty(workspace_size, dtype=torch.uint8, device=input.device) + runner.configure_workspace(workspace) + + if input_scale is not None: + M_padded = ((M + 3) // 4) * 4 # Round M up to multiple of 4 + K_blocks = K // BLOCK_SIZE + + # Create padded tensor with the stride TRT-LLM expects + input_scale_padded = torch.zeros( + K_blocks, M_padded, dtype=torch.float32, device=input.device + ) + + # Copy scales into the non-padded region: (K//128, M) + # Transpose from (M, K//128) to (K//128, M) and copy + input_scale_padded[:, :M] = input_scale.T + + # Extract view of the actual (K//128, M) region + # This view has stride (M_padded, 1) which matches TRT-LLM's expectations + input_scale_transposed = input_scale_padded[:, :M] + + # Verify stride matches TRT-LLM's expectations + expected_stride_0 = M_padded + if input_scale_transposed.stride(0) != expected_stride_0: + raise ValueError( + f"input_scale stride mismatch: expected stride[0]={expected_stride_0} " + f"(M_padded={M_padded}), got {input_scale_transposed.stride(0)}" + ) + else: + input_scale_transposed = None + + runner.gemm(input, weight, out, input_scale_transposed, weight_scale) + + return out diff --git a/flashinfer/jit/gemm/__init__.py b/flashinfer/jit/gemm/__init__.py index f1681d3bf5..e81d51e15f 100644 --- a/flashinfer/jit/gemm/__init__.py +++ b/flashinfer/jit/gemm/__init__.py @@ -27,6 +27,7 @@ gen_gemm_sm90_module, ) from .deepgemm import gen_deepgemm_sm100_module +from .fp8_blockscale import gen_fp8_blockscale_gemm_sm90_module __all__ = [ "gen_gemm_module", @@ -40,4 +41,5 @@ "gen_tgv_gemm_sm10x_module", "gen_gemm_sm90_module", "gen_deepgemm_sm100_module", + "gen_fp8_blockscale_gemm_sm90_module", ] diff --git a/flashinfer/jit/gemm/fp8_blockscale.py b/flashinfer/jit/gemm/fp8_blockscale.py new file mode 100755 index 0000000000..c2a79dd5b6 --- /dev/null +++ b/flashinfer/jit/gemm/fp8_blockscale.py @@ -0,0 +1,56 @@ +from typing import List + +from .. import env as jit_env +from ..core import ( + JitSpec, + gen_jit_spec, + sm90a_nvcc_flags, +) +from ..cpp_ext import is_cuda_version_at_least + + +def gen_fp8_blockscale_gemm_sm90_module(use_fast_build: bool = False) -> JitSpec: + """Generate JIT spec for FP8 block scale GEMM on SM90 (Hopper).""" + nvcc_flags = sm90a_nvcc_flags + [ + "-DCOMPILE_HOPPER_TMA_GEMMS", + "-DENABLE_BF16", + "-DENABLE_FP8", + "-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "", + ] + + return gen_jit_spec( + "fp8_blockscale_gemm_90", + [ + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.cu", + jit_env.FLASHINFER_CSRC_DIR / "fp8_blockscale_gemm_sm90_binding.cu", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/envUtils.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/logger.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/stringUtils.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/tllmException.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/memoryUtils.cu", + ], + extra_cuda_cflags=nvcc_flags, + extra_cflags=["-DFAST_BUILD"] if use_fast_build else [], + extra_ldflags=["-lnvrtc", "-lcuda"], + extra_include_paths=[ + jit_env.FLASHINFER_CSRC_DIR / "nv_internal", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "tensorrt_llm" + / "cutlass_extensions" + / "include", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "tensorrt_llm" + / "kernels" + / "cutlass_kernels" + / "include", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "tensorrt_llm" + / "kernels" + / "cutlass_kernels", + ], + ) diff --git a/tests/gemm/test_fp8_blockscale_gemm.py b/tests/gemm/test_fp8_blockscale_gemm.py new file mode 100755 index 0000000000..02502f4704 --- /dev/null +++ b/tests/gemm/test_fp8_blockscale_gemm.py @@ -0,0 +1,435 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +import torch +import torch.nn.functional as F + +import flashinfer +from flashinfer.gemm import fp8_blockscale_gemm_swapab +from flashinfer.testing.utils import per_token_cast_to_fp8, per_block_cast_to_fp8 +from flashinfer.utils import ( + get_compute_capability, + has_flashinfer_jit_cache, + is_sm90a_supported, +) +from flashinfer.jit.gemm import gen_fp8_blockscale_gemm_sm90_module + + +def calc_diff(output: torch.Tensor, expected: torch.Tensor) -> float: + """Calculate similarity difference using TensorRT-LLM's metric. + + Returns diff = 1 - sim, where sim = 2* / (||x||² + ||y||²) + This is similar to cosine similarity but uses squared norms in denominator. + + diff < 0.001 corresponds to >99.9% similarity. + """ + output_f64 = output.to(torch.float64) + expected_f64 = expected.to(torch.float64) + denominator = (output_f64 * output_f64 + expected_f64 * expected_f64).sum() + sim = 2 * (output_f64 * expected_f64).sum() / denominator + diff = 1 - sim + return diff.item() + + +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) +def warmup_jit(): + """Warm up JIT compilation for FP8 block-scale GEMM if not cached.""" + if is_sm90a_supported(torch.device("cuda:0")): + jit_specs = [gen_fp8_blockscale_gemm_sm90_module()] + flashinfer.jit.build_jit_specs(jit_specs, verbose=False) + yield + + +@pytest.mark.parametrize("m", [1, 16, 32, 64, 128]) +@pytest.mark.parametrize("n", [128, 256, 512, 1024, 4096]) +@pytest.mark.parametrize("k", [256, 512, 1024, 4096]) +@pytest.mark.parametrize("input_dtype", [torch.bfloat16]) +@pytest.mark.parametrize("weight_dtype", [torch.bfloat16]) +def test_fp8_blockscale_gemm_swapab(m, n, k, input_dtype, weight_dtype): + """Test FP8 block-scale GEMM with swapAB optimization. + + This test focuses on the usage: BF16 inputs with internal quantization. + The kernel automatically handles FP8 quantization with proper block-scale computation. + """ + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + # K must be divisible by 128 (block size requirement) + if k % 128 != 0: + pytest.skip("K must be divisible by 128 for block-scale GEMM") + + device = "cuda" + torch.manual_seed(42) + + # Create BF16 inputs + input = torch.randn(m, k, device=device, dtype=input_dtype) + weight = torch.randn(n, k, device=device, dtype=weight_dtype) + + # Compute reference result + reference = torch.matmul(input, weight.T) + + # Run FP8 block-scale GEMM + output = fp8_blockscale_gemm_swapab(input, weight) + + # Verify output shape + assert output.shape == (m, n), f"Expected shape {(m, n)}, got {output.shape}" + assert output.dtype == torch.bfloat16, f"Expected BF16 output, got {output.dtype}" + + # Check correctness + cos_sim = F.cosine_similarity( + reference.flatten().float(), output.flatten().float(), dim=0 + ) + assert cos_sim > 0.99, f"Cosine similarity {cos_sim} is too low (expected > 0.99)" + + +@pytest.mark.parametrize("m", [1, 32, 128]) +@pytest.mark.parametrize("n", [1024, 4096]) +@pytest.mark.parametrize("k", [512, 4096]) +@pytest.mark.parametrize( + "input_dtype,weight_dtype", + [ + ( + torch.bfloat16, + torch.bfloat16, + ), # Both BF16 (for testing internal quantization) + (torch.bfloat16, torch.float8_e4m3fn), # BF16 input + FP8 weight + ], +) +def test_fp8_blockscale_gemm_dtypes(m, n, k, input_dtype, weight_dtype): + """Test the 2 recommended dtype combinations with proper FP8 quantization. + + Uses quantization from flashinfer.testing.utils: + - per_token_cast_to_fp8: 1x128 block quantization (for both input and weight) + + Note: Both input and weight use per_token (1x128 blocks). + The API expects scale shape (N, K//128), which per_token provides. + + These utilities return scales in the correct format (reciprocals) that + match TRT-LLM's kernel expectations. For kernel reference, + see csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh + """ + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + if k % 128 != 0: + pytest.skip("K must be divisible by 128 for block-scale GEMM") + + device = "cuda" + torch.manual_seed(42) + + # Create BF16 data for reference + input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight_bf16 = torch.randn(n, k, device=device, dtype=torch.bfloat16) + + # Quantize input + if input_dtype == torch.float8_e4m3fn: + input_tensor, input_scale = per_token_cast_to_fp8(input_bf16) + else: + input_tensor, input_scale = input_bf16, None + + # Quantize weight + if weight_dtype == torch.float8_e4m3fn: + weight_tensor, weight_scale = per_token_cast_to_fp8(weight_bf16) + else: + weight_tensor, weight_scale = weight_bf16, None + + # Compute reference + reference = torch.matmul(input_bf16, weight_bf16.T) + + # Run FP8 block-scale GEMM + output = fp8_blockscale_gemm_swapab( + input_tensor, weight_tensor, input_scale, weight_scale + ) + + # Verify output properties + assert output.shape == (m, n), f"Expected shape {(m, n)}, got {output.shape}" + assert output.dtype == torch.bfloat16, f"Expected BF16 output, got {output.dtype}" + + # Check correctness + cos_sim = F.cosine_similarity( + reference.flatten().float(), output.flatten().float(), dim=0 + ) + + if input_dtype == torch.bfloat16 and weight_dtype == torch.bfloat16: + threshold = 0.99 + else: + # BF16+FP8: BF16 input quantized internally, FP8 weight pre-quantized + # TODO: check threshold + threshold = 0.967 + + assert cos_sim > threshold, ( + f"Cosine similarity {cos_sim:.4f} too low for " + f"{input_dtype} + {weight_dtype} (expected > {threshold})" + ) + + +@pytest.mark.parametrize("m", [7, 32, 128]) +@pytest.mark.parametrize("n", [1024, 4096]) +@pytest.mark.parametrize("k", [512, 4096]) +def test_fp8_blockscale_gemm_w8a8(m, n, k): + """Test W8A8 (FP8+FP8) GEMM with per-token scales for both input and weight. + + This test demonstrates full FP8 quantization for both activations and weights. + """ + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + device = "cuda" + # m, n, k = 64, 2048, 4096 + torch.manual_seed(42) + + # Create BF16 inputs for reference (no normalization) + # Raw randn values work well with FP8 quantization without causing numerical issues + input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight_bf16 = torch.randn(n, k, device=device, dtype=torch.bfloat16) + + # Quantize both input and weight to FP8 with per-token (1x128) scales + input_fp8, input_scale = per_token_cast_to_fp8(input_bf16) + weight_fp8, weight_scale = per_token_cast_to_fp8(weight_bf16) + + # Verify scale shapes + assert input_scale.shape == (m, k // 128), ( + f"Expected input scale shape ({m}, {k // 128}), got {input_scale.shape}" + ) + assert weight_scale.shape == (n, k // 128), ( + f"Expected weight scale shape ({n}, {k // 128}), got {weight_scale.shape}" + ) + assert input_scale.min() > 0, "Input scale should be positive" + assert weight_scale.min() > 0, "Weight scale should be positive" + + # Run W8A8 GEMM: FP8 input + FP8 weight + output = fp8_blockscale_gemm_swapab( + input_fp8, weight_fp8, input_scale, weight_scale + ) + + # Dequantize FP8 tensors to create reference (tests kernel correctness, not quantization) + # Dequant: bf16 = fp8.to(bf16) * scale (applied per 128-element block) + input_dequant = torch.zeros_like(input_bf16) + for i in range(m): + for k_tile in range(k // 128): + start, end = k_tile * 128, (k_tile + 1) * 128 + input_dequant[i, start:end] = ( + input_fp8[i, start:end].to(torch.bfloat16) * input_scale[i, k_tile] + ) + + weight_dequant = torch.zeros_like(weight_bf16) + for j in range(n): + for k_tile in range(k // 128): + start, end = k_tile * 128, (k_tile + 1) * 128 + weight_dequant[j, start:end] = ( + weight_fp8[j, start:end].to(torch.bfloat16) * weight_scale[j, k_tile] + ) + + reference = torch.matmul(input_dequant, weight_dequant.T) + + # Use cosine similarity (same metric as BF16+FP8 tests) + cos_sim = F.cosine_similarity( + reference.flatten().float(), output.flatten().float(), dim=0 + ) + # W8A8 achieves ~97% cosine similarity against dequantized FP8 reference + assert cos_sim > 0.967, ( + f"W8A8 cosine similarity {cos_sim:.4f} too low (expected > 0.967)" + ) + + print(f"✓ W8A8 (FP8+FP8): cosine similarity = {cos_sim:.4f}") + + +def test_fp8_blockscale_gemm_per_block_weight_scales(): + """Test BF16+FP8 GEMM with per-block (128x128) weight scales. + + This test demonstrates using 128x128 block quantization for weights with BF16 input, + """ + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + device = "cuda" + m, n, k = 16, 512, 512 + torch.manual_seed(42) + + # Create inputs + input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight_bf16 = torch.randn(n, k, device=device, dtype=torch.bfloat16) + + # Quantize weight with per-block (128x128) blocks + weight_fp8, weight_scale = per_block_cast_to_fp8(weight_bf16) + + # Verify scale shape + assert weight_scale.shape == (n // 128, k // 128), ( + f"Expected weight scale shape ({n // 128}, {k // 128}), got {weight_scale.shape}" + ) + assert weight_scale.min() > 0, "Weight scale should be positive (reciprocal format)" + + # Run GEMM: BF16 input (internal quant) + FP8 weight (per-block scales) + output = fp8_blockscale_gemm_swapab(input_bf16, weight_fp8, None, weight_scale) + + # Compare to BF16 reference + reference = torch.matmul(input_bf16, weight_bf16.T) + + cos_sim = F.cosine_similarity( + reference.flatten().float(), output.flatten().float(), dim=0 + ) + # TODO: check threshold + assert cos_sim > 0.967, f"Per-block weight scale accuracy too low: {cos_sim:.4f}" + + print(f"✓ Per-block weight scales: cosine similarity = {cos_sim:.4f}") + + +@pytest.mark.parametrize( + "m,n,k", + [ + (1, 4096, 4096), + (8, 4096, 4096), + (128, 4096, 4096), + (16, 8192, 8192), + (32, 2048, 4096), + ], +) +def test_fp8_blockscale_gemm_shapes(m, n, k): + """Test various common shapes used in LLM inference.""" + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + if k % 128 != 0: + pytest.skip("K must be divisible by 128") + + device = "cuda" + torch.manual_seed(42) + + input = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) + + reference = torch.matmul(input, weight.T) + output = fp8_blockscale_gemm_swapab(input, weight) + + cos_sim = F.cosine_similarity( + reference.flatten().float(), output.flatten().float(), dim=0 + ) + assert cos_sim > 0.99, f"Shape ({m}, {n}, {k}): cosine similarity {cos_sim} too low" + + +def test_fp8_blockscale_gemm_error_handling(): + """Test that proper errors are raised for invalid inputs.""" + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + device = "cuda" + m, n, k = 16, 256, 256 + + # Test: K not divisible by 128 + input = torch.randn(m, 127, device=device, dtype=torch.bfloat16) + weight = torch.randn(n, 127, device=device, dtype=torch.bfloat16) + with pytest.raises(ValueError, match="divisible by block size"): + fp8_blockscale_gemm_swapab(input, weight) + + # Test: FP16 not supported + input = torch.randn(m, k, device=device, dtype=torch.float16) + weight = torch.randn(n, k, device=device, dtype=torch.float16) + with pytest.raises(ValueError, match="FP8.*or BF16"): + fp8_blockscale_gemm_swapab(input, weight) + + # Test: FP8 weight without scale (naive conversion) + input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight_bf16 = torch.randn(n, k, device=device, dtype=torch.bfloat16) + weight_fp8_naive = weight_bf16.to(torch.float8_e4m3fn) + with pytest.raises(ValueError, match="weight_scale is required when weight is FP8"): + fp8_blockscale_gemm_swapab(input_bf16, weight_fp8_naive, None, None) + + # Test: BF16 input with scale (should raise error) + input = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) + fake_scale = torch.ones(m, k // 128, device=device, dtype=torch.float32) + with pytest.raises(ValueError, match="input_scale should not be provided for BF16"): + fp8_blockscale_gemm_swapab(input, weight, input_scale=fake_scale) + + # Test: Wrong scale shape for FP8 input + input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) + input_fp8, _ = per_token_cast_to_fp8(input_bf16) + weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) + wrong_scale = torch.ones(m, k // 64, device=device, dtype=torch.float32) + with pytest.raises(ValueError): + fp8_blockscale_gemm_swapab(input_fp8, weight, input_scale=wrong_scale) + + # Test: FP8 input + BF16 weight is NOT supported + input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) + input_fp8, input_scale = per_token_cast_to_fp8(input_bf16) + weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) + with pytest.raises(ValueError, match="FP8 input.*BF16 weight.*not supported"): + fp8_blockscale_gemm_swapab(input_fp8, weight, input_scale, None) + + +def test_fp8_blockscale_gemm_output_buffer(): + """Test providing pre-allocated output buffer.""" + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + device = "cuda" + m, n, k = 16, 256, 256 + torch.manual_seed(42) + + input = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) + + # Pre-allocate output + output = torch.empty(m, n, device=device, dtype=torch.bfloat16) + + # Run GEMM with pre-allocated output + result = fp8_blockscale_gemm_swapab(input, weight, out=output) + + # Verify result is the same buffer + assert result is output + + # Verify correctness + reference = torch.matmul(input, weight.T) + cos_sim = F.cosine_similarity( + reference.flatten().float(), output.flatten().float(), dim=0 + ) + assert cos_sim > 0.99 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])