Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 58 additions & 23 deletions jax/experimental/mosaic/gpu/tcgen05.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from __future__ import annotations

import dataclasses
import functools
import itertools
import math
from typing import Any, Callable, Iterator, cast
Expand Down Expand Up @@ -113,8 +114,9 @@ def _create_scaled_instr_descriptor(
b_type: ir.Type,
a_scale_idx: int,
b_scale_idx: int,
transpose_a: bool = False,
transpose_b: bool = False,
transpose_a: bool,
transpose_b: bool,
scale_type: ir.Type,
) -> ir.Value:
desc = 0
# Bits 0, 1 are reserved
Expand All @@ -131,7 +133,13 @@ def _create_scaled_instr_descriptor(
if n % 8 or n > 256:
raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}")
desc |= (n >> 3) << 17 # N, bits 17-22
desc |= 1 << 23 # Scale matrix type
if scale_type == ir.Float8E8M0FNUType.get():
scale_encoding = 1
elif scale_type == ir.Float8E4M3FNType.get():
scale_encoding = 0
else:
raise NotImplementedError(f"Unsupported scale type: {scale_type}")
desc |= scale_encoding << 23 # Scale matrix type
# Bits 24-26 are reserved
if m % 128 or m > 256:
raise ValueError(f"M must be a multiple of 16 and <= 256, got: {m}")
Expand Down Expand Up @@ -375,6 +383,8 @@ def mma(
)

