From 6d2389adbc9c69e87295e8eafdc36afc152bed12 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 14 Nov 2025 02:57:34 -0800 Subject: [PATCH] [Mosaic GPU] Support nvfp4 scaled MMA instructions PiperOrigin-RevId: 832237310 --- jax/experimental/mosaic/gpu/tcgen05.py | 81 ++++++++++++++++++-------- tests/mosaic/gpu_test.py | 52 ++++++++++++----- 2 files changed, 96 insertions(+), 37 deletions(-) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index f57efab589c0..b9d811415855 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -16,6 +16,7 @@ from __future__ import annotations import dataclasses +import functools import itertools import math from typing import Any, Callable, Iterator, cast @@ -113,8 +114,9 @@ def _create_scaled_instr_descriptor( b_type: ir.Type, a_scale_idx: int, b_scale_idx: int, - transpose_a: bool = False, - transpose_b: bool = False, + transpose_a: bool, + transpose_b: bool, + scale_type: ir.Type, ) -> ir.Value: desc = 0 # Bits 0, 1 are reserved @@ -131,7 +133,13 @@ def _create_scaled_instr_descriptor( if n % 8 or n > 256: raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}") desc |= (n >> 3) << 17 # N, bits 17-22 - desc |= 1 << 23 # Scale matrix type + if scale_type == ir.Float8E8M0FNUType.get(): + scale_encoding = 1 + elif scale_type == ir.Float8E4M3FNType.get(): + scale_encoding = 0 + else: + raise NotImplementedError(f"Unsupported scale type: {scale_type}") + desc |= scale_encoding << 23 # Scale matrix type # Bits 24-26 are reserved if m % 128 or m > 256: raise ValueError(f"M must be a multiple of 16 and <= 256, got: {m}") @@ -375,6 +383,8 @@ def mma( ) # Check that the shapes and element types are correct for block scaling. + scale_element_type = None + scale_block = None if is_scaled: if collective: raise NotImplementedError("MMA with block scaling does not support collective") @@ -384,24 +394,32 @@ def mma( f"MMA with block scaling requires N to be divisible by 32, got: {n}" ) assert a_scale is not None and b_scale is not None - if a_scale.shape != (m, k // 32): + scale_element_type = a_scale.dtype + if a_scale.dtype == ir.Float8E8M0FNUType.get(): + scale_block = 32 + elif a_scale.dtype == ir.Float8E4M3FNType.get(): + scale_block = 16 + else: raise ValueError( - f"A scale shape mismatch: expected ({m}, {k // 32}), got {a_scale.shape}" + f"A scale dtype mismatch: expected f8e8m0fnu, got {a_scale.dtype}" ) - if a_scale.dtype != ir.Float8E8M0FNUType.get(): + if b_scale.dtype != a_scale.dtype: raise ValueError( - f"A scale dtype mismatch: expected f8e8m0fnu, got {a_scale.dtype}" + f"B scale dtype mismatch: expected {a_scale.dtype} (same as A), got" + f" {b_scale.dtype}" ) - if b_scale.shape != (n, k // 32): + if a_scale.shape != (m, k // scale_block): raise ValueError( - f"B scale shape mismatch: expected ({n}, {k // 32}), got {b_scale.shape}" + f"A scale shape mismatch: expected ({m}, {k // scale_block}), got" + f" {a_scale.shape}" ) - if b_scale.dtype != ir.Float8E8M0FNUType.get(): + if b_scale.shape != (n, k // scale_block): raise ValueError( - f"B scale dtype mismatch: expected f8e8m0fnu, got {b_scale.dtype}" + f"B scale shape mismatch: expected ({n}, {k // scale_block}), got" + f" {b_scale.shape}" ) - if k_group_elems % 128: - min_swizzle = 16 * utils.bitwidth(element_type) + if k_group_elems % (scale_block * 4): + min_swizzle = scale_block // 2 * utils.bitwidth(element_type) raise NotImplementedError( f"{element_type} MMA with block scaling requires swizzle to be at" f" least {min_swizzle}" @@ -502,15 +520,17 @@ def mma( raise NotImplementedError("A scale address calculation for multiple M tiles") if n_groups != 1: raise NotImplementedError("B scale address calculation for multiple N tiles") - assert k_group_elems % 128 == 0 + assert scale_block is not None # For type checkers. + assert k_group_elems % (scale_block * 4) == 0 assert m_group_elems % 32 == 0 and n_group_elems % 32 == 0 + k_scales_per_group = k_group_elems // (scale_block * 4) a_scale_addr = arith.addi( a_scale_addr_base, - utils.c(ki * k_group_elems // 128 * m_group_elems // 32, i32), + utils.c(ki * k_scales_per_group * m_group_elems // 32, i32), ) b_scale_addr = arith.addi( b_scale_addr_base, - utils.c(ki * k_group_elems // 128 * n_group_elems // 32, i32), + utils.c(ki * k_scales_per_group * n_group_elems // 32, i32), ) else: a_scale_addr = b_scale_addr = None @@ -540,6 +560,7 @@ def mma( a_sparse_addr=a_sparse_addr, accumulate=acc, element_type=mma_element_type, + scale_element_type=scale_element_type, ) @@ -558,6 +579,7 @@ def _do_mma( n: int, k: int, element_type: ir.Type, + scale_element_type: ir.Type | None, d_type: ir.Type, accumulate: ir.Value, collective: bool, @@ -580,14 +602,27 @@ def _do_mma( assert not is_sparse if (ir.Float8E5M2Type.isinstance(element_type) or ir.Float8E4M3FNType.isinstance(element_type)): + if scale_element_type != ir.Float8E8M0FNUType.get(): + raise ValueError( + f"Scale element type mismatch: expected f8e8m0fnu, got {scale_element_type}" + ) kind = "mxf8f6f4.block_scale.scale_vec::1X" scale_steps = 4 - create_scaled_instr_descriptor = create_scaled_f8f6f4_instr_descriptor + create_scaled_instr_descriptor = functools.partial( + create_scaled_f8f6f4_instr_descriptor, scale_type=scale_element_type + ) elif ir.Float4E2M1FNType.isinstance(element_type): assert not a_transpose and not b_transpose - kind = "mxf4.block_scale.scale_vec::2X" - scale_steps = 2 - create_scaled_instr_descriptor = create_scaled_f4_instr_descriptor + create_scaled_instr_descriptor = functools.partial( + create_scaled_f4_instr_descriptor, + scale_type=scale_element_type, + ) + if scale_element_type == ir.Float8E8M0FNUType.get(): + kind = "mxf4.block_scale.scale_vec::2X" + scale_steps = 2 + elif scale_element_type == ir.Float8E4M3FNType.get(): + kind = "mxf4nvf4.block_scale.scale_vec::4X" + scale_steps = 1 else: raise NotImplementedError(f"Unsupported element type for block scaling: {element_type}") extra_ptx = "[$5], [$6], " @@ -605,7 +640,7 @@ def _do_mma( raise NotImplementedError(f"Unsupported input element type: {element_type}") extra_constraints = extra_ptx = "" - def create_scaled_instr_descriptor(*args): + def create_scaled_instr_descriptor(*args): # type: ignore raise NotImplementedError num_cta = 2 if collective else 1 @@ -1440,8 +1475,8 @@ def async_copy_scales_smem_to_tmem(smem_ref: ir.Value, tmem_ref: TMEMRef) -> Non smem_ty = ir.MemRefType(smem_ref.type) if (dtype := smem_ty.element_type) != tmem_ref.dtype: raise ValueError(f"Incompatible dtypes: SMEM has {dtype}, TMEM has {tmem_ref.dtype}") - if dtype != ir.Float8E8M0FNUType.get(): - raise NotImplementedError(f"Unsupported dtype: {dtype}, only f8e8m0fnu supported") + if dtype not in {ir.Float8E8M0FNUType.get(), ir.Float8E4M3FNType.get()}: + raise NotImplementedError(f"Unsupported dtype: {dtype}, only f8e8m0fnu and f8e4m3fn are supported") if tmem_ref.shape[0] % TMEM_ROWS: raise ValueError(f"TMEM reference must have a multiple of {TMEM_ROWS} rows, but got {tmem_ref.shape[0]}") if tmem_ref.shape[1] % 4: diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index e80351b25de7..9f2b32b54f4a 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1617,14 +1617,22 @@ def kernel(ctx, src, out, scratch): @parameterized.product( in_jax_dtype=(jnp.float8_e5m2, jnp.float8_e4m3fn, jnp.float4_e2m1fn), + scale_jax_dtype=(jnp.float8_e8m0fnu, jnp.float8_e4m3fn), m=(128,), # TODO(apaszke): 256 n=(128, 256), # TODO(apaszke): 192, other non-power-of-2 ) - def test_mma_block_scaled(self, m, n, in_jax_dtype): + def test_mma_block_scaled(self, m, n, in_jax_dtype, scale_jax_dtype): out_jax_dtype = jnp.float32 - scale_jax_dtype = jnp.float8_e8m0fnu swizzle = 128 k_steps = 2 + if scale_jax_dtype == jnp.float8_e8m0fnu: + block_size = 32 + elif scale_jax_dtype == jnp.float8_e4m3fn: + if in_jax_dtype != jnp.float4_e2m1fn: + self.skipTest("Only float4_e2m1fn input is supported for e4m3fn scale.") + block_size = 16 + else: + raise ValueError(f"Unsupported scale dtype: {scale_jax_dtype}") if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: self.skipTest("Only f16 input is supported for f16 output.") @@ -1670,21 +1678,37 @@ def kernel(ctx, lhs, rhs, lhs_scales_gmem, rhs_scales_gmem, out, scratch): scratch_shape = [ jax.ShapeDtypeStruct(tile_shape(x_shape, lhs_tiling), in_jax_dtype), jax.ShapeDtypeStruct(tile_shape(y_shape, rhs_tiling), in_jax_dtype), - jax.ShapeDtypeStruct((m // 128, k // (32 * 4), 32, 16), scale_jax_dtype), - jax.ShapeDtypeStruct((n // 128, k // (32 * 4), 32, 16), scale_jax_dtype), + jax.ShapeDtypeStruct((m // 128, k // (block_size * 4), 32, 16), scale_jax_dtype), + jax.ShapeDtypeStruct((n // 128, k // (block_size * 4), 32, 16), scale_jax_dtype), mgpu.TMABarrier(4), mgpu.Barrier(1), mgpu.TMEM((m, n), out_jax_dtype), - mgpu.TMEM((m, k // 32), scale_jax_dtype, layout=tcgen05.scales_layout()), - mgpu.TMEM((n, k // 32), scale_jax_dtype, layout=tcgen05.scales_layout()), + mgpu.TMEM((m, k // block_size), scale_jax_dtype, layout=tcgen05.scales_layout()), + mgpu.TMEM((n, k // block_size), scale_jax_dtype, layout=tcgen05.scales_layout()), ] ka, kb = jax.random.split(jax.random.key(1234), 2) - a_scales = jax.lax.bitcast_convert_type( - jax.random.randint(ka, (m, k // 32), 122, 132, dtype=jnp.uint8), scale_jax_dtype - ) - b_scales = jax.lax.bitcast_convert_type( - jax.random.randint(kb, (n, k // 32), 122, 132, dtype=jnp.uint8), scale_jax_dtype - ) + if scale_jax_dtype == jnp.float8_e8m0fnu: + a_scales = jax.lax.bitcast_convert_type( + jax.random.randint(ka, (m, k // block_size), 122, 132, dtype=jnp.uint8), + scale_jax_dtype + ) + b_scales = jax.lax.bitcast_convert_type( + jax.random.randint(kb, (n, k // block_size), 122, 132, dtype=jnp.uint8), + scale_jax_dtype + ) + elif scale_jax_dtype == jnp.float8_e4m3fn: + a_scales = jnp.abs( + jax.random.normal(ka, (m, k // block_size), dtype=jnp.float32).astype( + scale_jax_dtype + ) + ) + b_scales = jnp.abs( + jax.random.normal(kb, (n, k // block_size), dtype=jnp.float32).astype( + scale_jax_dtype + ) + ) + else: + raise ValueError(f"Unsupported scale dtype: {scale_jax_dtype}") def format_scales(scales): mn, k = scales.shape assert mn % 128 == 0 and k % 4 == 0, scales.shape @@ -1699,8 +1723,8 @@ def format_scales(scales): kernel, (1, 1, 1), (128, 1, 1), args, out_shape, scratch_shape )(*args) x32, y32 = x.astype(np.float32), y.astype(np.float32) - a_logical_scales = jnp.repeat(a_scales, 32, axis=1).astype(jnp.float32) - b_logical_scales = jnp.repeat(b_scales, 32, axis=1).astype(jnp.float32) + a_logical_scales = jnp.repeat(a_scales, block_size, axis=1).astype(jnp.float32) + b_logical_scales = jnp.repeat(b_scales, block_size, axis=1).astype(jnp.float32) ref = (x32 * a_logical_scales) @ (y32 * b_logical_scales).T np.testing.assert_allclose(z, ref, atol=2e-4, rtol=5e-6)