Skip to content

Commit 8c5bfd2

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add custom vector load/store ops to the dialect.
These new operations replace the use of `vector.load` and `vector.store`. `vector.load` and `vector.store` verifiers check that the last dimension of the ref is contiguous but we don't have this restriction in mosaic GPU. PiperOrigin-RevId: 830815240
1 parent 626ce6e commit 8c5bfd2

File tree

10 files changed

+404
-369
lines changed

10 files changed

+404
-369
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,22 +1608,16 @@ def _get_lowering_rule_wg(
16081608
x_ref, transforms = _handle_transforms(
16091609
ctx, x_ref, transforms, allow_peer_refs=True
16101610
)
1611-
assert isinstance(x_ref, ir.Value)
1612-
mlir_dtype = ir.MemRefType(x_ref.type).element_type
16131611

16141612
if transforms:
16151613
raise NotImplementedError(
16161614
"Transforms are not yet implemented for warpgroup semantics"
16171615
)
16181616

1617+
assert isinstance(x_ref, ir.Value)
16191618
shape = ctx.avals_out[0].shape
1620-
ty = ir.VectorType.get(shape, mlir_dtype)
16211619
if shape:
1622-
zero_index = arith_dialect.constant(ir.IndexType.get(), 0)
1623-
indices = [zero_index for _ in range(len(shape))]
1624-
op = vector_dialect.LoadOp(ty, x_ref, indices)
1625-
op.attributes["optimized"] = ir.BoolAttr.get(optimized)
1626-
return op.result
1620+
return mgpu.dialect.vector_load(x_ref, optimized=optimized)
16271621
else:
16281622
return memref_dialect.load(x_ref, [])
16291623

@@ -1752,13 +1746,9 @@ def _swap_lowering_rule_wg(
17521746
"Transforms are not yet implemented for warpgroup semantics"
17531747
)
17541748
assert isinstance(x_smem, ir.Value)
1755-
x_mlir_dtype = ir.MemRefType(x_smem.type).element_type
1756-
ty = ir.VectorType.get(shape, x_mlir_dtype)
17571749
if shape:
1758-
zero_index = arith_dialect.constant(ir.IndexType.get(), 0)
1759-
indices = [zero_index] * len(shape)
1760-
old_value = vector_dialect.load(ty, x_smem, indices)
1761-
vector_dialect.store(value, x_smem, indices)
1750+
old_value = mgpu.dialect.vector_load(x_smem)
1751+
mgpu.dialect.vector_store(value, x_smem)
17621752
else:
17631753
old_value = memref_dialect.load(x_smem, [])
17641754
memref_dialect.store(value, x_smem, [])

jax/experimental/mosaic/gpu/dialect_lowering.py

Lines changed: 106 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -416,164 +416,135 @@ def _retry_on_failure(transfer: _Transfer, optimized: bool | None) -> Any:
416416
return transfer(optimized=False)
417417

418418

419-
@_register_lowering(vector.LoadOp)
420-
def _vector_load_op_lowering_rule(
421-
_: LoweringContext, vector_load_op: vector.LoadOp
422-
) -> Sequence[ir.Value]:
423-
(out_layout_attr,) = cast(
424-
ir.ArrayAttr, vector_load_op.attributes["out_layouts"]
425-
)
419+
if jaxlib.version > (0, 8, 0):
426420

427-
for i in vector_load_op.indices:
428-
index_defining_op = i.owner.opview
429-
if (
430-
not isinstance(index_defining_op, arith.ConstantOp)
431-
or index_defining_op.literal_value != 0
432-
):
433-
# TODO(bchetioui,dasenov): support non-zero indices.
434-
raise NotImplementedError(
435-
"Only constants with value 0 are supported as indices "
436-
f"for {vector_load_op}"
437-
)
421+
@_register_lowering(mgpu.VectorLoadOp)
422+
def _vector_load_op_lowering_rule(
423+
_: LoweringContext, op: mgpu.VectorLoadOp
424+
) -> Sequence[ir.Value]:
425+
(out_layout_attr,) = inference_utils.out_layouts(op)
438426

439-
element_type = ir.VectorType(vector_load_op.result.type).element_type
440-
is_signed = _default_is_signed(element_type)
427+
element_type = ir.VectorType(op.result.type).element_type
428+
is_signed = _default_is_signed(element_type)
441429

442-
def _fragmented_array_to_ir(fragmented_array: fa.FragmentedArray) -> ir.Value:
443-
return fragmented_array_to_ir(fragmented_array, vector_load_op.result.type)
430+
def _fragmented_array_to_ir(
431+
fragmented_array: fa.FragmentedArray,
432+
) -> ir.Value:
433+
return fragmented_array_to_ir(fragmented_array, op.result.type)
444434

445-
if layouts.is_strided_fragmented_layout(out_layout_attr):
446-
strided_layout = layouts.from_strided_fragmented_layout_attr(
447-
out_layout_attr
448-
)
449-
# TODO(bchetioui): Process transforms.
450-
fragmented_array = fa.FragmentedArray.load_strided(
451-
vector_load_op.base,
452-
is_signed=is_signed,
453-
vec_size=strided_layout.vec_size,
454-
)
455-
return [_fragmented_array_to_ir(fragmented_array)]
456-
457-
if not layouts.is_tiled_layout(out_layout_attr):
458-
raise ValueError(
459-
f"{vector_load_op} has an unsupported layout: {out_layout_attr}"
460-
)
461-
462-
optimized = (
463-
vector_load_op.attributes["optimized"].value
464-
if "optimized" in vector_load_op.attributes
465-
else None
466-
)
467-
layout = layouts.from_tiled_layout_attr(out_layout_attr)
468-
ref_ty = ir.MemRefType(vector_load_op.base.type)
469-
if ref_ty.memory_space is None: # GMEM
470-
fragmented_array = fa.FragmentedArray.load_untiled(
471-
vector_load_op.base,
472-
layout=layout,
473-
is_signed=is_signed,
474-
optimized=optimized if optimized is not None else False,
475-
)
476-
return [_fragmented_array_to_ir(fragmented_array)]
477-
478-
if ref_ty.memory_space != utils.smem():
479-
raise ValueError(f"Unsupported memory space: {ref_ty.memory_space}")
480-
481-
transforms_attr = inference_utils.in_transforms(vector_load_op)[0]
482-
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
483-
transforms_attr
484-
)
485-
has_transforms = swizzle != mgpu.SwizzlingMode.kNoSwizzle or transforms
486-
if has_transforms:
487-
_check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle)
488-
transformed_ref = unwrap_transformed_memref(
489-
vector_load_op.base, transforms_attr
490-
)
491-
def load_tiled(optimized: bool) -> fa.FragmentedArray:
492-
return fa.FragmentedArray.load_tiled(
493-
transformed_ref,
494-
swizzle,
495-
is_signed=is_signed,
496-
layout=layout,
497-
optimized=optimized,
435+
if layouts.is_strided_fragmented_layout(out_layout_attr):
436+
strided_layout = layouts.from_strided_fragmented_layout_attr(
437+
out_layout_attr
498438
)
499-
500-
fragmented_array = _retry_on_failure(load_tiled, optimized)
501-
else:
502-
def load_untiled(optimized: bool) -> fa.FragmentedArray:
503-
return fa.FragmentedArray.load_untiled(
504-
vector_load_op.base,
505-
layout=layout,
439+
# TODO(bchetioui): Process transforms.
440+
fragmented_array = fa.FragmentedArray.load_strided(
441+
op.source,
506442
is_signed=is_signed,
507-
optimized=optimized,
443+
vec_size=strided_layout.vec_size,
508444
)
445+
return [_fragmented_array_to_ir(fragmented_array)]
509446

510-
fragmented_array = _retry_on_failure(load_untiled, optimized)
511-
512-
return [_fragmented_array_to_ir(fragmented_array)]
447+
if not layouts.is_tiled_layout(out_layout_attr):
448+
raise ValueError(f"{op} has an unsupported layout: {out_layout_attr}")
513449

514-
515-
@_register_lowering(vector.StoreOp)
516-
def _vector_store_op_lowering_rule(
517-
ctx: LoweringContext, vector_store_op: vector.StoreOp
518-
) -> Sequence[ir.Value]:
519-
for i in vector_store_op.indices:
520-
index_defining_op = i.owner.opview
521-
if (
522-
not isinstance(index_defining_op, arith.ConstantOp)
523-
or index_defining_op.literal_value != 0
524-
):
525-
# TODO(bchetioui,dasenov): support non-zero indices.
526-
raise NotImplementedError(
527-
"Only constants with value 0 are supported as indices "
528-
f"for {vector_store_op}"
450+
optimized = op.optimized.value if op.optimized is not None else None
451+
layout = layouts.from_tiled_layout_attr(out_layout_attr)
452+
ref_ty = ir.MemRefType(op.source.type)
453+
if ref_ty.memory_space is None: # GMEM
454+
fragmented_array = fa.FragmentedArray.load_untiled(
455+
op.source,
456+
layout=layout,
457+
is_signed=is_signed,
458+
optimized=bool(optimized),
529459
)
460+
return [_fragmented_array_to_ir(fragmented_array)]
530461

531-
[to_store_layout] = inference_utils.in_layouts(vector_store_op)
532-
fragmented_array = _fragmented_array_from_ir(
533-
vector_store_op.valueToStore, to_store_layout
534-
)
535-
536-
if ctx.auto_barriers:
537-
mgpu_utils.warpgroup_barrier() # Make sure the reads have completed.
538-
539-
ref = vector_store_op.base
540-
ref_type = ir.MemRefType(ref.type)
541-
optimized = (
542-
vector_store_op.attributes["optimized"].value
543-
if "optimized" in vector_store_op.attributes
544-
else None
545-
)
462+
if ref_ty.memory_space != utils.smem():
463+
raise ValueError(f"Unsupported memory space: {ref_ty.memory_space}")
546464

547-
if ref_type.memory_space is None: # GMEM
548-
fragmented_array.store_untiled(
549-
ref, optimized=optimized if optimized is not None else False
550-
)
551-
elif ref_type.memory_space == utils.smem():
552-
transforms_attr = inference_utils.in_transforms(vector_store_op)[0]
465+
transforms_attr = inference_utils.in_transforms(op)[0]
553466
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
554467
transforms_attr
555468
)
556469
has_transforms = swizzle != mgpu.SwizzlingMode.kNoSwizzle or transforms
557470
if has_transforms:
558-
_check_transforms_and_swizzle_are_supported(ref_type, transforms, swizzle)
559-
unwrapped_ref = unwrap_transformed_memref(ref, transforms_attr)
560-
def store_tiled(optimized: bool):
561-
fragmented_array.store_tiled(unwrapped_ref, swizzle, optimized)
471+
_check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle)
472+
transformed_ref = unwrap_transformed_memref(op.source, transforms_attr)
473+
474+
def load_tiled(optimized: bool) -> fa.FragmentedArray:
475+
return fa.FragmentedArray.load_tiled(
476+
transformed_ref,
477+
swizzle,
478+
is_signed=is_signed,
479+
layout=layout,
480+
optimized=optimized,
481+
)
562482

563-
_retry_on_failure(store_tiled, optimized)
483+
fragmented_array = _retry_on_failure(load_tiled, optimized)
564484
else:
565485

566-
def store_untiled(optimized: bool):
567-
fragmented_array.store_untiled(ref, optimized=optimized)
486+
def load_untiled(optimized: bool) -> fa.FragmentedArray:
487+
return fa.FragmentedArray.load_untiled(
488+
op.source,
489+
layout=layout,
490+
is_signed=is_signed,
491+
optimized=optimized,
492+
)
568493

569-
_retry_on_failure(store_untiled, optimized)
570-
else:
571-
raise ValueError(f"Unsupported memory space: {ref_type.memory_space}")
494+
fragmented_array = _retry_on_failure(load_untiled, optimized)
572495

573-
if ctx.auto_barriers:
574-
mgpu_utils.warpgroup_barrier() # Make sure the writes have completed.
496+
return [_fragmented_array_to_ir(fragmented_array)]
575497

576-
return []
498+
499+
if jaxlib.version > (0, 8, 0):
500+
501+
@_register_lowering(mgpu.VectorStoreOp)
502+
def _vector_store_op_lowering_rule(
503+
ctx: LoweringContext, op: mgpu.VectorStoreOp
504+
) -> Sequence[ir.Value]:
505+
[to_store_layout] = inference_utils.in_layouts(op)
506+
fragmented_array = _fragmented_array_from_ir(
507+
op.valueToStore, to_store_layout
508+
)
509+
510+
if ctx.auto_barriers:
511+
mgpu_utils.warpgroup_barrier() # Make sure the reads have completed.
512+
513+
ref = op.destination
514+
ref_type = ir.MemRefType(ref.type)
515+
optimized = op.optimized.value if op.optimized is not None else None
516+
517+
if ref_type.memory_space is None: # GMEM
518+
fragmented_array.store_untiled(ref, optimized=bool(optimized))
519+
elif ref_type.memory_space == utils.smem():
520+
transforms_attr = inference_utils.in_transforms(op)[0]
521+
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
522+
transforms_attr
523+
)
524+
has_transforms = swizzle != mgpu.SwizzlingMode.kNoSwizzle or transforms
525+
if has_transforms:
526+
_check_transforms_and_swizzle_are_supported(
527+
ref_type, transforms, swizzle
528+
)
529+
unwrapped_ref = unwrap_transformed_memref(ref, transforms_attr)
530+
531+
def store_tiled(optimized: bool):
532+
fragmented_array.store_tiled(unwrapped_ref, swizzle, optimized)
533+
534+
_retry_on_failure(store_tiled, optimized)
535+
else:
536+
537+
def store_untiled(optimized: bool):
538+
fragmented_array.store_untiled(ref, optimized=optimized)
539+
540+
_retry_on_failure(store_untiled, optimized)
541+
else:
542+
raise ValueError(f"Unsupported memory space: {ref_type.memory_space}")
543+
544+
if ctx.auto_barriers:
545+
mgpu_utils.warpgroup_barrier() # Make sure the writes have completed.
546+
547+
return []
577548

578549

579550
@_register_lowering(vector.BroadcastOp)

jax/experimental/mosaic/gpu/equations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,9 @@ def reduce_expression(
344344
_SUPPORTED_TILED_RELAYOUTS = frozenset([
345345
# Transposed layouts.
346346
(fa.WGMMA_LAYOUT, fa.WGMMA_TRANSPOSED_LAYOUT),
347+
(fa.WGMMA_TRANSPOSED_LAYOUT, fa.WGMMA_LAYOUT),
347348
(fa.TCGEN05_LAYOUT, fa.TCGEN05_TRANSPOSED_LAYOUT),
349+
(fa.TCGEN05_TRANSPOSED_LAYOUT, fa.TCGEN05_LAYOUT),
348350
# "Conversion-optimized" layouts.
349351
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT),
350352
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT_UPCAST_2X),

0 commit comments

Comments
 (0)