From 6e55e9d7a5905303170f8d3869e7d305ef464566 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 13 Nov 2025 05:48:12 -0800 Subject: [PATCH] [Mosaic GPU] Add support for block-scaled MMA with all K sizes divisible by 128 PiperOrigin-RevId: 831821303 --- jax/experimental/mosaic/gpu/tcgen05.py | 124 ++++++++++++++++--------- tests/mosaic/gpu_test.py | 31 ++++--- tests/pallas/mosaic_gpu_test.py | 23 +++-- 3 files changed, 114 insertions(+), 64 deletions(-) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index ed7ce5312c23..f57efab589c0 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -15,7 +15,6 @@ from __future__ import annotations -from collections.abc import Sequence import dataclasses import itertools import math @@ -384,31 +383,29 @@ def mma( raise ValueError( f"MMA with block scaling requires N to be divisible by 32, got: {n}" ) - if k_group_elems != 128 or a_swizzle != b_swizzle: - assert utils.bitwidth(element_type) <= 8 - expected_swizzle = 128 // (8 // utils.bitwidth(element_type)) - raise NotImplementedError( - "MMA with block scaling requires swizzle to be" - f" {expected_swizzle} for dtype {element_type}, got:" - f" {a_swizzle=} and {b_swizzle=}" - ) assert a_scale is not None and b_scale is not None - if a_scale.shape != (m, 4): + if a_scale.shape != (m, k // 32): raise ValueError( - f"A scale shape mismatch: expected ({m}, 4), got {a_scale.shape}" + f"A scale shape mismatch: expected ({m}, {k // 32}), got {a_scale.shape}" ) if a_scale.dtype != ir.Float8E8M0FNUType.get(): raise ValueError( f"A scale dtype mismatch: expected f8e8m0fnu, got {a_scale.dtype}" ) - if b_scale.shape != (n, 4): + if b_scale.shape != (n, k // 32): raise ValueError( - f"B scale shape mismatch: expected ({n}, 4), got {b_scale.shape}" + f"B scale shape mismatch: expected ({n}, {k // 32}), got {b_scale.shape}" ) if b_scale.dtype != ir.Float8E8M0FNUType.get(): raise ValueError( f"B scale dtype mismatch: expected f8e8m0fnu, got {b_scale.dtype}" ) + if k_group_elems % 128: + min_swizzle = 16 * utils.bitwidth(element_type) + raise NotImplementedError( + f"{element_type} MMA with block scaling requires swizzle to be at" + f" least {min_swizzle}" + ) if is_sparse: a_sparse_metadata = cast(TMEMRef, a_sparse_metadata) if collective: @@ -479,6 +476,8 @@ def mma( assert d.layout.base_tile_shape[0] % 4 == 0 lanes_per_n_group = d.layout.base_tile_shape[0] // 4 a_sparse_addr_base = a_sparse_metadata.address if is_sparse else None # type: ignore + a_scale_addr_base = a_scale.address if is_scaled else None # type: ignore + b_scale_addr_base = b_scale.address if is_scaled else None # type: ignore for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups): if isinstance(a, TMEMRef): if m_groups != 1: @@ -498,8 +497,23 @@ def mma( a_sparse_addr = arith.addi(a_sparse_addr_base, utils.c(ki * cols_per_k_group, i32)) else: a_sparse_addr = None - if is_scaled and (m_groups != 1 or n_groups != 1 or k_groups != 1): - raise NotImplementedError("Block-scaled metadata address calculation for multiple tiles") + if a_scale_addr_base is not None and b_scale_addr_base is not None: + if m_groups != 1: + raise NotImplementedError("A scale address calculation for multiple M tiles") + if n_groups != 1: + raise NotImplementedError("B scale address calculation for multiple N tiles") + assert k_group_elems % 128 == 0 + assert m_group_elems % 32 == 0 and n_group_elems % 32 == 0 + a_scale_addr = arith.addi( + a_scale_addr_base, + utils.c(ki * k_group_elems // 128 * m_group_elems // 32, i32), + ) + b_scale_addr = arith.addi( + b_scale_addr_base, + utils.c(ki * k_group_elems // 128 * n_group_elems // 32, i32), + ) + else: + a_scale_addr = b_scale_addr = None acc = accumulate if ki == 0 else true ni_lane_group, ni_col = ni // n_col_groups, ni % n_col_groups d_offset = ( @@ -521,8 +535,8 @@ def mma( b_transpose=b_fastest != mma_utils.Dim.K, a_k_strides=a_k_instr_strides, b_k_strides=b_k_instr_strides, - a_scale_addr=a_scale.address if a_scale is not None else None, - b_scale_addr=b_scale.address if b_scale is not None else None, + a_scale_addr=a_scale_addr, + b_scale_addr=b_scale_addr, a_sparse_addr=a_sparse_addr, accumulate=acc, element_type=mma_element_type, @@ -561,7 +575,6 @@ def _do_mma( instr_k = (1 + is_sparse) * 8 * 32 // elem_bitwidth packing = 8 * 4 // elem_bitwidth - extra_args: Sequence[object] scale_steps = None if is_scaled: assert not is_sparse @@ -577,7 +590,6 @@ def _do_mma( create_scaled_instr_descriptor = create_scaled_f4_instr_descriptor else: raise NotImplementedError(f"Unsupported element type for block scaling: {element_type}") - extra_args = (a_scale_addr, b_scale_addr) extra_ptx = "[$5], [$6], " extra_constraints = ",r,r" else: @@ -591,7 +603,6 @@ def _do_mma( kind = "i8" else: raise NotImplementedError(f"Unsupported input element type: {element_type}") - extra_args = () extra_constraints = extra_ptx = "" def create_scaled_instr_descriptor(*args): @@ -605,8 +616,8 @@ def create_scaled_instr_descriptor(*args): sparse_meta_ptx = "[$5], " if is_sparse else "" extra_constraints += ",r" if is_sparse else "" sparse_addr: tuple[Any, ...] = () + scales_addrs: tuple[Any, ...] = () assert a_desc_or_addr.type == ir.IntegerType.get_signless(32 if a_in_tmem else 64) - assert scale_steps is None or scale_steps == k // instr_k def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...]): assert len(idx_tiling) + 1 == len(strides) idxs = [] @@ -620,10 +631,18 @@ def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...]) if is_scaled: assert scale_steps is not None scale_vec_width = 4 // scale_steps - scale_id = k_step * scale_vec_width + scale_id = (k_step % scale_steps) * scale_vec_width i_desc = create_scaled_instr_descriptor( m, n, element_type, element_type, scale_id, scale_id, a_transpose, b_transpose ) + assert m == 128 + assert n % 128 == 0 + a_scale_addr_offset = arith.constant(i32, k_step // scale_steps * 4) + b_scale_addr_offset = arith.constant(i32, k_step // scale_steps * n // 32) + scales_addrs = ( + arith.addi(a_scale_addr, a_scale_addr_offset), + arith.addi(b_scale_addr, b_scale_addr_offset), + ) else: sp_selector = None if is_sparse: @@ -657,7 +676,7 @@ def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...]) b_desc_instr = arith.addi(b_desc, _get_offset(k_step, b_k_idx_tiling, b_k_strides)) llvm.inline_asm( ir.Type.parse("!llvm.void"), - [d_addr, a_desc_or_addr_instr, b_desc_instr, i_desc, accumulate, *extra_args, *sparse_addr], + [d_addr, a_desc_or_addr_instr, b_desc_instr, i_desc, accumulate, *scales_addrs, *sparse_addr], f"tcgen05.mma{sparse_mod}.cta_group::{num_cta}.kind::{kind} [$0], {a_ptx}, $2, {sparse_meta_ptx}$3, {extra_ptx}$4;", f"r,{a_ptx_constraint},l,r,b" + extra_constraints, has_side_effects=True, @@ -831,7 +850,14 @@ def check_type(self, shape: tuple[int, ...], bitwidth: int) -> None: def cols_in_shape(self, shape: tuple[int, int], bitwidth: int) -> int: self.check_type(shape, bitwidth) - return math.prod(shape) // TMEM_ROWS // self.vector_length + replication_factor = 1 + for dim in self.warp_dims: + if isinstance(dim, fa.Replicated): + replication_factor *= dim.times + for dim in self.lane_dims: + if isinstance(dim, fa.Replicated): + replication_factor *= dim.times + return math.prod(shape) // TMEM_ROWS // self.vector_length * replication_factor def canonicalize(self) -> TMEMLayout: layout = super().canonicalize() @@ -1396,17 +1422,19 @@ def async_copy_scales_smem_to_tmem(smem_ref: ir.Value, tmem_ref: TMEMRef) -> Non MMA issued in the same thread, no additional synchronization is needed. At the moment the function requires ``smem_ref`` to be contiguous and have a - shape of (MN // 128, 32, 16) for 8-bit scales (here MN stands for the size of - the non-contracting dimension which is M or N), matching the scale layout for - .scale_vec::1X. See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x + shape of ``(MN // 128, K // 128, 32, 16)`` for 8-bit scales (here MN stands + for the size of the non-contracting dimension which is M or N), matching the + scale layout for .scale_vec::1X. See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x for more details. Note that we always put the non-contracting dimension first. - If you have a (MN, 4) array of scales in JAX (where MN is divisible by 128), - you can prepare it for use in the kernel this way:: + If you have a (MN, K // 32) array of scales in JAX (where MN and K are + divisible by 128), you can prepare it for use in the kernel this way:: - scales.reshape(-1, 4, 32, 4).swapaxes(1, 2).reshape(-1, 32, 16) + scales.reshape(mn // 128, 4, 32, k // 4, 4) + .transpose(0, 3, 2, 1, 4) + .reshape(mn // 128, k // 4, 32, 16) - The TMEM ref is expected to have the logical shape of the scales (MN, 4), and - the layout created by ``scales_layout()``. + The TMEM ref is expected to have the logical shape of the scales + ``(MN, K // 32)``, and the layout created by ``scales_layout()``. """ i32 = ir.IntegerType.get_signless(32) smem_ty = ir.MemRefType(smem_ref.type) @@ -1416,30 +1444,42 @@ def async_copy_scales_smem_to_tmem(smem_ref: ir.Value, tmem_ref: TMEMRef) -> Non raise NotImplementedError(f"Unsupported dtype: {dtype}, only f8e8m0fnu supported") if tmem_ref.shape[0] % TMEM_ROWS: raise ValueError(f"TMEM reference must have a multiple of {TMEM_ROWS} rows, but got {tmem_ref.shape[0]}") - if tmem_ref.shape[1] != 4: - raise ValueError(f"TMEM reference must have 4 colums, but got {tmem_ref.shape[1]}") + if tmem_ref.shape[1] % 4: + raise ValueError(f"TMEM reference must have a multiple of 4 columns, but got {tmem_ref.shape[1]}") if tmem_ref.layout != scales_layout(): raise ValueError(f"TMEM layout {tmem_ref.layout} is not supported") smem_shape = tuple(smem_ty.shape) - expected_smem_shape = (tmem_ref.shape[0] // TMEM_ROWS, 32, 16) + expected_smem_shape = (tmem_ref.shape[0] // TMEM_ROWS, tmem_ref.shape[1] // 4, 32, 16) if smem_shape != expected_smem_shape: raise NotImplementedError( f"SMEM has {smem_shape}, but expected {expected_smem_shape} for TMEM" f" ref shape {tmem_ref.shape}" ) strides, _ = smem_ty.get_strides_and_offset() + # TODO(apaszke): This should only matter for the two minor dims. if strides != utils.get_contiguous_strides(smem_shape): raise ValueError("Only copies from contiguous SMEM references are supported") - row_tile_stride = strides[0] - if row_tile_stride % 4: - raise ValueError("Column tile stride must be a multiple of 4") - row_tile_stride_i32 = row_tile_stride // 4 + mn_tile_stride, k_tile_stride = strides[:2] + # One tile of scales has 128 bytes. + if mn_tile_stride % 128 or k_tile_stride % 128: + raise ValueError("Scale tile strides must be a multiple of 128") + mn_tile_stride_i32 = mn_tile_stride // 4 + k_tile_stride_i32 = k_tile_stride // 4 smem_base_ptr = utils.memref_ptr(smem_ref, 3) - for row_tile in range(expected_smem_shape[0]): + # TODO(apaszke): Need to figure out the TMEM layout otherwise and MMA doesn't + # support it anyway. + if smem_shape[0] > 2: + raise NotImplementedError("Only M/N up to 256 supported") + for mn_tile, k_tile in np.ndindex(smem_shape[:2]): load_ptr = utils.getelementptr( - smem_base_ptr, [row_tile * row_tile_stride_i32], i32 + smem_base_ptr, + [mn_tile * mn_tile_stride_i32 + k_tile * k_tile_stride_i32], + i32, + ) + store_addr = arith.addi( + tmem_ref.address, + arith.constant(i32, 4 * smem_shape[0] * k_tile + 4 * mn_tile), ) - store_addr = arith.addi(tmem_ref.address, arith.constant(i32, 4 * row_tile)) # The "core matrix" here is the same as in MMA: 8x(16 bytes). desc = mma_utils.encode_descriptor(load_ptr, 0, 8 * 16, swizzle=None) nvvm.tcgen05_cp( diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index f106cbec8a26..cff36e4d61f4 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1567,7 +1567,7 @@ def kernel(ctx, src, out, scratch): )._debug_print() copy(src, out) - shape = (1, 32, 16) + shape = (1, 1, 32, 16) x = jax.lax.bitcast_convert_type( np.arange(math.prod(shape), dtype=np.uint8).reshape(shape), dtype ) @@ -1601,8 +1601,8 @@ def kernel(ctx, src, out, scratch): def test_mma_block_scaled(self, m, n, in_jax_dtype): out_jax_dtype = jnp.float32 scale_jax_dtype = jnp.float8_e8m0fnu - swizzle = 128 // (8 // jnp.finfo(in_jax_dtype).bits) - k_steps = 1 + swizzle = 128 + k_steps = 2 if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: self.skipTest("Only f16 input is supported for f16 output.") @@ -1648,24 +1648,29 @@ def kernel(ctx, lhs, rhs, lhs_scales_gmem, rhs_scales_gmem, out, scratch): scratch_shape = [ jax.ShapeDtypeStruct(tile_shape(x_shape, lhs_tiling), in_jax_dtype), jax.ShapeDtypeStruct(tile_shape(y_shape, rhs_tiling), in_jax_dtype), - jax.ShapeDtypeStruct((m // 128, 32, 16), scale_jax_dtype), - jax.ShapeDtypeStruct((n // 128, 32, 16), scale_jax_dtype), + jax.ShapeDtypeStruct((m // 128, k // (32 * 4), 32, 16), scale_jax_dtype), + jax.ShapeDtypeStruct((n // 128, k // (32 * 4), 32, 16), scale_jax_dtype), mgpu.TMABarrier(4), mgpu.Barrier(1), mgpu.TMEM((m, n), out_jax_dtype), - mgpu.TMEM((m, 4), scale_jax_dtype, layout=tcgen05.scales_layout()), - mgpu.TMEM((n, 4), scale_jax_dtype, layout=tcgen05.scales_layout()), + mgpu.TMEM((m, k // 32), scale_jax_dtype, layout=tcgen05.scales_layout()), + mgpu.TMEM((n, k // 32), scale_jax_dtype, layout=tcgen05.scales_layout()), ] ka, kb = jax.random.split(jax.random.key(1234), 2) a_scales = jax.lax.bitcast_convert_type( - jax.random.randint(ka, (m, 4), 122, 132, dtype=jnp.uint8), scale_jax_dtype + jax.random.randint(ka, (m, k // 32), 122, 132, dtype=jnp.uint8), scale_jax_dtype ) b_scales = jax.lax.bitcast_convert_type( - jax.random.randint(kb, (n, 4), 122, 132, dtype=jnp.uint8), scale_jax_dtype + jax.random.randint(kb, (n, k // 32), 122, 132, dtype=jnp.uint8), scale_jax_dtype ) def format_scales(scales): - assert scales.shape[0] % 128 == 0 and scales.shape[1] == 4 - return scales.reshape(-1, 4, 32, 4).swapaxes(1, 2).reshape(-1, 32, 16) + mn, k = scales.shape + assert mn % 128 == 0 and k % 4 == 0, scales.shape + return ( + scales.reshape(mn // 128, 4, 32, k // 4, 4) + .transpose(0, 3, 2, 1, 4) + .reshape(mn // 128, k // 4, 32, 16) + ) a_gpu_scales, b_gpu_scales = map(format_scales, (a_scales, b_scales)) args = (x, y, a_gpu_scales, b_gpu_scales) z = mgpu.as_gpu_kernel( @@ -1675,9 +1680,7 @@ def format_scales(scales): a_logical_scales = jnp.repeat(a_scales, 32, axis=1).astype(jnp.float32) b_logical_scales = jnp.repeat(b_scales, 32, axis=1).astype(jnp.float32) ref = (x32 * a_logical_scales) @ (y32 * b_logical_scales).T - atol = 2e-2 if out_jax_dtype == jnp.float16 else 7e-5 - rtol = 8e-4 if out_jax_dtype == jnp.float16 else 5e-6 - np.testing.assert_allclose(z, ref, atol=atol, rtol=rtol) + np.testing.assert_allclose(z, ref, atol=2e-4, rtol=5e-6) @parameterized.product( lhs_transpose=(False, True), diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index db60cdb5ba3b..0997acf39fda 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -3632,8 +3632,10 @@ def kernel(a_smem, b_smem, out_ref, _, acc_tmem, barrier_ref): ) def test_simple_scaled_matmul(self, m, n, dtype): self.skip_if_wg_semantics() - k = 128 - swizzle = 128 // (8 // jnp.finfo(dtype).bits) + # TODO(apaszke): Add support for single-buffering in pallas_call. + causes_oom = jnp.finfo(dtype).bits == 8 and n == 256 + k = 128 if causes_oom else 256 + swizzle = 128 transforms = self.default_transforms(swizzle=swizzle, dtype=dtype) out_transforms = self.default_transforms(dtype=jnp.float32) @@ -3656,8 +3658,8 @@ def kernel(a_smem, b_smem, a_scale_smem, b_scale_smem, out_ref, scratch_shapes = [ plgpu.Barrier(orders_tensor_core=True), plgpu.TMEM((m, n), jnp.float32), - plgpu.TMEM((m, 4), jnp.float8_e8m0fnu, layout=plgpu.TMEMLayout.SCALES_LAYOUT), - plgpu.TMEM((n, 4), jnp.float8_e8m0fnu, layout=plgpu.TMEMLayout.SCALES_LAYOUT), + plgpu.TMEM((m, k // 32), jnp.float8_e8m0fnu, layout=plgpu.TMEMLayout.SCALES_LAYOUT), + plgpu.TMEM((n, k // 32), jnp.float8_e8m0fnu, layout=plgpu.TMEMLayout.SCALES_LAYOUT), ] f = self.pallas_call( @@ -3676,16 +3678,21 @@ def kernel(a_smem, b_smem, a_scale_smem, b_scale_smem, out_ref, y = jax.random.uniform(jax.random.key(2), shape=(n, k), dtype=jnp.float32).astype(dtype) ksx, ksy = jax.random.split(jax.random.key(1234), 2) x_scale = jax.lax.bitcast_convert_type( - jax.random.randint(ksx, (m, 4), 122, 132, dtype=jnp.uint8), + jax.random.randint(ksx, (m, k // 32), 122, 132, dtype=jnp.uint8), jnp.float8_e8m0fnu ) y_scale = jax.lax.bitcast_convert_type( - jax.random.randint(ksy, (n, 4), 122, 132, dtype=jnp.uint8), + jax.random.randint(ksy, (n, k // 32), 122, 132, dtype=jnp.uint8), jnp.float8_e8m0fnu ) def format_scales(scales): - assert scales.shape[0] % 128 == 0 and scales.shape[1] == 4 - return scales.reshape(-1, 4, 32, 4).swapaxes(1, 2).reshape(-1, 32, 16) + mn, k = scales.shape + assert mn % 128 == 0 and k % 4 == 0 + return ( + scales.reshape(mn // 128, 4, 32, k // 4, 4) + .transpose(0, 3, 2, 1, 4) + .reshape(mn // 128, k // 4, 32, 16) + ) result = f(x, y, format_scales(x_scale), format_scales(y_scale)) x_logical_scale = jnp.repeat(x_scale, 32, axis=1).astype(jnp.float32) y_logical_scale = jnp.repeat(y_scale, 32, axis=1).astype(jnp.float32)