# Check that the shapes and element types are correct for block scaling.
scale_element_type = None
scale_block = None
if is_scaled:
if collective:
raise NotImplementedError("MMA with block scaling does not support collective")
Expand All @@ -384,24 +394,32 @@ def mma(
f"MMA with block scaling requires N to be divisible by 32, got: {n}"
)
assert a_scale is not None and b_scale is not None
if a_scale.shape != (m, k // 32):
scale_element_type = a_scale.dtype
if a_scale.dtype == ir.Float8E8M0FNUType.get():
scale_block = 32
elif a_scale.dtype == ir.Float8E4M3FNType.get():
scale_block = 16
else:
raise ValueError(
f"A scale shape mismatch: expected ({m}, {k // 32}), got {a_scale.shape}"
f"A scale dtype mismatch: expected f8e8m0fnu, got {a_scale.dtype}"
)
if a_scale.dtype != ir.Float8E8M0FNUType.get():
if b_scale.dtype != a_scale.dtype:
raise ValueError(
f"A scale dtype mismatch: expected f8e8m0fnu, got {a_scale.dtype}"
f"B scale dtype mismatch: expected {a_scale.dtype} (same as A), got"
f" {b_scale.dtype}"
)
if b_scale.shape != (n, k // 32):
if a_scale.shape != (m, k // scale_block):
raise ValueError(
f"B scale shape mismatch: expected ({n}, {k // 32}), got {b_scale.shape}"
f"A scale shape mismatch: expected ({m}, {k // scale_block}), got"
f" {a_scale.shape}"
)
if b_scale.dtype != ir.Float8E8M0FNUType.get():
if b_scale.shape != (n, k // scale_block):
raise ValueError(
f"B scale dtype mismatch: expected f8e8m0fnu, got {b_scale.dtype}"
f"B scale shape mismatch: expected ({n}, {k // scale_block}), got"
f" {b_scale.shape}"
)
if k_group_elems % 128:
min_swizzle = 16 * utils.bitwidth(element_type)
if k_group_elems % (scale_block * 4):
min_swizzle = scale_block // 2 * utils.bitwidth(element_type)
raise NotImplementedError(
f"{element_type} MMA with block scaling requires swizzle to be at"
f" least {min_swizzle}"
Expand Down Expand Up @@ -502,15 +520,17 @@ def mma(
raise NotImplementedError("A scale address calculation for multiple M tiles")
if n_groups != 1:
raise NotImplementedError("B scale address calculation for multiple N tiles")
assert k_group_elems % 128 == 0
assert scale_block is not None # For type checkers.
assert k_group_elems % (scale_block * 4) == 0
assert m_group_elems % 32 == 0 and n_group_elems % 32 == 0
k_scales_per_group = k_group_elems // (scale_block * 4)
a_scale_addr = arith.addi(
a_scale_addr_base,
utils.c(ki * k_group_elems // 128 * m_group_elems // 32, i32),
utils.c(ki * k_scales_per_group * m_group_elems // 32, i32),
)
b_scale_addr = arith.addi(
b_scale_addr_base,
utils.c(ki * k_group_elems // 128 * n_group_elems // 32, i32),
utils.c(ki * k_scales_per_group * n_group_elems // 32, i32),
)
else:
a_scale_addr = b_scale_addr = None
Expand Down Expand Up @@ -540,6 +560,7 @@ def mma(
a_sparse_addr=a_sparse_addr,
accumulate=acc,
element_type=mma_element_type,
scale_element_type=scale_element_type,
)


Expand All @@ -558,6 +579,7 @@ def _do_mma(
n: int,
k: int,
element_type: ir.Type,
scale_element_type: ir.Type | None,
d_type: ir.Type,
accumulate: ir.Value,
collective: bool,
Expand All @@ -580,14 +602,27 @@ def _do_mma(
assert not is_sparse
if (ir.Float8E5M2Type.isinstance(element_type) or
ir.Float8E4M3FNType.isinstance(element_type)):
if scale_element_type != ir.Float8E8M0FNUType.get():
raise ValueError(
f"Scale element type mismatch: expected f8e8m0fnu, got {scale_element_type}"
)
kind = "mxf8f6f4.block_scale.scale_vec::1X"
scale_steps = 4
create_scaled_instr_descriptor = create_scaled_f8f6f4_instr_descriptor
create_scaled_instr_descriptor = functools.partial(
create_scaled_f8f6f4_instr_descriptor, scale_type=scale_element_type
)
elif ir.Float4E2M1FNType.isinstance(element_type):
assert not a_transpose and not b_transpose
kind = "mxf4.block_scale.scale_vec::2X"
scale_steps = 2
create_scaled_instr_descriptor = create_scaled_f4_instr_descriptor
create_scaled_instr_descriptor = functools.partial(
create_scaled_f4_instr_descriptor,
scale_type=scale_element_type,
)
if scale_element_type == ir.Float8E8M0FNUType.get():
kind = "mxf4.block_scale.scale_vec::2X"
scale_steps = 2
elif scale_element_type == ir.Float8E4M3FNType.get():
kind = "mxf4nvf4.block_scale.scale_vec::4X"
scale_steps = 1
else:
raise NotImplementedError(f"Unsupported element type for block scaling: {element_type}")
extra_ptx = "[$5], [$6], "
Expand All @@ -605,7 +640,7 @@ def _do_mma(
raise NotImplementedError(f"Unsupported input element type: {element_type}")
extra_constraints = extra_ptx = ""

def create_scaled_instr_descriptor(*args):
def create_scaled_instr_descriptor(*args): # type: ignore
raise NotImplementedError

num_cta = 2 if collective else 1
Expand Down Expand Up @@ -1440,8 +1475,8 @@ def async_copy_scales_smem_to_tmem(smem_ref: ir.Value, tmem_ref: TMEMRef) -> Non
smem_ty = ir.MemRefType(smem_ref.type)
if (dtype := smem_ty.element_type) != tmem_ref.dtype:
raise ValueError(f"Incompatible dtypes: SMEM has {dtype}, TMEM has {tmem_ref.dtype}")
if dtype != ir.Float8E8M0FNUType.get():
raise NotImplementedError(f"Unsupported dtype: {dtype}, only f8e8m0fnu supported")
if dtype not in {ir.Float8E8M0FNUType.get(), ir.Float8E4M3FNType.get()}:
raise NotImplementedError(f"Unsupported dtype: {dtype}, only f8e8m0fnu and f8e4m3fn are supported")
if tmem_ref.shape[0] % TMEM_ROWS:
raise ValueError(f"TMEM reference must have a multiple of {TMEM_ROWS} rows, but got {tmem_ref.shape[0]}")
if tmem_ref.shape[1] % 4:
Expand Down
52 changes: 38 additions & 14 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,14 +1617,22 @@ def kernel(ctx, src, out, scratch):

@parameterized.product(
in_jax_dtype=(jnp.float8_e5m2, jnp.float8_e4m3fn, jnp.float4_e2m1fn),
scale_jax_dtype=(jnp.float8_e8m0fnu, jnp.float8_e4m3fn),
m=(128,), # TODO(apaszke): 256
n=(128, 256), # TODO(apaszke): 192, other non-power-of-2
)
def test_mma_block_scaled(self, m, n, in_jax_dtype):
def test_mma_block_scaled(self, m, n, in_jax_dtype, scale_jax_dtype):
out_jax_dtype = jnp.float32
scale_jax_dtype = jnp.float8_e8m0fnu
swizzle = 128
k_steps = 2
if scale_jax_dtype == jnp.float8_e8m0fnu:
block_size = 32
elif scale_jax_dtype == jnp.float8_e4m3fn:
if in_jax_dtype != jnp.float4_e2m1fn:
self.skipTest("Only float4_e2m1fn input is supported for e4m3fn scale.")
block_size = 16
else:
raise ValueError(f"Unsupported scale dtype: {scale_jax_dtype}")
if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16:
self.skipTest("Only f16 input is supported for f16 output.")

Expand Down Expand Up @@ -1670,21 +1678,37 @@ def kernel(ctx, lhs, rhs, lhs_scales_gmem, rhs_scales_gmem, out, scratch):
scratch_shape = [
jax.ShapeDtypeStruct(tile_shape(x_shape, lhs_tiling), in_jax_dtype),
jax.ShapeDtypeStruct(tile_shape(y_shape, rhs_tiling), in_jax_dtype),
jax.ShapeDtypeStruct((m // 128, k // (32 * 4), 32, 16), scale_jax_dtype),
jax.ShapeDtypeStruct((n // 128, k // (32 * 4), 32, 16), scale_jax_dtype),
jax.ShapeDtypeStruct((m // 128, k // (block_size * 4), 32, 16), scale_jax_dtype),
jax.ShapeDtypeStruct((n // 128, k // (block_size * 4), 32, 16), scale_jax_dtype),
mgpu.TMABarrier(4),
mgpu.Barrier(1),
mgpu.TMEM((m, n), out_jax_dtype),
mgpu.TMEM((m, k // 32), scale_jax_dtype, layout=tcgen05.scales_layout()),
mgpu.TMEM((n, k // 32), scale_jax_dtype, layout=tcgen05.scales_layout()),
mgpu.TMEM((m, k // block_size), scale_jax_dtype, layout=tcgen05.scales_layout()),
mgpu.TMEM((n, k // block_size), scale_jax_dtype, layout=tcgen05.scales_layout()),
]
ka, kb = jax.random.split(jax.random.key(1234), 2)
a_scales = jax.lax.bitcast_convert_type(
jax.random.randint(ka, (m, k // 32), 122, 132, dtype=jnp.uint8), scale_jax_dtype
)
b_scales = jax.lax.bitcast_convert_type(
jax.random.randint(kb, (n, k // 32), 122, 132, dtype=jnp.uint8), scale_jax_dtype
)
if scale_jax_dtype == jnp.float8_e8m0fnu:
a_scales = jax.lax.bitcast_convert_type(
jax.random.randint(ka, (m, k // block_size), 122, 132, dtype=jnp.uint8),
scale_jax_dtype
)
b_scales = jax.lax.bitcast_convert_type(
jax.random.randint(kb, (n, k // block_size), 122, 132, dtype=jnp.uint8),
scale_jax_dtype
)
elif scale_jax_dtype == jnp.float8_e4m3fn:
a_scales = jnp.abs(
jax.random.normal(ka, (m, k // block_size), dtype=jnp.float32).astype(
scale_jax_dtype
)
)
b_scales = jnp.abs(
jax.random.normal(kb, (n, k // block_size), dtype=jnp.float32).astype(
scale_jax_dtype
)
)
else:
raise ValueError(f"Unsupported scale dtype: {scale_jax_dtype}")
def format_scales(scales):
mn, k = scales.shape
assert mn % 128 == 0 and k % 4 == 0, scales.shape
Expand All @@ -1699,8 +1723,8 @@ def format_scales(scales):
kernel, (1, 1, 1), (128, 1, 1), args, out_shape, scratch_shape
)(*args)
x32, y32 = x.astype(np.float32), y.astype(np.float32)
a_logical_scales = jnp.repeat(a_scales, 32, axis=1).astype(jnp.float32)
b_logical_scales = jnp.repeat(b_scales, 32, axis=1).astype(jnp.float32)
a_logical_scales = jnp.repeat(a_scales, block_size, axis=1).astype(jnp.float32)
b_logical_scales = jnp.repeat(b_scales, block_size, axis=1).astype(jnp.float32)
ref = (x32 * a_logical_scales) @ (y32 * b_logical_scales).T
np.testing.assert_allclose(z, ref, atol=2e-4, rtol=5e-6)

Expand Down
Loading