Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 82 additions & 42 deletions jax/experimental/mosaic/gpu/tcgen05.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from __future__ import annotations

from collections.abc import Sequence
import dataclasses
import itertools
import math
Expand Down Expand Up @@ -384,31 +383,29 @@ def mma(
raise ValueError(
f"MMA with block scaling requires N to be divisible by 32, got: {n}"
)
if k_group_elems != 128 or a_swizzle != b_swizzle:
assert utils.bitwidth(element_type) <= 8
expected_swizzle = 128 // (8 // utils.bitwidth(element_type))
raise NotImplementedError(
"MMA with block scaling requires swizzle to be"
f" {expected_swizzle} for dtype {element_type}, got:"
f" {a_swizzle=} and {b_swizzle=}"
)
assert a_scale is not None and b_scale is not None
if a_scale.shape != (m, 4):
if a_scale.shape != (m, k // 32):
raise ValueError(
f"A scale shape mismatch: expected ({m}, 4), got {a_scale.shape}"
f"A scale shape mismatch: expected ({m}, {k // 32}), got {a_scale.shape}"
)
if a_scale.dtype != ir.Float8E8M0FNUType.get():
raise ValueError(
f"A scale dtype mismatch: expected f8e8m0fnu, got {a_scale.dtype}"
)
if b_scale.shape != (n, 4):
if b_scale.shape != (n, k // 32):
raise ValueError(
f"B scale shape mismatch: expected ({n}, 4), got {b_scale.shape}"
f"B scale shape mismatch: expected ({n}, {k // 32}), got {b_scale.shape}"
)
if b_scale.dtype != ir.Float8E8M0FNUType.get():
raise ValueError(
f"B scale dtype mismatch: expected f8e8m0fnu, got {b_scale.dtype}"
)
if k_group_elems % 128:
min_swizzle = 16 * utils.bitwidth(element_type)
raise NotImplementedError(
f"{element_type} MMA with block scaling requires swizzle to be at"
f" least {min_swizzle}"
)
if is_sparse:
a_sparse_metadata = cast(TMEMRef, a_sparse_metadata)
if collective:
Expand Down Expand Up @@ -479,6 +476,8 @@ def mma(
assert d.layout.base_tile_shape[0] % 4 == 0
lanes_per_n_group = d.layout.base_tile_shape[0] // 4
a_sparse_addr_base = a_sparse_metadata.address if is_sparse else None # type: ignore
a_scale_addr_base = a_scale.address if is_scaled else None # type: ignore
b_scale_addr_base = b_scale.address if is_scaled else None # type: ignore
for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups):
if isinstance(a, TMEMRef):
if m_groups != 1:
Expand All @@ -498,8 +497,23 @@ def mma(
a_sparse_addr = arith.addi(a_sparse_addr_base, utils.c(ki * cols_per_k_group, i32))
else:
a_sparse_addr = None
if is_scaled and (m_groups != 1 or n_groups != 1 or k_groups != 1):
raise NotImplementedError("Block-scaled metadata address calculation for multiple tiles")
if a_scale_addr_base is not None and b_scale_addr_base is not None:
if m_groups != 1:
raise NotImplementedError("A scale address calculation for multiple M tiles")
if n_groups != 1:
raise NotImplementedError("B scale address calculation for multiple N tiles")
assert k_group_elems % 128 == 0
assert m_group_elems % 32 == 0 and n_group_elems % 32 == 0
a_scale_addr = arith.addi(
a_scale_addr_base,
utils.c(ki * k_group_elems // 128 * m_group_elems // 32, i32),
)
b_scale_addr = arith.addi(
b_scale_addr_base,
utils.c(ki * k_group_elems // 128 * n_group_elems // 32, i32),
)
else:
a_scale_addr = b_scale_addr = None
acc = accumulate if ki == 0 else true
ni_lane_group, ni_col = ni // n_col_groups, ni % n_col_groups
d_offset = (
Expand All @@ -521,8 +535,8 @@ def mma(
b_transpose=b_fastest != mma_utils.Dim.K,
a_k_strides=a_k_instr_strides,
b_k_strides=b_k_instr_strides,
a_scale_addr=a_scale.address if a_scale is not None else None,
b_scale_addr=b_scale.address if b_scale is not None else None,
a_scale_addr=a_scale_addr,
b_scale_addr=b_scale_addr,
a_sparse_addr=a_sparse_addr,
accumulate=acc,
element_type=mma_element_type,
Expand Down Expand Up @@ -561,7 +575,6 @@ def _do_mma(
instr_k = (1 + is_sparse) * 8 * 32 // elem_bitwidth
packing = 8 * 4 // elem_bitwidth

extra_args: Sequence[object]
scale_steps = None
if is_scaled:
assert not is_sparse
Expand All @@ -577,7 +590,6 @@ def _do_mma(
create_scaled_instr_descriptor = create_scaled_f4_instr_descriptor
else:
raise NotImplementedError(f"Unsupported element type for block scaling: {element_type}")
extra_args = (a_scale_addr, b_scale_addr)
extra_ptx = "[$5], [$6], "
extra_constraints = ",r,r"
else:
Expand All @@ -591,7 +603,6 @@ def _do_mma(
kind = "i8"
else:
raise NotImplementedError(f"Unsupported input element type: {element_type}")
extra_args = ()
extra_constraints = extra_ptx = ""

def create_scaled_instr_descriptor(*args):
Expand All @@ -605,8 +616,8 @@ def create_scaled_instr_descriptor(*args):
sparse_meta_ptx = "[$5], " if is_sparse else ""
extra_constraints += ",r" if is_sparse else ""
sparse_addr: tuple[Any, ...] = ()
scales_addrs: tuple[Any, ...] = ()
assert a_desc_or_addr.type == ir.IntegerType.get_signless(32 if a_in_tmem else 64)
assert scale_steps is None or scale_steps == k // instr_k
def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...]):
assert len(idx_tiling) + 1 == len(strides)
idxs = []
Expand All @@ -620,10 +631,18 @@ def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...])
if is_scaled:
assert scale_steps is not None
scale_vec_width = 4 // scale_steps
scale_id = k_step * scale_vec_width
scale_id = (k_step % scale_steps) * scale_vec_width
i_desc = create_scaled_instr_descriptor(
m, n, element_type, element_type, scale_id, scale_id, a_transpose, b_transpose
)
assert m == 128
assert n % 128 == 0
a_scale_addr_offset = arith.constant(i32, k_step // scale_steps * 4)
b_scale_addr_offset = arith.constant(i32, k_step // scale_steps * n // 32)
scales_addrs = (
arith.addi(a_scale_addr, a_scale_addr_offset),
arith.addi(b_scale_addr, b_scale_addr_offset),
)
else:
sp_selector = None
if is_sparse:
Expand Down Expand Up @@ -657,7 +676,7 @@ def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...])
b_desc_instr = arith.addi(b_desc, _get_offset(k_step, b_k_idx_tiling, b_k_strides))
llvm.inline_asm(
ir.Type.parse("!llvm.void"),
[d_addr, a_desc_or_addr_instr, b_desc_instr, i_desc, accumulate, *extra_args, *sparse_addr],
[d_addr, a_desc_or_addr_instr, b_desc_instr, i_desc, accumulate, *scales_addrs, *sparse_addr],
f"tcgen05.mma{sparse_mod}.cta_group::{num_cta}.kind::{kind} [$0], {a_ptx}, $2, {sparse_meta_ptx}$3, {extra_ptx}$4;",
f"r,{a_ptx_constraint},l,r,b" + extra_constraints,
has_side_effects=True,
Expand Down Expand Up @@ -831,7 +850,14 @@ def check_type(self, shape: tuple[int, ...], bitwidth: int) -> None:

def cols_in_shape(self, shape: tuple[int, int], bitwidth: int) -> int:
self.check_type(shape, bitwidth)
return math.prod(shape) // TMEM_ROWS // self.vector_length
replication_factor = 1
for dim in self.warp_dims:
if isinstance(dim, fa.Replicated):
replication_factor *= dim.times
for dim in self.lane_dims:
if isinstance(dim, fa.Replicated):
replication_factor *= dim.times
return math.prod(shape) // TMEM_ROWS // self.vector_length * replication_factor

def canonicalize(self) -> TMEMLayout:
layout = super().canonicalize()
Expand Down Expand Up @@ -1396,17 +1422,19 @@ def async_copy_scales_smem_to_tmem(smem_ref: ir.Value, tmem_ref: TMEMRef) -> Non
MMA issued in the same thread, no additional synchronization is needed.

At the moment the function requires ``smem_ref`` to be contiguous and have a
shape of (MN // 128, 32, 16) for 8-bit scales (here MN stands for the size of
the non-contracting dimension which is M or N), matching the scale layout for
.scale_vec::1X. See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x
shape of ``(MN // 128, K // 128, 32, 16)`` for 8-bit scales (here MN stands
for the size of the non-contracting dimension which is M or N), matching the
scale layout for .scale_vec::1X. See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x
for more details. Note that we always put the non-contracting dimension first.
If you have a (MN, 4) array of scales in JAX (where MN is divisible by 128),
you can prepare it for use in the kernel this way::
If you have a (MN, K // 32) array of scales in JAX (where MN and K are
divisible by 128), you can prepare it for use in the kernel this way::

scales.reshape(-1, 4, 32, 4).swapaxes(1, 2).reshape(-1, 32, 16)
scales.reshape(mn // 128, 4, 32, k // 4, 4)
.transpose(0, 3, 2, 1, 4)
.reshape(mn // 128, k // 4, 32, 16)

The TMEM ref is expected to have the logical shape of the scales (MN, 4), and
the layout created by ``scales_layout()``.
The TMEM ref is expected to have the logical shape of the scales
``(MN, K // 32)``, and the layout created by ``scales_layout()``.
"""
i32 = ir.IntegerType.get_signless(32)
smem_ty = ir.MemRefType(smem_ref.type)
Expand All @@ -1416,30 +1444,42 @@ def async_copy_scales_smem_to_tmem(smem_ref: ir.Value, tmem_ref: TMEMRef) -> Non
raise NotImplementedError(f"Unsupported dtype: {dtype}, only f8e8m0fnu supported")
if tmem_ref.shape[0] % TMEM_ROWS:
raise ValueError(f"TMEM reference must have a multiple of {TMEM_ROWS} rows, but got {tmem_ref.shape[0]}")
if tmem_ref.shape[1] != 4:
raise ValueError(f"TMEM reference must have 4 colums, but got {tmem_ref.shape[1]}")
if tmem_ref.shape[1] % 4:
raise ValueError(f"TMEM reference must have a multiple of 4 columns, but got {tmem_ref.shape[1]}")
if tmem_ref.layout != scales_layout():
raise ValueError(f"TMEM layout {tmem_ref.layout} is not supported")
smem_shape = tuple(smem_ty.shape)
expected_smem_shape = (tmem_ref.shape[0] // TMEM_ROWS, 32, 16)
expected_smem_shape = (tmem_ref.shape[0] // TMEM_ROWS, tmem_ref.shape[1] // 4, 32, 16)
if smem_shape != expected_smem_shape:
raise NotImplementedError(
f"SMEM has {smem_shape}, but expected {expected_smem_shape} for TMEM"
f" ref shape {tmem_ref.shape}"
)
strides, _ = smem_ty.get_strides_and_offset()
# TODO(apaszke): This should only matter for the two minor dims.
if strides != utils.get_contiguous_strides(smem_shape):
raise ValueError("Only copies from contiguous SMEM references are supported")
row_tile_stride = strides[0]
if row_tile_stride % 4:
raise ValueError("Column tile stride must be a multiple of 4")
row_tile_stride_i32 = row_tile_stride // 4
mn_tile_stride, k_tile_stride = strides[:2]
# One tile of scales has 128 bytes.
if mn_tile_stride % 128 or k_tile_stride % 128:
raise ValueError("Scale tile strides must be a multiple of 128")
mn_tile_stride_i32 = mn_tile_stride // 4
k_tile_stride_i32 = k_tile_stride // 4
smem_base_ptr = utils.memref_ptr(smem_ref, 3)
for row_tile in range(expected_smem_shape[0]):
# TODO(apaszke): Need to figure out the TMEM layout otherwise and MMA doesn't
# support it anyway.
if smem_shape[0] > 2:
raise NotImplementedError("Only M/N up to 256 supported")
for mn_tile, k_tile in np.ndindex(smem_shape[:2]):
load_ptr = utils.getelementptr(
smem_base_ptr, [row_tile * row_tile_stride_i32], i32
smem_base_ptr,
[mn_tile * mn_tile_stride_i32 + k_tile * k_tile_stride_i32],
i32,
)
store_addr = arith.addi(
tmem_ref.address,
arith.constant(i32, 4 * smem_shape[0] * k_tile + 4 * mn_tile),
)
store_addr = arith.addi(tmem_ref.address, arith.constant(i32, 4 * row_tile))
# The "core matrix" here is the same as in MMA: 8x(16 bytes).
desc = mma_utils.encode_descriptor(load_ptr, 0, 8 * 16, swizzle=None)
nvvm.tcgen05_cp(
Expand Down
31 changes: 17 additions & 14 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1567,7 +1567,7 @@ def kernel(ctx, src, out, scratch):
)._debug_print()
copy(src, out)

shape = (1, 32, 16)
shape = (1, 1, 32, 16)
x = jax.lax.bitcast_convert_type(
np.arange(math.prod(shape), dtype=np.uint8).reshape(shape), dtype
)
Expand Down Expand Up @@ -1601,8 +1601,8 @@ def kernel(ctx, src, out, scratch):
def test_mma_block_scaled(self, m, n, in_jax_dtype):
out_jax_dtype = jnp.float32
scale_jax_dtype = jnp.float8_e8m0fnu
swizzle = 128 // (8 // jnp.finfo(in_jax_dtype).bits)
k_steps = 1
swizzle = 128
k_steps = 2
if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16:
self.skipTest("Only f16 input is supported for f16 output.")

Expand Down Expand Up @@ -1648,24 +1648,29 @@ def kernel(ctx, lhs, rhs, lhs_scales_gmem, rhs_scales_gmem, out, scratch):
scratch_shape = [
jax.ShapeDtypeStruct(tile_shape(x_shape, lhs_tiling), in_jax_dtype),
jax.ShapeDtypeStruct(tile_shape(y_shape, rhs_tiling), in_jax_dtype),
jax.ShapeDtypeStruct((m // 128, 32, 16), scale_jax_dtype),
jax.ShapeDtypeStruct((n // 128, 32, 16), scale_jax_dtype),
jax.ShapeDtypeStruct((m // 128, k // (32 * 4), 32, 16), scale_jax_dtype),
jax.ShapeDtypeStruct((n // 128, k // (32 * 4), 32, 16), scale_jax_dtype),
mgpu.TMABarrier(4),
mgpu.Barrier(1),
mgpu.TMEM((m, n), out_jax_dtype),
mgpu.TMEM((m, 4), scale_jax_dtype, layout=tcgen05.scales_layout()),
mgpu.TMEM((n, 4), scale_jax_dtype, layout=tcgen05.scales_layout()),
mgpu.TMEM((m, k // 32), scale_jax_dtype, layout=tcgen05.scales_layout()),
mgpu.TMEM((n, k // 32), scale_jax_dtype, layout=tcgen05.scales_layout()),
]
ka, kb = jax.random.split(jax.random.key(1234), 2)
a_scales = jax.lax.bitcast_convert_type(
jax.random.randint(ka, (m, 4), 122, 132, dtype=jnp.uint8), scale_jax_dtype
jax.random.randint(ka, (m, k // 32), 122, 132, dtype=jnp.uint8), scale_jax_dtype
)
b_scales = jax.lax.bitcast_convert_type(
jax.random.randint(kb, (n, 4), 122, 132, dtype=jnp.uint8), scale_jax_dtype
jax.random.randint(kb, (n, k // 32), 122, 132, dtype=jnp.uint8), scale_jax_dtype
)
def format_scales(scales):
assert scales.shape[0] % 128 == 0 and scales.shape[1] == 4
return scales.reshape(-1, 4, 32, 4).swapaxes(1, 2).reshape(-1, 32, 16)
mn, k = scales.shape
assert mn % 128 == 0 and k % 4 == 0, scales.shape
return (
scales.reshape(mn // 128, 4, 32, k // 4, 4)
.transpose(0, 3, 2, 1, 4)
.reshape(mn // 128, k // 4, 32, 16)
)
a_gpu_scales, b_gpu_scales = map(format_scales, (a_scales, b_scales))
args = (x, y, a_gpu_scales, b_gpu_scales)
z = mgpu.as_gpu_kernel(
Expand All @@ -1675,9 +1680,7 @@ def format_scales(scales):
a_logical_scales = jnp.repeat(a_scales, 32, axis=1).astype(jnp.float32)
b_logical_scales = jnp.repeat(b_scales, 32, axis=1).astype(jnp.float32)
ref = (x32 * a_logical_scales) @ (y32 * b_logical_scales).T
atol = 2e-2 if out_jax_dtype == jnp.float16 else 7e-5
rtol = 8e-4 if out_jax_dtype == jnp.float16 else 5e-6
np.testing.assert_allclose(z, ref, atol=atol, rtol=rtol)
np.testing.assert_allclose(z, ref, atol=2e-4, rtol=5e-6)

@parameterized.product(
lhs_transpose=(False, True),
Expand Down
23 changes: 15 additions & 8 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3632,8 +3632,10 @@ def kernel(a_smem, b_smem, out_ref, _, acc_tmem, barrier_ref):
)
def test_simple_scaled_matmul(self, m, n, dtype):
self.skip_if_wg_semantics()
k = 128
swizzle = 128 // (8 // jnp.finfo(dtype).bits)
# TODO(apaszke): Add support for single-buffering in pallas_call.
causes_oom = jnp.finfo(dtype).bits == 8 and n == 256
k = 128 if causes_oom else 256
swizzle = 128
transforms = self.default_transforms(swizzle=swizzle, dtype=dtype)
out_transforms = self.default_transforms(dtype=jnp.float32)

Expand All @@ -3656,8 +3658,8 @@ def kernel(a_smem, b_smem, a_scale_smem, b_scale_smem, out_ref,
scratch_shapes = [
plgpu.Barrier(orders_tensor_core=True),
plgpu.TMEM((m, n), jnp.float32),
plgpu.TMEM((m, 4), jnp.float8_e8m0fnu, layout=plgpu.TMEMLayout.SCALES_LAYOUT),
plgpu.TMEM((n, 4), jnp.float8_e8m0fnu, layout=plgpu.TMEMLayout.SCALES_LAYOUT),
plgpu.TMEM((m, k // 32), jnp.float8_e8m0fnu, layout=plgpu.TMEMLayout.SCALES_LAYOUT),
plgpu.TMEM((n, k // 32), jnp.float8_e8m0fnu, layout=plgpu.TMEMLayout.SCALES_LAYOUT),
]

f = self.pallas_call(
Expand All @@ -3676,16 +3678,21 @@ def kernel(a_smem, b_smem, a_scale_smem, b_scale_smem, out_ref,
y = jax.random.uniform(jax.random.key(2), shape=(n, k), dtype=jnp.float32).astype(dtype)
ksx, ksy = jax.random.split(jax.random.key(1234), 2)
x_scale = jax.lax.bitcast_convert_type(
jax.random.randint(ksx, (m, 4), 122, 132, dtype=jnp.uint8),
jax.random.randint(ksx, (m, k // 32), 122, 132, dtype=jnp.uint8),
jnp.float8_e8m0fnu
)
y_scale = jax.lax.bitcast_convert_type(
jax.random.randint(ksy, (n, 4), 122, 132, dtype=jnp.uint8),
jax.random.randint(ksy, (n, k // 32), 122, 132, dtype=jnp.uint8),
jnp.float8_e8m0fnu
)
def format_scales(scales):
assert scales.shape[0] % 128 == 0 and scales.shape[1] == 4
return scales.reshape(-1, 4, 32, 4).swapaxes(1, 2).reshape(-1, 32, 16)
mn, k = scales.shape
assert mn % 128 == 0 and k % 4 == 0
return (
scales.reshape(mn // 128, 4, 32, k // 4, 4)
.transpose(0, 3, 2, 1, 4)
.reshape(mn // 128, k // 4, 32, 16)
)
result = f(x, y, format_scales(x_scale), format_scales(y_scale))
x_logical_scale = jnp.repeat(x_scale, 32, axis=1).astype(jnp.float32)
y_logical_scale = jnp.repeat(y_scale, 32, axis=1).astype(jnp.float32)
Expand Down
Loading