From ae9292e3834dd79b53b496cf2072e4c28cfd5c08 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Tue, 11 Nov 2025 02:53:01 -0800 Subject: [PATCH] [Mosaic GPU] Add support for integer WGMMA with LHS in registers under WG semantic. PiperOrigin-RevId: 830827398 --- .../mosaic/gpu/dialect_lowering.py | 23 ++++++------ .../mosaic/gpu/inference_utils.py | 1 + .../mosaic/gpu/layout_inference.py | 35 ++++++++++--------- tests/mosaic/gpu_layout_inference_test.py | 26 ++++++++++++++ tests/mosaic/gpu_test.py | 22 +++++++----- tests/pallas/mosaic_gpu_test.py | 1 - 6 files changed, 71 insertions(+), 37 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 1d0e61a693a0..a619c377b707 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -1161,14 +1161,10 @@ def _bitcast_op_lowering_rule( def _mgpu_wgmma_op_lowering_rule( _: LoweringContext, wgmma_op: mgpu.WGMMAOp ) -> Sequence[ir.Value]: - fa_layouts = ( - *inference_utils.in_layouts(wgmma_op), - *inference_utils.out_layouts(wgmma_op), - ) - wgmma_layout = layouts.to_layout_attr(fa.WGMMA_LAYOUT) - for layout in fa_layouts: - if layout != wgmma_layout: - raise ValueError("Layout mismatch") + in_layouts = inference_utils.in_layouts(wgmma_op) + assert in_layouts[0] == layouts.to_layout_attr(fa.WGMMA_LAYOUT) + [out_layout] = inference_utils.out_layouts(wgmma_op) + assert out_layout == layouts.to_layout_attr(fa.WGMMA_LAYOUT) # s8/i8 WGMMA expects signed integer accumulator. element_type = wgmma_op.a.type.element_type @@ -1177,7 +1173,7 @@ def _mgpu_wgmma_op_lowering_rule( # The associated fence could be a little expensive and is not needed if the # result a wgmma feeds into another wgmma (even in another loop step). regs = _fragmented_array_from_ir( - wgmma_op.accumulator, wgmma_layout, is_signed + wgmma_op.accumulator, in_layouts[0], is_signed ) acc = wgmma.WGMMAAccumulator.from_registers(regs) @@ -1200,7 +1196,13 @@ def _mgpu_wgmma_op_lowering_rule( ) if ir.VectorType.isinstance(wgmma_op.a.type): - a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout, is_signed) + expected_a_layout = ( + fa.WGMMA_LAYOUT_8BIT + if element_type == ir.IntegerType.get_signless(8) + else fa.WGMMA_LAYOUT + ) + assert in_layouts[1] == layouts.to_layout_attr(expected_a_layout) + a_operand = _fragmented_array_from_ir(wgmma_op.a, in_layouts[1], is_signed) else: a_swizzle, a_transforms = swizzle_and_transforms_from_transforms_attr( a_transforms @@ -1217,7 +1219,6 @@ def _mgpu_wgmma_op_lowering_rule( a_operand = unwrapped_a_ref new_acc = wgmma.wgmma(acc, a_operand, unwrapped_b_ref, swizzle=b_swizzle) - return [ fragmented_array_to_ir( new_acc.value.to_layout(fa.WGMMA_LAYOUT), diff --git a/jax/experimental/mosaic/gpu/inference_utils.py b/jax/experimental/mosaic/gpu/inference_utils.py index 604636d357c8..752372368c36 100644 --- a/jax/experimental/mosaic/gpu/inference_utils.py +++ b/jax/experimental/mosaic/gpu/inference_utils.py @@ -291,6 +291,7 @@ def is_mma_layout(layout: fa.FragmentedLayout) -> bool: fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_TRANSPOSED_LAYOUT, + fa.WGMMA_LAYOUT_8BIT, fa.TCGEN05_LAYOUT, fa.TCGEN05_TRANSPOSED_LAYOUT, }: diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index ac3d3bb60bff..b91cc17f8b03 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -962,31 +962,34 @@ def _wgmma_equation_system( op: mgpu.WGMMAOp, ) -> tuple[eqns.EquationSystem, ValueSitesForVariable, list[Hint]]: assignments: dict[eqns.Variable, eqns.Constant] = {} - # Registers - vector_operands_or_results = vector_value_sites(op) - vec_variable = eqns.Variable(vector_operands_or_results[0]) - assignments[vec_variable] = eqns.RegisterLayout(fa.WGMMA_LAYOUT) - operands_or_results_for_variable = {vec_variable: vector_operands_or_results} + value_sites_for_variable: ValueSitesForVariable = {} + + acc_out = ValueSite(op, VariableType.RESULT, 0) + acc_in = ValueSite(op, VariableType.OPERAND, 0) + acc_var = eqns.Variable(acc_out) + assignments[acc_var] = eqns.RegisterLayout(fa.WGMMA_LAYOUT) + value_sites_for_variable[acc_var] = [acc_in, acc_out] - # SMEM a_tiling, b_tiling = _infer_wgmma_tiling(op.a.type, op.b.type) b = ValueSite(op, VariableType.OPERAND, 2) b_var = ctx.producer_ref(b) - assignments[b_var] = eqns.SMEMTiling(lc.TileTransform(b_tiling)) - operands_or_results_for_variable[b_var] = [b] + value_sites_for_variable[b_var] = [b] - if a_tiling is not None: - # a is in SMEM - a = ValueSite(op, VariableType.OPERAND, 1) + a = ValueSite(op, VariableType.OPERAND, 1) + if _is_smem_ref(op.a): a_var = ctx.producer_ref(a) assignments[a_var] = eqns.SMEMTiling(lc.TileTransform(a_tiling)) - operands_or_results_for_variable[a_var] = [a] + else: + assert a_tiling is None + a_var = eqns.Variable(a) + if ir.IntegerType.get_signless(8) == ir.VectorType(op.a.type).element_type: + assignments[a_var] = eqns.RegisterLayout(fa.WGMMA_LAYOUT_8BIT) + else: + assignments[a_var] = eqns.RegisterLayout(fa.WGMMA_LAYOUT) + value_sites_for_variable[a_var] = [a] - system = eqns.EquationSystem( - assignments=assignments, - ) - return system, operands_or_results_for_variable, [] + return eqns.EquationSystem(assignments), value_sites_for_variable, [] @_add_equation_system_derivation_rule(vector.BroadcastOp) diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index f20f2ba4311d..b330e7a3a20e 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -1312,6 +1312,32 @@ def test_infer_transforms_for_wgmma_op(self, swizzle, dtype, lhs_in_registers): inference_utils.in_transforms(wgmma_op), in_transforms ) + @parameterized.product( + dtype=(jnp.int8, jnp.uint8), + lhs_in_registers=(False, True), + ) + def test_infer_layouts_for_8bits_wgmma_op(self, dtype, lhs_in_registers): + shape = (128, 128) + with ir.InsertionPoint(self.module.body): + elt_ty = mgpu.utils.dtype_to_ir_type(dtype) + lhs_ref_ty = ir.MemRefType.get( + shape, elt_ty, memory_space=mgpu.utils.smem() + ) + lhs_vec_ty = ir.VectorType.get(shape, elt_ty) + lhs_ty = lhs_vec_ty if lhs_in_registers else lhs_ref_ty + rhs_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + acc_ty = ir.VectorType.get(shape, elt_ty) + [acc, lhs, rhs] = undefs(acc_ty, lhs_ty, rhs_ty) + wgmma_op = mgpu.dialect.WGMMAOp(acc, lhs, rhs) + + mgpu.infer_layout(self.module) + + if lhs_in_registers: + self.checkInLayouts(wgmma_op, [mgpu.WGMMA_LAYOUT, mgpu.WGMMA_LAYOUT_8BIT]) + else: + self.checkInLayouts(wgmma_op, [mgpu.WGMMA_LAYOUT]) + self.checkOutLayouts(wgmma_op, [mgpu.WGMMA_LAYOUT]) + @parameterized.product( swizzle_lhs=tuple(mgpu.dialect.SwizzlingMode), swizzle_rhs=tuple(mgpu.dialect.SwizzlingMode), diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index d1b8c50a253d..f00c09e293bf 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -4869,13 +4869,16 @@ def matmul( rtol=0, ) - @parameterized.parameters(jnp.int8, jnp.uint8) - def test_integer_wgmma(self, dtype): + @parameterized.product( + dtype=(jnp.int8, jnp.uint8), + lhs_in_smem=(False, True), + ) + def test_integer_wgmma(self, dtype, lhs_in_smem): m, k, n = 64, 128, 64 def body(ctx, lhs_gmem, rhs_gmem, result_gmem, scratch): del ctx - lhs, rhs, tma_barrier = scratch + lhs_smem, rhs_smem, tma_barrier = scratch i32 = ir.IntegerType.get_signless(32) zero = arith.constant(i32, 0) @@ -4883,27 +4886,28 @@ def body(ctx, lhs_gmem, rhs_gmem, result_gmem, scratch): tma_barrier.arrive_expect_tx(m * k + k * n) mgpu_dialect.async_load( source=lhs_gmem, - destination=lhs, + destination=lhs_smem, barrier=tma_barrier.as_barrier_memref(), indices=[zero, zero], - slice_lengths=lhs.type.shape, + slice_lengths=lhs_smem.type.shape, collective=ir.ArrayAttr.get([]), ) mgpu_dialect.async_load( source=rhs_gmem, - destination=rhs, + destination=rhs_smem, barrier=tma_barrier.as_barrier_memref(), indices=[zero, zero], - slice_lengths=rhs.type.shape, + slice_lengths=rhs_smem.type.shape, collective=ir.ArrayAttr.get([]), ) tma_barrier.wait() acc_type = ir.VectorType.get((m, n), i32) acc = vector.broadcast(acc_type, zero) + lhs = lhs_smem if lhs_in_smem else mgpu_dialect.vector_load(lhs_smem) # Only f16 WGMMA supports transposes - rhs = utils.memref_transpose(rhs, (1, 0)) - result = mgpu_dialect.wgmma(acc, lhs, rhs) + rhs_smem = utils.memref_transpose(rhs_smem, (1, 0)) + result = mgpu_dialect.wgmma(acc, lhs, rhs_smem) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) mgpu_dialect.vector_store(result, result_gmem) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b1086b7fdff0..58535ca4297a 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -3024,7 +3024,6 @@ def scope(acc_ref): np.testing.assert_allclose(res, a @ b, rtol=1e-3) def test_wgmma_registers_integer(self): - self.skip_if_wg_semantics() # WGMMA_8BIT layout not supported input_dtype = jnp.int8 out_dtype = jnp.int32 def kernel(a_ref, b_ref, o_ref):