Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions jax/experimental/mosaic/gpu/tcgen05.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def mma(
s32 = ir.IntegerType.get_signless(32)
if element_type == f32 or element_type == ir.BF16Type.get():
if element_type == f32 and is_sparse:
raise NotImplementedError("Only 16-bit types supported for sparse MMA")
raise NotImplementedError("Sparse MMA unsupported for f32")
if is_scaled:
raise ValueError(
f"MMA with element type {element_type} does not support block scaling"
Expand All @@ -288,8 +288,6 @@ def mma(
t.isinstance(element_type)
for t in {ir.Float8E5M2Type, ir.Float8E4M3FNType}
):
if is_sparse:
raise NotImplementedError("Only 16-bit types supported for sparse MMA")
if d.dtype != f16 and d.dtype != f32:
raise ValueError(
f"MMA with element type {element_type} only supports accumulators of"
Expand All @@ -304,7 +302,7 @@ def mma(
t.isinstance(element_type) for t in {ir.Float4E2M1FNType}
):
if is_sparse:
raise NotImplementedError("Only 16-bit types supported for sparse MMA")
raise NotImplementedError("Sparse MMA unsupported for f4e2m1fn")
if not is_scaled:
raise ValueError(
f"MMA with element type {element_type} only supports block scaling"
Expand All @@ -315,8 +313,6 @@ def mma(
f" accumulators, but got: {d.dtype}"
)
elif element_type == ir.IntegerType.get_signless(8):
if is_sparse:
raise NotImplementedError("Only 16-bit types supported for sparse MMA")
if is_scaled:
raise ValueError(
f"MMA with element type {element_type} does not support block scaling"
Expand Down Expand Up @@ -561,11 +557,11 @@ def _do_mma(
elem_bitwidth = utils.bitwidth(element_type)
instr_k = (1 + is_sparse) * 8 * 32 // elem_bitwidth
packing = 8 * 4 // elem_bitwidth
assert not is_sparse or elem_bitwidth == 16 # Only 16-bit supported for now.

extra_args: Sequence[object]
scale_steps = None
if is_scaled:
assert not is_sparse
if (ir.Float8E5M2Type.isinstance(element_type) or
ir.Float8E4M3FNType.isinstance(element_type)):
kind = "mxf8f6f4.block_scale.scale_vec::1X"
Expand Down Expand Up @@ -628,10 +624,11 @@ def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...])
else:
sp_selector = None
if is_sparse:
assert (k // instr_k) % 2 == 0
sp_selector = k_step % 2
selector_width = 64
k_steps_for_col_inc = selector_width // instr_k
assert 32 <= instr_k <= 64
selector_width = instr_k
k_steps_for_col_inc = 64 // selector_width
assert (k // instr_k) % k_steps_for_col_inc == 0
sp_selector = k_step % k_steps_for_col_inc
# If the K group is large, we need to increment the sparse metadata.
# TODO(apaszke): At this point the purpose of this function is becoming
# less clear, since we end up replicating address arithmetic that's
Expand Down
29 changes: 22 additions & 7 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1643,14 +1643,17 @@ def format_scales(scales):
@parameterized.product(
lhs_transpose=(False, True),
rhs_transpose=(False, True),
in_jax_dtype=(jnp.float16, jnp.bfloat16,),
in_jax_dtype=(jnp.float16, jnp.bfloat16, jnp.int8, jnp.float8_e4m3fn),
m=(128,), # TODO(apaszke): 256
n=(128, 256), # TODO(apaszke): other non-power-of-2
lhs_swizzle=(32, 64, 128),
rhs_swizzle=(64, 128), # 32 is too small and unsuported.
)
def test_mma_sparse(self, m, n, in_jax_dtype, lhs_swizzle, rhs_swizzle, lhs_transpose, rhs_transpose):
out_jax_dtype = jnp.float32
if jnp.issubdtype(in_jax_dtype, jnp.floating):
out_jax_dtype = jnp.float32
else:
out_jax_dtype = jnp.int32
sparse_meta_dtype = jnp.uint2

in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype)
Expand Down Expand Up @@ -1682,12 +1685,17 @@ def kernel(ctx, lhs, rhs, lhs_sparse_gmem, out, scratch):
)
tcgen05.commit_arrive(mma_barrier)
mma_barrier.wait(orders_tensor_core=True)
acc.load().store_untiled(out, optimized=False)
is_signed = True if jnp.issubdtype(in_jax_dtype, jnp.integer) else None
acc.load(is_signed=is_signed).store_untiled(out, optimized=False)

x_shape = (k // 2, m) if lhs_transpose else (m, k // 2)
x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype)
y_shape = (n, k) if rhs_transpose else (k, n)
y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype)
if jnp.issubdtype(in_jax_dtype, jnp.integer):
x = jax.random.randint(jax.random.key(1234), x_shape, -64, 64, dtype=in_jax_dtype)
y = jax.random.randint(jax.random.key(2567), y_shape, -64, 64, dtype=in_jax_dtype)
else:
x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype)
y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype)
out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype)
scratch_shape = [
jax.ShapeDtypeStruct(tile_shape(x_shape, lhs_tiling), in_jax_dtype),
Expand All @@ -1708,10 +1716,17 @@ def format_sparse_meta(meta):
mn, k, _2 = meta.shape
assert _2 == 2
k *= 2
return (
if jnp.dtype(in_jax_dtype).itemsize == 1:
meta_tiled = (
meta.reshape(mn // 128, 128, k // 64, 64).transpose(0, 2, 1, 3)
)
else:
meta_tiled = (
meta.reshape(mn // 128, 8, 2, 8, k // 64, 4, 2, 8)
.transpose(0, 4, 1, 6, 3, 5, 2, 7)
.reshape(mn // 128, k // 64, 128, 64)
)
return (
meta_tiled.reshape(mn // 128, k // 64, 128, 64)
.astype(sparse_meta_dtype)
)
x_gpu_sparse = format_sparse_meta(x_sparse)
Expand Down
Loading