@@ -416,164 +416,135 @@ def _retry_on_failure(transfer: _Transfer, optimized: bool | None) -> Any:
416416 return transfer (optimized = False )
417417
418418
419- @_register_lowering (vector .LoadOp )
420- def _vector_load_op_lowering_rule (
421- _ : LoweringContext , vector_load_op : vector .LoadOp
422- ) -> Sequence [ir .Value ]:
423- (out_layout_attr ,) = cast (
424- ir .ArrayAttr , vector_load_op .attributes ["out_layouts" ]
425- )
419+ if jaxlib .version > (0 , 8 , 0 ):
426420
427- for i in vector_load_op .indices :
428- index_defining_op = i .owner .opview
429- if (
430- not isinstance (index_defining_op , arith .ConstantOp )
431- or index_defining_op .literal_value != 0
432- ):
433- # TODO(bchetioui,dasenov): support non-zero indices.
434- raise NotImplementedError (
435- "Only constants with value 0 are supported as indices "
436- f"for { vector_load_op } "
437- )
421+ @_register_lowering (mgpu .VectorLoadOp )
422+ def _vector_load_op_lowering_rule (
423+ _ : LoweringContext , op : mgpu .VectorLoadOp
424+ ) -> Sequence [ir .Value ]:
425+ (out_layout_attr ,) = inference_utils .out_layouts (op )
438426
439- element_type = ir .VectorType (vector_load_op .result .type ).element_type
440- is_signed = _default_is_signed (element_type )
427+ element_type = ir .VectorType (op .result .type ).element_type
428+ is_signed = _default_is_signed (element_type )
441429
442- def _fragmented_array_to_ir (fragmented_array : fa .FragmentedArray ) -> ir .Value :
443- return fragmented_array_to_ir (fragmented_array , vector_load_op .result .type )
430+ def _fragmented_array_to_ir (
431+ fragmented_array : fa .FragmentedArray ,
432+ ) -> ir .Value :
433+ return fragmented_array_to_ir (fragmented_array , op .result .type )
444434
445- if layouts .is_strided_fragmented_layout (out_layout_attr ):
446- strided_layout = layouts .from_strided_fragmented_layout_attr (
447- out_layout_attr
448- )
449- # TODO(bchetioui): Process transforms.
450- fragmented_array = fa .FragmentedArray .load_strided (
451- vector_load_op .base ,
452- is_signed = is_signed ,
453- vec_size = strided_layout .vec_size ,
454- )
455- return [_fragmented_array_to_ir (fragmented_array )]
456-
457- if not layouts .is_tiled_layout (out_layout_attr ):
458- raise ValueError (
459- f"{ vector_load_op } has an unsupported layout: { out_layout_attr } "
460- )
461-
462- optimized = (
463- vector_load_op .attributes ["optimized" ].value
464- if "optimized" in vector_load_op .attributes
465- else None
466- )
467- layout = layouts .from_tiled_layout_attr (out_layout_attr )
468- ref_ty = ir .MemRefType (vector_load_op .base .type )
469- if ref_ty .memory_space is None : # GMEM
470- fragmented_array = fa .FragmentedArray .load_untiled (
471- vector_load_op .base ,
472- layout = layout ,
473- is_signed = is_signed ,
474- optimized = optimized if optimized is not None else False ,
475- )
476- return [_fragmented_array_to_ir (fragmented_array )]
477-
478- if ref_ty .memory_space != utils .smem ():
479- raise ValueError (f"Unsupported memory space: { ref_ty .memory_space } " )
480-
481- transforms_attr = inference_utils .in_transforms (vector_load_op )[0 ]
482- swizzle , transforms = swizzle_and_transforms_from_transforms_attr (
483- transforms_attr
484- )
485- has_transforms = swizzle != mgpu .SwizzlingMode .kNoSwizzle or transforms
486- if has_transforms :
487- _check_transforms_and_swizzle_are_supported (ref_ty , transforms , swizzle )
488- transformed_ref = unwrap_transformed_memref (
489- vector_load_op .base , transforms_attr
490- )
491- def load_tiled (optimized : bool ) -> fa .FragmentedArray :
492- return fa .FragmentedArray .load_tiled (
493- transformed_ref ,
494- swizzle ,
495- is_signed = is_signed ,
496- layout = layout ,
497- optimized = optimized ,
435+ if layouts .is_strided_fragmented_layout (out_layout_attr ):
436+ strided_layout = layouts .from_strided_fragmented_layout_attr (
437+ out_layout_attr
498438 )
499-
500- fragmented_array = _retry_on_failure (load_tiled , optimized )
501- else :
502- def load_untiled (optimized : bool ) -> fa .FragmentedArray :
503- return fa .FragmentedArray .load_untiled (
504- vector_load_op .base ,
505- layout = layout ,
439+ # TODO(bchetioui): Process transforms.
440+ fragmented_array = fa .FragmentedArray .load_strided (
441+ op .source ,
506442 is_signed = is_signed ,
507- optimized = optimized ,
443+ vec_size = strided_layout . vec_size ,
508444 )
445+ return [_fragmented_array_to_ir (fragmented_array )]
509446
510- fragmented_array = _retry_on_failure (load_untiled , optimized )
511-
512- return [_fragmented_array_to_ir (fragmented_array )]
447+ if not layouts .is_tiled_layout (out_layout_attr ):
448+ raise ValueError (f"{ op } has an unsupported layout: { out_layout_attr } " )
513449
514-
515- @_register_lowering (vector .StoreOp )
516- def _vector_store_op_lowering_rule (
517- ctx : LoweringContext , vector_store_op : vector .StoreOp
518- ) -> Sequence [ir .Value ]:
519- for i in vector_store_op .indices :
520- index_defining_op = i .owner .opview
521- if (
522- not isinstance (index_defining_op , arith .ConstantOp )
523- or index_defining_op .literal_value != 0
524- ):
525- # TODO(bchetioui,dasenov): support non-zero indices.
526- raise NotImplementedError (
527- "Only constants with value 0 are supported as indices "
528- f"for { vector_store_op } "
450+ optimized = op .optimized .value if op .optimized is not None else None
451+ layout = layouts .from_tiled_layout_attr (out_layout_attr )
452+ ref_ty = ir .MemRefType (op .source .type )
453+ if ref_ty .memory_space is None : # GMEM
454+ fragmented_array = fa .FragmentedArray .load_untiled (
455+ op .source ,
456+ layout = layout ,
457+ is_signed = is_signed ,
458+ optimized = bool (optimized ),
529459 )
460+ return [_fragmented_array_to_ir (fragmented_array )]
530461
531- [to_store_layout ] = inference_utils .in_layouts (vector_store_op )
532- fragmented_array = _fragmented_array_from_ir (
533- vector_store_op .valueToStore , to_store_layout
534- )
535-
536- if ctx .auto_barriers :
537- mgpu_utils .warpgroup_barrier () # Make sure the reads have completed.
538-
539- ref = vector_store_op .base
540- ref_type = ir .MemRefType (ref .type )
541- optimized = (
542- vector_store_op .attributes ["optimized" ].value
543- if "optimized" in vector_store_op .attributes
544- else None
545- )
462+ if ref_ty .memory_space != utils .smem ():
463+ raise ValueError (f"Unsupported memory space: { ref_ty .memory_space } " )
546464
547- if ref_type .memory_space is None : # GMEM
548- fragmented_array .store_untiled (
549- ref , optimized = optimized if optimized is not None else False
550- )
551- elif ref_type .memory_space == utils .smem ():
552- transforms_attr = inference_utils .in_transforms (vector_store_op )[0 ]
465+ transforms_attr = inference_utils .in_transforms (op )[0 ]
553466 swizzle , transforms = swizzle_and_transforms_from_transforms_attr (
554467 transforms_attr
555468 )
556469 has_transforms = swizzle != mgpu .SwizzlingMode .kNoSwizzle or transforms
557470 if has_transforms :
558- _check_transforms_and_swizzle_are_supported (ref_type , transforms , swizzle )
559- unwrapped_ref = unwrap_transformed_memref (ref , transforms_attr )
560- def store_tiled (optimized : bool ):
561- fragmented_array .store_tiled (unwrapped_ref , swizzle , optimized )
471+ _check_transforms_and_swizzle_are_supported (ref_ty , transforms , swizzle )
472+ transformed_ref = unwrap_transformed_memref (op .source , transforms_attr )
473+
474+ def load_tiled (optimized : bool ) -> fa .FragmentedArray :
475+ return fa .FragmentedArray .load_tiled (
476+ transformed_ref ,
477+ swizzle ,
478+ is_signed = is_signed ,
479+ layout = layout ,
480+ optimized = optimized ,
481+ )
562482
563- _retry_on_failure (store_tiled , optimized )
483+ fragmented_array = _retry_on_failure (load_tiled , optimized )
564484 else :
565485
566- def store_untiled (optimized : bool ):
567- fragmented_array .store_untiled (ref , optimized = optimized )
486+ def load_untiled (optimized : bool ) -> fa .FragmentedArray :
487+ return fa .FragmentedArray .load_untiled (
488+ op .source ,
489+ layout = layout ,
490+ is_signed = is_signed ,
491+ optimized = optimized ,
492+ )
568493
569- _retry_on_failure (store_untiled , optimized )
570- else :
571- raise ValueError (f"Unsupported memory space: { ref_type .memory_space } " )
494+ fragmented_array = _retry_on_failure (load_untiled , optimized )
572495
573- if ctx .auto_barriers :
574- mgpu_utils .warpgroup_barrier () # Make sure the writes have completed.
496+ return [_fragmented_array_to_ir (fragmented_array )]
575497
576- return []
498+
499+ if jaxlib .version > (0 , 8 , 0 ):
500+
501+ @_register_lowering (mgpu .VectorStoreOp )
502+ def _vector_store_op_lowering_rule (
503+ ctx : LoweringContext , op : mgpu .VectorStoreOp
504+ ) -> Sequence [ir .Value ]:
505+ [to_store_layout ] = inference_utils .in_layouts (op )
506+ fragmented_array = _fragmented_array_from_ir (
507+ op .valueToStore , to_store_layout
508+ )
509+
510+ if ctx .auto_barriers :
511+ mgpu_utils .warpgroup_barrier () # Make sure the reads have completed.
512+
513+ ref = op .destination
514+ ref_type = ir .MemRefType (ref .type )
515+ optimized = op .optimized .value if op .optimized is not None else None
516+
517+ if ref_type .memory_space is None : # GMEM
518+ fragmented_array .store_untiled (ref , optimized = bool (optimized ))
519+ elif ref_type .memory_space == utils .smem ():
520+ transforms_attr = inference_utils .in_transforms (op )[0 ]
521+ swizzle , transforms = swizzle_and_transforms_from_transforms_attr (
522+ transforms_attr
523+ )
524+ has_transforms = swizzle != mgpu .SwizzlingMode .kNoSwizzle or transforms
525+ if has_transforms :
526+ _check_transforms_and_swizzle_are_supported (
527+ ref_type , transforms , swizzle
528+ )
529+ unwrapped_ref = unwrap_transformed_memref (ref , transforms_attr )
530+
531+ def store_tiled (optimized : bool ):
532+ fragmented_array .store_tiled (unwrapped_ref , swizzle , optimized )
533+
534+ _retry_on_failure (store_tiled , optimized )
535+ else :
536+
537+ def store_untiled (optimized : bool ):
538+ fragmented_array .store_untiled (ref , optimized = optimized )
539+
540+ _retry_on_failure (store_untiled , optimized )
541+ else :
542+ raise ValueError (f"Unsupported memory space: { ref_type .memory_space } " )
543+
544+ if ctx .auto_barriers :
545+ mgpu_utils .warpgroup_barrier () # Make sure the writes have completed.
546+
547+ return []
577548
578549
579550@_register_lowering (vector .BroadcastOp )
0 commit comments