Skip to content

Commit bbec947

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Use the packing trick for int4 -> bf16 conversions
PiperOrigin-RevId: 832278748
1 parent b3fc3af commit bbec947

File tree

2 files changed

+61
-56
lines changed

2 files changed

+61
-56
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1946,7 +1946,11 @@ def upcast_to_f8e4m3fn(reg: ir.Value, part: int):
19461946
if cur_dtype == i4 and self.is_signed and new_dtype == bf16 and vector_len % 2 == 0:
19471947
new_registers = np.empty_like(self.registers)
19481948
out_vec_ty = ir.VectorType.get((vector_len,), new_dtype)
1949-
for idx, reg in np.ndenumerate(self.registers):
1949+
# We use packed_registers for consistency, even though the packing is not
1950+
# really profitable here: the PTX below begins by an op dependent on the
1951+
# extracted part and so there are no ops that can be shared across packed
1952+
# parts.
1953+
for indices, reg in packed_registers(2, if_not_sliced=True):
19501954
# The algorithm here is largely the same as CUTLASS's
19511955
# NumericArrayConverter specialization for int4 -> bf16 casts.
19521956
# We modify it slightly, because we only extract 2 values.
@@ -1962,7 +1966,7 @@ def upcast_to_f8e4m3fn(reg: ir.Value, part: int):
19621966
# bias coming from flipping the sign bit which is 136 (0x4308 as bits).
19631967
def upcast_i4_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int):
19641968
assert 0 <= part < 4
1965-
return llvm.inline_asm(
1969+
int_reg = llvm.inline_asm(
19661970
i32,
19671971
[reg, reg_shr],
19681972
f"""
@@ -1976,43 +1980,43 @@ def upcast_i4_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int):
19761980
""",
19771981
"=r,r,r",
19781982
)
1979-
offset = 0
1983+
return utils.bitcast(int_reg, ir.VectorType.get((2,), bf16))
1984+
[group_size] = ir.VectorType(reg.type).shape
1985+
assert group_size % vector_len == 0
1986+
assert group_size * 4 <= 32
1987+
int_ty = ir.IntegerType.get_signless(group_size * 4)
1988+
# If the vector originates from a slice (common after relayouts), we
1989+
# can fuse the slicing into the conversion and prevent LLVM from
1990+
# generating a bunch of shifts to align the vector data to the LSB.
1991+
# This also lets us share the right shift among more vectors.
19801992
out_int_regs: list[ir.Value] = []
1981-
# TODO(apaszke): Use packed_registers here.
1982-
for group_size in (8, 4, 2):
1983-
int_ty = ir.IntegerType.get_signless(group_size * 4)
1984-
while vector_len - offset >= group_size:
1985-
# If the vector originates from a slice (common after relayouts), we
1986-
# can fuse the slicing into the conversion and prevent LLVM from
1987-
# generating a bunch of shifts to align the vector data to the LSB.
1988-
# This also lets us share the right shift among more vectors.
1989-
if (isinstance(slice_op := reg.owner.opview, vector.ExtractStridedSliceOp)
1990-
and utils.bitwidth(slice_op.source.type) == 32
1991-
and slice_op.strides[0].value == 1):
1992-
slice_offset = slice_op.offsets[0].value + offset
1993-
reg_int = utils.bitcast(slice_op.source, i32)
1994-
reg_int_shr = arith.shrui(reg_int, c(4, i32))
1995-
out_int_regs.extend(
1996-
upcast_i4_to_bf16(reg_int, reg_int_shr, part=(slice_offset // 2 + part))
1997-
for part in range(group_size // 2)
1998-
)
1999-
else:
2000-
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
2001-
reg_slice_int = utils.bitcast(reg_slice, int_ty)
2002-
if int_ty != i32:
2003-
reg_slice_int = arith.extsi(i32, reg_slice_int)
2004-
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
2005-
out_int_regs.extend(
2006-
upcast_i4_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
2007-
for part in range(group_size // 2)
2008-
)
2009-
offset += group_size
2010-
assert offset == vector_len
2011-
out_vec_int = utils.vector_concat([
2012-
vector.broadcast(ir.VectorType.get((1,), i32), reg)
2013-
for reg in out_int_regs
2014-
])
2015-
new_registers[idx] = utils.bitcast(out_vec_int, out_vec_ty)
1993+
if regs_from_32bit_slice:
1994+
slice_op = reg.owner.opview
1995+
slice_offset = slice_op.offsets[0].value
1996+
reg_int = utils.bitcast(slice_op.source, i32)
1997+
reg_int_shr = arith.shrui(reg_int, c(4, i32))
1998+
assert slice_offset % 2 == 0
1999+
out_int_regs.extend(
2000+
upcast_i4_to_bf16(reg_int, reg_int_shr, part=slice_offset // 2 + part)
2001+
for part in range(group_size // 2)
2002+
)
2003+
else:
2004+
reg_slice_int = utils.bitcast(reg, int_ty)
2005+
if int_ty != i32:
2006+
reg_slice_int = arith.extsi(i32, reg_slice_int)
2007+
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
2008+
out_int_regs.extend(
2009+
upcast_i4_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
2010+
for part in range(group_size // 2)
2011+
)
2012+
out_reg = utils.vector_concat(out_int_regs)
2013+
offset = 0
2014+
for idx in indices:
2015+
new_registers[idx] = new_reg = utils.vector_slice(
2016+
out_reg, slice(offset, offset + vector_len)
2017+
)
2018+
offset += vector_len
2019+
assert new_reg.type == out_vec_ty
20162020
return FragmentedArray(
20172021
_registers=new_registers, _layout=self.layout, _is_signed=None
20182022
)
@@ -2058,6 +2062,7 @@ def upcast_i4_to_i8(reg: ir.Value, first_valid_nibble: int = 0):
20582062
])
20592063
[group_size] = ir.VectorType(reg.type).shape
20602064
assert group_size % vector_len == 0
2065+
assert group_size * 4 <= 32
20612066
int_ty = ir.IntegerType.get_signless(group_size * 4)
20622067
if regs_from_32bit_slice:
20632068
slice_op = reg.owner.opview

tests/mosaic/gpu_test.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -667,21 +667,21 @@ def kernel(ctx, inp, out, smem):
667667
(jnp.int4, jnp.int8),
668668
# TODO(apaszke,bchetioui): bf16/f32 -> f8e4m3fn
669669
),
670-
layout_desc=(
671-
"WGMMA_LAYOUT",
672-
"WGMMA_LAYOUT_8BIT",
673-
"WGMMA_LAYOUT_UPCAST_2X",
674-
"WGMMA_LAYOUT_UPCAST_4X",
670+
layout_descs=(
671+
("WGMMA_LAYOUT", "WGMMA_LAYOUT"),
672+
("WGMMA_LAYOUT_8BIT", "WGMMA_LAYOUT_8BIT"),
673+
("WGMMA_LAYOUT_UPCAST_2X", "WGMMA_LAYOUT_UPCAST_2X"),
674+
("WGMMA_LAYOUT_UPCAST_2X", "WGMMA_LAYOUT"),
675+
("WGMMA_LAYOUT_UPCAST_4X", "WGMMA_LAYOUT_UPCAST_4X"),
676+
("WGMMA_LAYOUT_UPCAST_4X", "WGMMA_LAYOUT_UPCAST_2X"),
677+
("WGMMA_LAYOUT_UPCAST_4X", "WGMMA_LAYOUT"),
675678
),
676-
change_layout=(False, True),
677679
)
678680
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
679-
def test_optimized_conversion(self, jax_dtype_from_to, layout_desc, change_layout):
680-
if change_layout and layout_desc == "WGMMA_LAYOUT":
681-
self.skipTest("No-op relayout")
682-
if change_layout and layout_desc == "WGMMA_LAYOUT_8BIT":
683-
self.skipTest("Unimplemented relayout")
684-
layout: fa.TiledLayout = getattr(fa, layout_desc)
681+
def test_optimized_conversion(self, jax_dtype_from_to, layout_descs):
682+
layout_desc_from, layout_desc_to = layout_descs
683+
layout_from: fa.TiledLayout = getattr(fa, layout_desc_from)
684+
layout_to: fa.TiledLayout = getattr(fa, layout_desc_to)
685685
jax_dtype_from, jax_dtype_to = jax_dtype_from_to
686686
mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from)
687687
mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
@@ -692,16 +692,16 @@ def kernel(ctx, inp, out, smem):
692692
t = mgpu.FragmentedArray.load_untiled(
693693
inp,
694694
is_signed=utils.is_signed(jax_dtype_from),
695-
layout=layout,
695+
layout=layout_from,
696696
optimized=False,
697697
)
698-
if change_layout:
698+
if layout_from != layout_to:
699699
if (
700-
layout == fa.WGMMA_LAYOUT_UPCAST_4X
701-
and utils.bitwidth(mlir_dtype_from) > 4
700+
layout_from == fa.WGMMA_LAYOUT_UPCAST_4X
701+
and utils.bitwidth(mlir_dtype_from) != 4
702702
):
703703
self.skipTest("Unimplemented relayout")
704-
t = t.to_layout(fa.WGMMA_LAYOUT)
704+
t = t.to_layout(layout_to)
705705
t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to))
706706
t.store_untiled(out, optimized=False)
707707

@@ -725,7 +725,7 @@ def _maybe_profile():
725725
with open(file_path, "a") as f:
726726
data = (
727727
jnp.dtype(jax_dtype_from).name, jnp.dtype(jax_dtype_to).name,
728-
layout_desc, change_layout, sass().count("\n"),
728+
layout_desc_from, layout_desc_to, sass().count("\n")
729729
)
730730
f.write(",".join(map(str, data)) + "\n")
731731
f.flush()

0 commit comments

Comments
 (0)