Skip to content

Commit ae9292e

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add support for integer WGMMA with LHS in registers under WG semantic.
PiperOrigin-RevId: 830827398
1 parent 8c5bfd2 commit ae9292e

File tree

6 files changed

+71
-37
lines changed

6 files changed

+71
-37
lines changed

jax/experimental/mosaic/gpu/dialect_lowering.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,14 +1161,10 @@ def _bitcast_op_lowering_rule(
11611161
def _mgpu_wgmma_op_lowering_rule(
11621162
_: LoweringContext, wgmma_op: mgpu.WGMMAOp
11631163
) -> Sequence[ir.Value]:
1164-
fa_layouts = (
1165-
*inference_utils.in_layouts(wgmma_op),
1166-
*inference_utils.out_layouts(wgmma_op),
1167-
)
1168-
wgmma_layout = layouts.to_layout_attr(fa.WGMMA_LAYOUT)
1169-
for layout in fa_layouts:
1170-
if layout != wgmma_layout:
1171-
raise ValueError("Layout mismatch")
1164+
in_layouts = inference_utils.in_layouts(wgmma_op)
1165+
assert in_layouts[0] == layouts.to_layout_attr(fa.WGMMA_LAYOUT)
1166+
[out_layout] = inference_utils.out_layouts(wgmma_op)
1167+
assert out_layout == layouts.to_layout_attr(fa.WGMMA_LAYOUT)
11721168

11731169
# s8/i8 WGMMA expects signed integer accumulator.
11741170
element_type = wgmma_op.a.type.element_type
@@ -1177,7 +1173,7 @@ def _mgpu_wgmma_op_lowering_rule(
11771173
# The associated fence could be a little expensive and is not needed if the
11781174
# result a wgmma feeds into another wgmma (even in another loop step).
11791175
regs = _fragmented_array_from_ir(
1180-
wgmma_op.accumulator, wgmma_layout, is_signed
1176+
wgmma_op.accumulator, in_layouts[0], is_signed
11811177
)
11821178
acc = wgmma.WGMMAAccumulator.from_registers(regs)
11831179

@@ -1200,7 +1196,13 @@ def _mgpu_wgmma_op_lowering_rule(
12001196
)
12011197

12021198
if ir.VectorType.isinstance(wgmma_op.a.type):
1203-
a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout, is_signed)
1199+
expected_a_layout = (
1200+
fa.WGMMA_LAYOUT_8BIT
1201+
if element_type == ir.IntegerType.get_signless(8)
1202+
else fa.WGMMA_LAYOUT
1203+
)
1204+
assert in_layouts[1] == layouts.to_layout_attr(expected_a_layout)
1205+
a_operand = _fragmented_array_from_ir(wgmma_op.a, in_layouts[1], is_signed)
12041206
else:
12051207
a_swizzle, a_transforms = swizzle_and_transforms_from_transforms_attr(
12061208
a_transforms
@@ -1217,7 +1219,6 @@ def _mgpu_wgmma_op_lowering_rule(
12171219
a_operand = unwrapped_a_ref
12181220

12191221
new_acc = wgmma.wgmma(acc, a_operand, unwrapped_b_ref, swizzle=b_swizzle)
1220-
12211222
return [
12221223
fragmented_array_to_ir(
12231224
new_acc.value.to_layout(fa.WGMMA_LAYOUT),

jax/experimental/mosaic/gpu/inference_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def is_mma_layout(layout: fa.FragmentedLayout) -> bool:
291291
fa.WGMMA_LAYOUT_UPCAST_2X,
292292
fa.WGMMA_LAYOUT_UPCAST_4X,
293293
fa.WGMMA_TRANSPOSED_LAYOUT,
294+
fa.WGMMA_LAYOUT_8BIT,
294295
fa.TCGEN05_LAYOUT,
295296
fa.TCGEN05_TRANSPOSED_LAYOUT,
296297
}:

jax/experimental/mosaic/gpu/layout_inference.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -962,31 +962,34 @@ def _wgmma_equation_system(
962962
op: mgpu.WGMMAOp,
963963
) -> tuple[eqns.EquationSystem, ValueSitesForVariable, list[Hint]]:
964964
assignments: dict[eqns.Variable, eqns.Constant] = {}
965-
# Registers
966-
vector_operands_or_results = vector_value_sites(op)
967-
vec_variable = eqns.Variable(vector_operands_or_results[0])
968-
assignments[vec_variable] = eqns.RegisterLayout(fa.WGMMA_LAYOUT)
969-
operands_or_results_for_variable = {vec_variable: vector_operands_or_results}
965+
value_sites_for_variable: ValueSitesForVariable = {}
966+
967+
acc_out = ValueSite(op, VariableType.RESULT, 0)
968+
acc_in = ValueSite(op, VariableType.OPERAND, 0)
969+
acc_var = eqns.Variable(acc_out)
970+
assignments[acc_var] = eqns.RegisterLayout(fa.WGMMA_LAYOUT)
971+
value_sites_for_variable[acc_var] = [acc_in, acc_out]
970972

971-
# SMEM
972973
a_tiling, b_tiling = _infer_wgmma_tiling(op.a.type, op.b.type)
973974
b = ValueSite(op, VariableType.OPERAND, 2)
974975
b_var = ctx.producer_ref(b)
975-
976976
assignments[b_var] = eqns.SMEMTiling(lc.TileTransform(b_tiling))
977-
operands_or_results_for_variable[b_var] = [b]
977+
value_sites_for_variable[b_var] = [b]
978978

979-
if a_tiling is not None:
980-
# a is in SMEM
981-
a = ValueSite(op, VariableType.OPERAND, 1)
979+
a = ValueSite(op, VariableType.OPERAND, 1)
980+
if _is_smem_ref(op.a):
982981
a_var = ctx.producer_ref(a)
983982
assignments[a_var] = eqns.SMEMTiling(lc.TileTransform(a_tiling))
984-
operands_or_results_for_variable[a_var] = [a]
983+
else:
984+
assert a_tiling is None
985+
a_var = eqns.Variable(a)
986+
if ir.IntegerType.get_signless(8) == ir.VectorType(op.a.type).element_type:
987+
assignments[a_var] = eqns.RegisterLayout(fa.WGMMA_LAYOUT_8BIT)
988+
else:
989+
assignments[a_var] = eqns.RegisterLayout(fa.WGMMA_LAYOUT)
990+
value_sites_for_variable[a_var] = [a]
985991

986-
system = eqns.EquationSystem(
987-
assignments=assignments,
988-
)
989-
return system, operands_or_results_for_variable, []
992+
return eqns.EquationSystem(assignments), value_sites_for_variable, []
990993

991994

992995
@_add_equation_system_derivation_rule(vector.BroadcastOp)

tests/mosaic/gpu_layout_inference_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,32 @@ def test_infer_transforms_for_wgmma_op(self, swizzle, dtype, lhs_in_registers):
13121312
inference_utils.in_transforms(wgmma_op), in_transforms
13131313
)
13141314

1315+
@parameterized.product(
1316+
dtype=(jnp.int8, jnp.uint8),
1317+
lhs_in_registers=(False, True),
1318+
)
1319+
def test_infer_layouts_for_8bits_wgmma_op(self, dtype, lhs_in_registers):
1320+
shape = (128, 128)
1321+
with ir.InsertionPoint(self.module.body):
1322+
elt_ty = mgpu.utils.dtype_to_ir_type(dtype)
1323+
lhs_ref_ty = ir.MemRefType.get(
1324+
shape, elt_ty, memory_space=mgpu.utils.smem()
1325+
)
1326+
lhs_vec_ty = ir.VectorType.get(shape, elt_ty)
1327+
lhs_ty = lhs_vec_ty if lhs_in_registers else lhs_ref_ty
1328+
rhs_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem())
1329+
acc_ty = ir.VectorType.get(shape, elt_ty)
1330+
[acc, lhs, rhs] = undefs(acc_ty, lhs_ty, rhs_ty)
1331+
wgmma_op = mgpu.dialect.WGMMAOp(acc, lhs, rhs)
1332+
1333+
mgpu.infer_layout(self.module)
1334+
1335+
if lhs_in_registers:
1336+
self.checkInLayouts(wgmma_op, [mgpu.WGMMA_LAYOUT, mgpu.WGMMA_LAYOUT_8BIT])
1337+
else:
1338+
self.checkInLayouts(wgmma_op, [mgpu.WGMMA_LAYOUT])
1339+
self.checkOutLayouts(wgmma_op, [mgpu.WGMMA_LAYOUT])
1340+
13151341
@parameterized.product(
13161342
swizzle_lhs=tuple(mgpu.dialect.SwizzlingMode),
13171343
swizzle_rhs=tuple(mgpu.dialect.SwizzlingMode),

tests/mosaic/gpu_test.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4869,41 +4869,45 @@ def matmul(
48694869
rtol=0,
48704870
)
48714871

4872-
@parameterized.parameters(jnp.int8, jnp.uint8)
4873-
def test_integer_wgmma(self, dtype):
4872+
@parameterized.product(
4873+
dtype=(jnp.int8, jnp.uint8),
4874+
lhs_in_smem=(False, True),
4875+
)
4876+
def test_integer_wgmma(self, dtype, lhs_in_smem):
48744877
m, k, n = 64, 128, 64
48754878

48764879
def body(ctx, lhs_gmem, rhs_gmem, result_gmem, scratch):
48774880
del ctx
4878-
lhs, rhs, tma_barrier = scratch
4881+
lhs_smem, rhs_smem, tma_barrier = scratch
48794882

48804883
i32 = ir.IntegerType.get_signless(32)
48814884
zero = arith.constant(i32, 0)
48824885

48834886
tma_barrier.arrive_expect_tx(m * k + k * n)
48844887
mgpu_dialect.async_load(
48854888
source=lhs_gmem,
4886-
destination=lhs,
4889+
destination=lhs_smem,
48874890
barrier=tma_barrier.as_barrier_memref(),
48884891
indices=[zero, zero],
4889-
slice_lengths=lhs.type.shape,
4892+
slice_lengths=lhs_smem.type.shape,
48904893
collective=ir.ArrayAttr.get([]),
48914894
)
48924895
mgpu_dialect.async_load(
48934896
source=rhs_gmem,
4894-
destination=rhs,
4897+
destination=rhs_smem,
48954898
barrier=tma_barrier.as_barrier_memref(),
48964899
indices=[zero, zero],
4897-
slice_lengths=rhs.type.shape,
4900+
slice_lengths=rhs_smem.type.shape,
48984901
collective=ir.ArrayAttr.get([]),
48994902
)
49004903
tma_barrier.wait()
49014904

49024905
acc_type = ir.VectorType.get((m, n), i32)
49034906
acc = vector.broadcast(acc_type, zero)
4907+
lhs = lhs_smem if lhs_in_smem else mgpu_dialect.vector_load(lhs_smem)
49044908
# Only f16 WGMMA supports transposes
4905-
rhs = utils.memref_transpose(rhs, (1, 0))
4906-
result = mgpu_dialect.wgmma(acc, lhs, rhs)
4909+
rhs_smem = utils.memref_transpose(rhs_smem, (1, 0))
4910+
result = mgpu_dialect.wgmma(acc, lhs, rhs_smem)
49074911
nvvm.wgmma_commit_group_sync_aligned()
49084912
nvvm.wgmma_wait_group_sync_aligned(0)
49094913
mgpu_dialect.vector_store(result, result_gmem)

tests/pallas/mosaic_gpu_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3024,7 +3024,6 @@ def scope(acc_ref):
30243024
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
30253025

30263026
def test_wgmma_registers_integer(self):
3027-
self.skip_if_wg_semantics() # WGMMA_8BIT layout not supported
30283027
input_dtype = jnp.int8
30293028
out_dtype = jnp.int32
30303029
def kernel(a_ref, b_ref, o_ref):

0 commit comments

Comments
 (0)