From 759639413abb14f3eb3152fb35e1c574adb3c683 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 11 Nov 2025 03:13:59 -0800 Subject: [PATCH] [Mosaic GPU] Add support for 8-bit types in sparse tcgen05 MMA PiperOrigin-RevId: 830832935 --- jax/experimental/mosaic/gpu/tcgen05.py | 19 +++++++---------- tests/mosaic/gpu_test.py | 29 +++++++++++++++++++------- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 81cb10de786b..480f82197278 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -264,7 +264,7 @@ def mma( s32 = ir.IntegerType.get_signless(32) if element_type == f32 or element_type == ir.BF16Type.get(): if element_type == f32 and is_sparse: - raise NotImplementedError("Only 16-bit types supported for sparse MMA") + raise NotImplementedError("Sparse MMA unsupported for f32") if is_scaled: raise ValueError( f"MMA with element type {element_type} does not support block scaling" @@ -288,8 +288,6 @@ def mma( t.isinstance(element_type) for t in {ir.Float8E5M2Type, ir.Float8E4M3FNType} ): - if is_sparse: - raise NotImplementedError("Only 16-bit types supported for sparse MMA") if d.dtype != f16 and d.dtype != f32: raise ValueError( f"MMA with element type {element_type} only supports accumulators of" @@ -304,7 +302,7 @@ def mma( t.isinstance(element_type) for t in {ir.Float4E2M1FNType} ): if is_sparse: - raise NotImplementedError("Only 16-bit types supported for sparse MMA") + raise NotImplementedError("Sparse MMA unsupported for f4e2m1fn") if not is_scaled: raise ValueError( f"MMA with element type {element_type} only supports block scaling" @@ -315,8 +313,6 @@ def mma( f" accumulators, but got: {d.dtype}" ) elif element_type == ir.IntegerType.get_signless(8): - if is_sparse: - raise NotImplementedError("Only 16-bit types supported for sparse MMA") if is_scaled: raise ValueError( f"MMA with element type {element_type} does not support block scaling" @@ -561,11 +557,11 @@ def _do_mma( elem_bitwidth = utils.bitwidth(element_type) instr_k = (1 + is_sparse) * 8 * 32 // elem_bitwidth packing = 8 * 4 // elem_bitwidth - assert not is_sparse or elem_bitwidth == 16 # Only 16-bit supported for now. extra_args: Sequence[object] scale_steps = None if is_scaled: + assert not is_sparse if (ir.Float8E5M2Type.isinstance(element_type) or ir.Float8E4M3FNType.isinstance(element_type)): kind = "mxf8f6f4.block_scale.scale_vec::1X" @@ -628,10 +624,11 @@ def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...]) else: sp_selector = None if is_sparse: - assert (k // instr_k) % 2 == 0 - sp_selector = k_step % 2 - selector_width = 64 - k_steps_for_col_inc = selector_width // instr_k + assert 32 <= instr_k <= 64 + selector_width = instr_k + k_steps_for_col_inc = 64 // selector_width + assert (k // instr_k) % k_steps_for_col_inc == 0 + sp_selector = k_step % k_steps_for_col_inc # If the K group is large, we need to increment the sparse metadata. # TODO(apaszke): At this point the purpose of this function is becoming # less clear, since we end up replicating address arithmetic that's diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index f00c09e293bf..d5af1fad11dd 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1643,14 +1643,17 @@ def format_scales(scales): @parameterized.product( lhs_transpose=(False, True), rhs_transpose=(False, True), - in_jax_dtype=(jnp.float16, jnp.bfloat16,), + in_jax_dtype=(jnp.float16, jnp.bfloat16, jnp.int8, jnp.float8_e4m3fn), m=(128,), # TODO(apaszke): 256 n=(128, 256), # TODO(apaszke): other non-power-of-2 lhs_swizzle=(32, 64, 128), rhs_swizzle=(64, 128), # 32 is too small and unsuported. ) def test_mma_sparse(self, m, n, in_jax_dtype, lhs_swizzle, rhs_swizzle, lhs_transpose, rhs_transpose): - out_jax_dtype = jnp.float32 + if jnp.issubdtype(in_jax_dtype, jnp.floating): + out_jax_dtype = jnp.float32 + else: + out_jax_dtype = jnp.int32 sparse_meta_dtype = jnp.uint2 in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) @@ -1682,12 +1685,17 @@ def kernel(ctx, lhs, rhs, lhs_sparse_gmem, out, scratch): ) tcgen05.commit_arrive(mma_barrier) mma_barrier.wait(orders_tensor_core=True) - acc.load().store_untiled(out, optimized=False) + is_signed = True if jnp.issubdtype(in_jax_dtype, jnp.integer) else None + acc.load(is_signed=is_signed).store_untiled(out, optimized=False) x_shape = (k // 2, m) if lhs_transpose else (m, k // 2) - x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) y_shape = (n, k) if rhs_transpose else (k, n) - y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype) + if jnp.issubdtype(in_jax_dtype, jnp.integer): + x = jax.random.randint(jax.random.key(1234), x_shape, -64, 64, dtype=in_jax_dtype) + y = jax.random.randint(jax.random.key(2567), y_shape, -64, 64, dtype=in_jax_dtype) + else: + x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) + y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) scratch_shape = [ jax.ShapeDtypeStruct(tile_shape(x_shape, lhs_tiling), in_jax_dtype), @@ -1708,10 +1716,17 @@ def format_sparse_meta(meta): mn, k, _2 = meta.shape assert _2 == 2 k *= 2 - return ( + if jnp.dtype(in_jax_dtype).itemsize == 1: + meta_tiled = ( + meta.reshape(mn // 128, 128, k // 64, 64).transpose(0, 2, 1, 3) + ) + else: + meta_tiled = ( meta.reshape(mn // 128, 8, 2, 8, k // 64, 4, 2, 8) .transpose(0, 4, 1, 6, 3, 5, 2, 7) - .reshape(mn // 128, k // 64, 128, 64) + ) + return ( + meta_tiled.reshape(mn // 128, k // 64, 128, 64) .astype(sparse_meta_dtype) ) x_gpu_sparse = format_sparse_meta(x_sparse)