Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions jax/experimental/mosaic/gpu/dialect_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,14 +1161,10 @@ def _bitcast_op_lowering_rule(
def _mgpu_wgmma_op_lowering_rule(
_: LoweringContext, wgmma_op: mgpu.WGMMAOp
) -> Sequence[ir.Value]:
fa_layouts = (
*inference_utils.in_layouts(wgmma_op),
*inference_utils.out_layouts(wgmma_op),
)
wgmma_layout = layouts.to_layout_attr(fa.WGMMA_LAYOUT)
for layout in fa_layouts:
if layout != wgmma_layout:
raise ValueError("Layout mismatch")
in_layouts = inference_utils.in_layouts(wgmma_op)
assert in_layouts[0] == layouts.to_layout_attr(fa.WGMMA_LAYOUT)
[out_layout] = inference_utils.out_layouts(wgmma_op)
assert out_layout == layouts.to_layout_attr(fa.WGMMA_LAYOUT)

# s8/i8 WGMMA expects signed integer accumulator.
element_type = wgmma_op.a.type.element_type
Expand All @@ -1177,7 +1173,7 @@ def _mgpu_wgmma_op_lowering_rule(
# The associated fence could be a little expensive and is not needed if the
# result a wgmma feeds into another wgmma (even in another loop step).
regs = _fragmented_array_from_ir(
wgmma_op.accumulator, wgmma_layout, is_signed
wgmma_op.accumulator, in_layouts[0], is_signed
)
acc = wgmma.WGMMAAccumulator.from_registers(regs)

Expand All @@ -1200,7 +1196,13 @@ def _mgpu_wgmma_op_lowering_rule(
)

if ir.VectorType.isinstance(wgmma_op.a.type):
a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout, is_signed)
expected_a_layout = (
fa.WGMMA_LAYOUT_8BIT
if element_type == ir.IntegerType.get_signless(8)
else fa.WGMMA_LAYOUT
)
assert in_layouts[1] == layouts.to_layout_attr(expected_a_layout)
a_operand = _fragmented_array_from_ir(wgmma_op.a, in_layouts[1], is_signed)
else:
a_swizzle, a_transforms = swizzle_and_transforms_from_transforms_attr(
a_transforms
Expand All @@ -1217,7 +1219,6 @@ def _mgpu_wgmma_op_lowering_rule(
a_operand = unwrapped_a_ref

new_acc = wgmma.wgmma(acc, a_operand, unwrapped_b_ref, swizzle=b_swizzle)

return [
fragmented_array_to_ir(
new_acc.value.to_layout(fa.WGMMA_LAYOUT),
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/mosaic/gpu/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def is_mma_layout(layout: fa.FragmentedLayout) -> bool:
fa.WGMMA_LAYOUT_UPCAST_2X,
fa.WGMMA_LAYOUT_UPCAST_4X,
fa.WGMMA_TRANSPOSED_LAYOUT,
fa.WGMMA_LAYOUT_8BIT,
fa.TCGEN05_LAYOUT,
fa.TCGEN05_TRANSPOSED_LAYOUT,
}:
Expand Down
35 changes: 19 additions & 16 deletions jax/experimental/mosaic/gpu/layout_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,31 +962,34 @@ def _wgmma_equation_system(
op: mgpu.WGMMAOp,
) -> tuple[eqns.EquationSystem, ValueSitesForVariable, list[Hint]]:
assignments: dict[eqns.Variable, eqns.Constant] = {}
# Registers
vector_operands_or_results = vector_value_sites(op)
vec_variable = eqns.Variable(vector_operands_or_results[0])
assignments[vec_variable] = eqns.RegisterLayout(fa.WGMMA_LAYOUT)
operands_or_results_for_variable = {vec_variable: vector_operands_or_results}
value_sites_for_variable: ValueSitesForVariable = {}

acc_out = ValueSite(op, VariableType.RESULT, 0)
acc_in = ValueSite(op, VariableType.OPERAND, 0)
acc_var = eqns.Variable(acc_out)
assignments[acc_var] = eqns.RegisterLayout(fa.WGMMA_LAYOUT)
value_sites_for_variable[acc_var] = [acc_in, acc_out]

# SMEM
a_tiling, b_tiling = _infer_wgmma_tiling(op.a.type, op.b.type)
b = ValueSite(op, VariableType.OPERAND, 2)
b_var = ctx.producer_ref(b)

assignments[b_var] = eqns.SMEMTiling(lc.TileTransform(b_tiling))
operands_or_results_for_variable[b_var] = [b]
value_sites_for_variable[b_var] = [b]

if a_tiling is not None:
# a is in SMEM
a = ValueSite(op, VariableType.OPERAND, 1)
a = ValueSite(op, VariableType.OPERAND, 1)
if _is_smem_ref(op.a):
a_var = ctx.producer_ref(a)
assignments[a_var] = eqns.SMEMTiling(lc.TileTransform(a_tiling))
operands_or_results_for_variable[a_var] = [a]
else:
assert a_tiling is None
a_var = eqns.Variable(a)
if ir.IntegerType.get_signless(8) == ir.VectorType(op.a.type).element_type:
assignments[a_var] = eqns.RegisterLayout(fa.WGMMA_LAYOUT_8BIT)
else:
assignments[a_var] = eqns.RegisterLayout(fa.WGMMA_LAYOUT)
value_sites_for_variable[a_var] = [a]

system = eqns.EquationSystem(
assignments=assignments,
)
return system, operands_or_results_for_variable, []
return eqns.EquationSystem(assignments), value_sites_for_variable, []


@_add_equation_system_derivation_rule(vector.BroadcastOp)
Expand Down
26 changes: 26 additions & 0 deletions tests/mosaic/gpu_layout_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,32 @@ def test_infer_transforms_for_wgmma_op(self, swizzle, dtype, lhs_in_registers):
inference_utils.in_transforms(wgmma_op), in_transforms
)

@parameterized.product(
dtype=(jnp.int8, jnp.uint8),
lhs_in_registers=(False, True),
)
def test_infer_layouts_for_8bits_wgmma_op(self, dtype, lhs_in_registers):
shape = (128, 128)
with ir.InsertionPoint(self.module.body):
elt_ty = mgpu.utils.dtype_to_ir_type(dtype)
lhs_ref_ty = ir.MemRefType.get(
shape, elt_ty, memory_space=mgpu.utils.smem()
)
lhs_vec_ty = ir.VectorType.get(shape, elt_ty)
lhs_ty = lhs_vec_ty if lhs_in_registers else lhs_ref_ty
rhs_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem())
acc_ty = ir.VectorType.get(shape, elt_ty)
[acc, lhs, rhs] = undefs(acc_ty, lhs_ty, rhs_ty)
wgmma_op = mgpu.dialect.WGMMAOp(acc, lhs, rhs)

mgpu.infer_layout(self.module)

if lhs_in_registers:
self.checkInLayouts(wgmma_op, [mgpu.WGMMA_LAYOUT, mgpu.WGMMA_LAYOUT_8BIT])
else:
self.checkInLayouts(wgmma_op, [mgpu.WGMMA_LAYOUT])
self.checkOutLayouts(wgmma_op, [mgpu.WGMMA_LAYOUT])

@parameterized.product(
swizzle_lhs=tuple(mgpu.dialect.SwizzlingMode),
swizzle_rhs=tuple(mgpu.dialect.SwizzlingMode),
Expand Down
22 changes: 13 additions & 9 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4869,41 +4869,45 @@ def matmul(
rtol=0,
)

@parameterized.parameters(jnp.int8, jnp.uint8)
def test_integer_wgmma(self, dtype):
@parameterized.product(
dtype=(jnp.int8, jnp.uint8),
lhs_in_smem=(False, True),
)
def test_integer_wgmma(self, dtype, lhs_in_smem):
m, k, n = 64, 128, 64

def body(ctx, lhs_gmem, rhs_gmem, result_gmem, scratch):
del ctx
lhs, rhs, tma_barrier = scratch
lhs_smem, rhs_smem, tma_barrier = scratch

i32 = ir.IntegerType.get_signless(32)
zero = arith.constant(i32, 0)

tma_barrier.arrive_expect_tx(m * k + k * n)
mgpu_dialect.async_load(
source=lhs_gmem,
destination=lhs,
destination=lhs_smem,
barrier=tma_barrier.as_barrier_memref(),
indices=[zero, zero],
slice_lengths=lhs.type.shape,
slice_lengths=lhs_smem.type.shape,
collective=ir.ArrayAttr.get([]),
)
mgpu_dialect.async_load(
source=rhs_gmem,
destination=rhs,
destination=rhs_smem,
barrier=tma_barrier.as_barrier_memref(),
indices=[zero, zero],
slice_lengths=rhs.type.shape,
slice_lengths=rhs_smem.type.shape,
collective=ir.ArrayAttr.get([]),
)
tma_barrier.wait()

acc_type = ir.VectorType.get((m, n), i32)
acc = vector.broadcast(acc_type, zero)
lhs = lhs_smem if lhs_in_smem else mgpu_dialect.vector_load(lhs_smem)
# Only f16 WGMMA supports transposes
rhs = utils.memref_transpose(rhs, (1, 0))
result = mgpu_dialect.wgmma(acc, lhs, rhs)
rhs_smem = utils.memref_transpose(rhs_smem, (1, 0))
result = mgpu_dialect.wgmma(acc, lhs, rhs_smem)
nvvm.wgmma_commit_group_sync_aligned()
nvvm.wgmma_wait_group_sync_aligned(0)
mgpu_dialect.vector_store(result, result_gmem)
Expand Down
1 change: 0 additions & 1 deletion tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3024,7 +3024,6 @@ def scope(acc_ref):
np.testing.assert_allclose(res, a @ b, rtol=1e-3)

def test_wgmma_registers_integer(self):
self.skip_if_wg_semantics() # WGMMA_8BIT layout not supported
input_dtype = jnp.int8
out_dtype = jnp.int32
def kernel(a_ref, b_ref, o_ref):
Expand Down
Loading