From ced45d6735e8c4d2a6211d82ca1c6060f140ab47 Mon Sep 17 00:00:00 2001 From: "Lu,Chengjun" Date: Mon, 21 Jul 2025 15:49:00 +0000 Subject: [PATCH 1/2] [LoadStoreOpToLLVM] Improve the 2D block IO lowering for DPAS and DotOp layout. Signed-off-by: Lu,Chengjun --- .../tensor-pointer-load-block-2d.mlir | 51 ++++++++++++- .../LoadStoreOpToLLVM.cpp | 73 ++++++++++++++++++- 2 files changed, 120 insertions(+), 4 deletions(-) diff --git a/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir b/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir index 63b0a3cb7a..ed6caad5cf 100644 --- a/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir +++ b/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir @@ -369,7 +369,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr tt.func public @regular_pointer_block_io(%arg0: !tt.ptr) { %a_mask = arith.constant dense : tensor<256x64xi1, #mma> - %a_other = arith.constant dense<0.00e+00> : tensor<256x64xf16, #mma> + %a_other = arith.constant dense<1.00e+00> : tensor<256x64xf16, #mma> // CHECK-NOT: llvm.cond_br %0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #mma}>> @@ -389,7 +389,6 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr // CHECK: %[[TOP_LEFT_MASK_BOOL_64:.*]] = llvm.extractvalue {{.*}}[64] : !llvm.struct<(i1, i1, {{.*}} // CHECK: %[[TOP_LEFT_MASK_BOOL_96:.*]] = llvm.extractvalue {{.*}}[96] : !llvm.struct<(i1, i1, {{.*}} - // CHECK: %[[BLOCK_SHAPE_Y:.*]] = llvm.mlir.constant(16 : i32) : i32 // CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64 // CHECK: %[[VAL_2886:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64 @@ -402,6 +401,18 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1 // CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32 // CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> // CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64 // CHECK: %[[VAL_3046:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64 @@ -414,6 +425,18 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1 // CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32 // CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> // CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64 // CHECK: %[[VAL_3046:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64 @@ -426,6 +449,18 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1 // CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32 // CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> // CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64 // CHECK: %[[VAL_3046:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64 @@ -438,6 +473,18 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1 // CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32 // CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> + // CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16> + // CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16> + // CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16> %11 = tt.load %10, %a_mask, %a_other {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr, #mma> tt.return diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index fd05a6be5e..da5f11e8f1 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -2682,10 +2682,66 @@ struct LoadOpToBlockIOConversion bool useVNNIFormat = false; Type packedDPASOperandType; - if (hasDotDpasEncoding(tensorType)) { + if (hasDpasEncoding(tensorType) || hasDotDpasEncoding(tensorType)) { + + // For the DPAS layout, there are three types of block loads used. + // (For non-DPAS layouts, only two types are involved.) + // 1. load2DGenXType – + // 2. packedDPASOperandType – (This is null for non-DPAS layouts.) + // 3. unpackedType – + // + // clang-format off + // The `tt.load` operation generates the following block load sequence: + // %0 = load_2d %ptr : + // %1 = shufflevector %0, %0, + // <8 x i32> + // %2 = shufflevector %0, %0, + // <8 x i32> + // %3 = bitcast %1 : -> + // %4 = bitcast %2 : -> + // + // clang-format on + // + // The `tt.dot` operation generates the DPAS instruction sequence: + // clang-format off + // + // %5 = bitcast %3 : -> + // %6 = bitcast %4 : -> + // %7 = dpas %5, %6, %other : , , + // clang-format on + // + // The LLVM optimizer eliminates redundant pack/unpack element pairs + // and corresponding bitcast operations. The final optimized IR for + // the dot product becomes: + // + // clang-format off + // %0 = load_2d %ptr : + // %1 = shufflevector %0, %0, + // <8 x i32> + // %2 = shufflevector %0, %0, + // <8 x i32> + // %3 = dpas %1, %2, %other : , , + // clang-format on + // + // The `packedDPASOperandType` together with the `shufflevector` + // operations defines the computation flow for the dot product. + DpasEncodingAttr::OpIdx opIdx = getOpIdx(tensorType); auto dpasLayout = getDpasLayout(tensorType); - if (opIdx == DpasEncodingAttr::OpIdx::OperandB) { + switch (opIdx) { + case DpasEncodingAttr::OpIdx::OperandA: { + unsigned elemsPerLanePerDPASInst = + product(dpasLayout.getDPASInstShapeA()) / threadsPerWarp; + // Block 2D contain at least one DotOp A. + if (numElemsPerLoad >= elemsPerLanePerDPASInst) { + packedDPASOperandType = LLVM::getVectorType( + packedType, elemsPerLanePerDPASInst / numPackedVals); + unpackedType = LLVM::getVectorType(eltTy, elemsPerLanePerDPASInst); + } + } break; + case DpasEncodingAttr::OpIdx::OperandB: { + assert(numPackedVals == 1 && + "invalid number of packed values for DPAS operand B."); unsigned elemsPerLanePerDPASInst = product(dpasLayout.getDPASInstShapeB()) / threadsPerWarp; // Block 2D contain at least one DotOp B. @@ -2709,6 +2765,19 @@ struct LoadOpToBlockIOConversion } unpackedType = LLVM::getVectorType(eltTy, elemsPerLanePerDPASInst); } + } break; + case DpasEncodingAttr::OpIdx::OperandC: { + unsigned elemsPerLanePerDPASInst = + product(dpasLayout.getDPASInstShapeC()) / threadsPerWarp; + // Block 2D contain at least one DotOp C. + if (numElemsPerLoad >= elemsPerLanePerDPASInst) { + packedDPASOperandType = LLVM::getVectorType( + packedType, elemsPerLanePerDPASInst / numPackedVals); + unpackedType = LLVM::getVectorType(eltTy, elemsPerLanePerDPASInst); + } + } break; + default: + llvm_unreachable("unexpected OpIdx type."); } } SmallVector unpackedLoadedVals(numElems); From ae745cf5f77b181523d09deafe535ace3b1b25d6 Mon Sep 17 00:00:00 2001 From: "Lu,Chengjun" Date: Mon, 21 Jul 2025 15:49:00 +0000 Subject: [PATCH 2/2] [LoadStoreOpToLLVM] Transposed 2d load. Signed-off-by: Lu,Chengjun --- python/test/unit/intel/test_block_io.py | 21 +- .../LoadStoreOpToLLVM.cpp | 704 ++---------------- 2 files changed, 70 insertions(+), 655 deletions(-) diff --git a/python/test/unit/intel/test_block_io.py b/python/test/unit/intel/test_block_io.py index 6c6d5f1250..f3a6cab669 100644 --- a/python/test/unit/intel/test_block_io.py +++ b/python/test/unit/intel/test_block_io.py @@ -120,8 +120,9 @@ def warps_per_cta(layout): @pytest.mark.parametrize("layout", layouts) @pytest.mark.parametrize("load_block_ptr, store_block_ptr", [(True, True), (False, False), (True, False), (False, True)]) +@pytest.mark.parametrize("transpose", [True, False]) @pytest.mark.skipif(not is_xpu(), reason="Block store tests are specific to the XPU backend") -def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, device, tmp_path: pathlib.Path): +def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, transpose, device, tmp_path: pathlib.Path): warps = warps_per_cta(layout) num_warps = int(np.prod(warps)) @@ -132,16 +133,18 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi support_block_io = torch.xpu.get_device_capability()['has_subgroup_2d_block_io'] + block_io = "\"column_major\"" if transpose else "\"row_major\"" + if load_block_ptr: load_ops = f""" - %src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array}} : > - %store_val = tt.load %src_ptr {{ttig.block_io = "row_major", boundaryCheck = array, padding = 1 : i32}} : !tt.ptr> + %src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], {"[%c1_i64, %M_i64]" if transpose else "[%N_i64, %c1_i64]"}, [%c0_i32, %c0_i32] {{order = array}} : > + %store_val = tt.load %src_ptr {{ttig.block_io = {block_io}, boundaryCheck = array, padding = 1 : i32}} : !tt.ptr> """ else: load_ops = f""" %src_base = tt.splat %src : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout> - %src_ptr = tt.addptr %src_base, %row_major_off : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout> - %store_val = tt.load %src_ptr {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout> + %src_ptr = tt.addptr %src_base, {"%col_major_off" if transpose else "%row_major_off" } : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout> + %store_val = tt.load %src_ptr {{ttig.block_io = {block_io}}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout> """ if store_block_ptr: store_ops = f""" @@ -175,6 +178,12 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi %7 = tt.broadcast %5 : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout> %row_major_off = arith.addi %6, %7 : tensor<{M}x{N}xi32, #layout> + %stride_M = arith.constant dense<{M}> : tensor<1x{N}xi32, #layout> + %col_stride = arith.muli %5, %stride_M : tensor<1x{N}xi32, #layout> + %8 = tt.broadcast %2 : tensor<{M}x1xi32, #layout> -> tensor<{M}x{N}xi32, #layout> + %9 = tt.broadcast %col_stride : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout> + %col_major_off = arith.addi %8, %9 : tensor<{M}x{N}xi32, #layout> + {load_ops} {store_ops} @@ -195,6 +204,8 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi temp_file.write_text(ir) kernel = triton.compile(str(temp_file)) + a = a.permute(1, 0).contiguous().permute(1, 0) if transpose else a + kernel[(1, 1, 1)](a, x) assert torch.equal(a, x) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index da5f11e8f1..4d54faf4fb 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1828,640 +1828,6 @@ struct LoadOpToBlockIOConversion return success(); } - // FIXME: Temp solution for supporting transpose load. - LogicalResult - matchAndRewriteTranspose(triton::LoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Location loc = op.getLoc(); - MLIRContext *ctx = rewriter.getContext(); - auto b = TritonLLVMOpBuilder(loc, rewriter); - Value ptr = op.getPtr(); - Value mask = op.getMask(); - Type resultType = op.getType(); - auto tensorType = cast(resultType); - const bool memoryRowMajor = isMemoryRowMajor(op); - DpasEncodingAttr::OpIdx opIdx = getOpIdx(tensorType); - - Attribute encoding = tensorType.getEncoding(); - std::optional llEncoding = - cast(encoding).toLinearLayout( - tensorType.getShape()); - assert(llEncoding.has_value() && - "unexpected failure when getting linear layout"); - - Type eltTy = getTypeConverter()->convertType(tensorType.getElementType()); - unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth(); - - auto llAttr = LinearEncodingAttr::get(rewriter.getContext(), *llEncoding); - SmallVector threadOrder(llAttr.getThreadOrder()); - size_t rank = threadOrder.size(); - const bool isTransposeRequired = true; - - // Step 2: Right now we only support DPAS related layout to simplify the - // lowering. - DpasEncodingAttr dpasLayout = getDpasLayout(tensorType); - const ArrayRef tensorShape = tensorType.getShape(); - unsigned numElems = getTotalElemsPerThread(resultType); - SmallVector repetitons = - dpasLayout.getDPASRepetitions(tensorShape, opIdx); - assert(repetitons.size() == 3 && - "getDPASRepetitions always return rank 3 size"); - assert(repetitons[0] == 1 && "Only supports rank of 2 for now"); - SmallVector numReps{repetitons[1], repetitons[2]}; - ArrayRef warpsPerCTA = dpasLayout.getWarpsPerCTA(); - SmallVector dpasWarpsOrder = - getMatrixOrder(warpsPerCTA.size(), /*rowMajor*/ true); - unsigned threadsPerWarp = - product(getThreadsPerWarp(dpasLayout, tensorShape)); - - Value warpId = rewriter.create( - loc, i32_ty, - rewriter.create(loc, /*upperBound=*/nullptr)); - - SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder); - - // By default, use the unpacked type for the 2D load result type. - Type loadResultElemType = typeConverter->convertType(eltTy); - bool usePackedType = false; - unsigned packedElemsNum = 1; - // The tensor values are distributed as DotOp layout of DPAS. - // If the element size of the tensor matches the DPAS packed layout, then - // use the packed type for the 2D load result type. For example, - // The intermediate ops generated by ConvertTritonGPUToLLVM: - // %0 = load_2d %ptr : vector<8 x i32> - // %1 = bitcast %0 : vector<8 x i32> -> vector<16 x f16> - // %2 = bitcast %1 : vector<16 x f16> -> vector<8 x i32> - // %3 = dpas %2 - // And the LLVM dialect optimization pass can eliminate the duplicated - // bitcast. Then there is a shortcut to use the load result directly as the - // input operands to DPAS. - // TODO: add support for int4 and int2. - - // OperandA: outer dim -> M, inner dim -> K. - // OperandB: outer dim -> N, inner dim -> K. - // OperandC: outer dim -> M, inner dim -> N. - // Round the warp id fit into the tensor shape. - unsigned dimOuter; - unsigned dimInner; - SmallVector repCluster(dpasLayout.getRepCluster()); - SmallVector warpShape; - SmallVector dpasInstShape; - - switch (opIdx) { - case DpasEncodingAttr::OpIdx::OperandA: { - warpShape = std::move(dpasLayout.getShapeA()); - dpasInstShape = std::move(dpasLayout.getDPASInstShapeA()); - dimOuter = rank - 2; - dimInner = rank - 1; - repCluster[dimInner] = 1; - - unsigned opsPerChannel = dpasLayout.getOpsPerChannel(); - if ((opsPerChannel == 4 && elemSizeInBits == 8) || - (opsPerChannel == 2 && elemSizeInBits == 16) || - (opsPerChannel == 1 && elemSizeInBits == 32)) { - loadResultElemType = elemSizeInBits == 32 ? i32_ty : i16_ty; - packedElemsNum = opsPerChannel == 4 ? 2 : 1; - usePackedType = true; - } else if (opsPerChannel == 4) { - packedElemsNum = 2; - unsigned packedBitWidht = elemSizeInBits * packedElemsNum; - if (packedBitWidht > 64) { - // Be conservative to avoid the packed type exceeds 64 bits. - return failure(); - } - // Need to pack two column into one to work around vectorization - // limitation. - loadResultElemType = int_ty(packedBitWidht); - usePackedType = true; - } - } break; - case DpasEncodingAttr::OpIdx::OperandB: { - warpShape = std::move(dpasLayout.getShapeB()); - dpasInstShape = std::move(dpasLayout.getDPASInstShapeB()); - dimOuter = rank - 1; - dimInner = rank - 2; - repCluster[dimInner] = 1; - - unsigned opsPerChannel = dpasLayout.getOpsPerChannel(); - if ((opsPerChannel == 4 && elemSizeInBits == 8) || - (opsPerChannel == 2 && elemSizeInBits == 16) || - (opsPerChannel == 1 && elemSizeInBits == 32)) { - loadResultElemType = i32_ty; - packedElemsNum = opsPerChannel; - usePackedType = true; - } - } break; - case DpasEncodingAttr::OpIdx::OperandC: - warpShape = std::move(dpasLayout.getShapeC()); - dpasInstShape = std::move(dpasLayout.getDPASInstShapeC()); - dimOuter = rank - 2; - dimInner = rank - 1; - usePackedType = false; - break; - default: - llvm_unreachable("unknown DPAS operands index type."); - break; - } - unsigned elemsPerLanePerDPASInst = - product(dpasInstShape) / threadsPerWarp; - LLVMTypeConverter *typeConverter = getTypeConverter(); - Type unpackedDPASOperandType = LLVM::getVectorType( - typeConverter->convertType(eltTy), elemsPerLanePerDPASInst); - - unsigned packedElemsPerLanePerDPASInst = - elemsPerLanePerDPASInst / packedElemsNum; - Type packedDPASOperandType = - LLVM::getVectorType(loadResultElemType, packedElemsPerLanePerDPASInst); - - unsigned outerDimTileNum = - mlir::ceil(tensorShape[dimOuter], warpShape[dimOuter]); - unsigned outerDimWarpNum = - std::min(warpsPerCTA[dimOuter], outerDimTileNum); - Value outerDimWarpId = - b.urem(multiDimWarpId[dimOuter], b.i32_val(outerDimWarpNum)); - unsigned innerDimRequiredWarpNum = - mlir::ceil(tensorShape[dimInner], warpShape[dimInner]); - unsigned innerDimWarpNum = - std::min(warpsPerCTA[dimInner], innerDimRequiredWarpNum); - - // Step 3: Get the tile size of load. - unsigned tileWidth = dpasInstShape[threadOrder[rank - 2]]; - unsigned tileHeight = dpasInstShape[threadOrder[rank - 1]]; - unsigned vBlocks = 1; - unsigned numOperandsOuterDimPerLoad = 1; - unsigned numOperandsInnerDimPerLoad = 1; - unsigned maskConstancyHor = 1, maskConstancyVer = 1; - unsigned instWidth = dpasInstShape[threadOrder[rank - 2]]; - unsigned instHeight = dpasInstShape[threadOrder[rank - 1]]; - - std::map, Value> ptrs; - std::map, Value> masks; - std::map, Value> others; - - Value llPtr = adaptor.getPtr(); - Value llMask = adaptor.getMask(); - - SmallVector ptrElems, maskElems, otherElems; - // Get the LLVM values for pointers - ptrElems = unpackLLElements(loc, llPtr, rewriter); - assert(ptrElems.size() == numElems && - "the number of pointer values is not matched with the number of " - "elements"); - - // Get the LLVM values for mask - if (llMask) { - maskElems = unpackLLElements(loc, llMask, rewriter); - assert(maskElems.size() == numElems && - "the number of mask values is not matched with the number of " - "elements"); - auto axisInfo = - const_cast(axisAnalysisPass) - .getAxisInfo(mask); - if (axisInfo) { - maskConstancyHor = axisInfo->getConstancy(rank - 1); - maskConstancyVer = axisInfo->getConstancy(rank - 2); - } else { - maskConstancyHor = 1; - maskConstancyVer = 1; - } - } else { - // no mask - maskConstancyHor = std::numeric_limits::max(); - maskConstancyVer = std::numeric_limits::max(); - } - - // Check the constancy of the mask support to load the memory in 2D block. - if (!(maskConstancyHor >= instWidth && maskConstancyVer >= instHeight)) - return failure(); - - // Get the LLVM values for `other` - Value other = op.getOther(); - Value llOther = adaptor.getOther(); - DenseElementsAttr constAttr; - if (other) { - if (matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat()) { - Type elemTy = constAttr.getElementType(); - auto handleSplatValue = [&](auto splatVal) { - if (!splatVal.isZero()) { - otherElems = SmallVector( - numElems, - rewriter.create(loc, elemTy, splatVal)); - } - }; - - TypeSwitch(elemTy) - .Case([&](FloatType) { - handleSplatValue(constAttr.getSplatValue()); - }) - .Case([&](IntegerType) { - handleSplatValue(constAttr.getSplatValue()); - }); - } else { - otherElems = unpackLLElements(loc, llOther, rewriter); - } - } - - // re-arrange the ptrs and masks to for large 2D block IO. - // Layout is unrelated to the scalar type. - SmallVector> offsets = - mlir::emitOffsetForLayout(encoding, tensorType); - for (size_t i = 0; i < ptrElems.size(); ++i) { - SmallVector offset = offsets[i]; - ptrs[offset] = ptrElems[i]; - if (llMask) - masks[offset] = maskElems[i]; - if (otherElems.size()) - others[offset] = otherElems[i]; - } - - unsigned numOperandsPer2DLoadM, numOperandsPer2DLoadN; - if (opIdx == DpasEncodingAttr::OpIdx::OperandA) - return failure(); - - if (!usePackedType) - return failure(); - - std::swap(tileHeight, tileWidth); - - // We can decompose the matrix returned by transposed large 2d load - // when threads per warp < column size. Otherwise we have to load one - // operand per inst. - // Note: the tileHeight and numOperandsPer2DLoadM are the column size - // now. - numOperandsPer2DLoadM = - (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1; - // The transpose 2d load only support 1 operand per inst on column. - // (vBlocks = 1) - numOperandsPer2DLoadN = 1; - - // adjust the mask constancy to fit the 2D load. - numOperandsPer2DLoadM = - std::min(numOperandsPer2DLoadM, maskConstancyHor / instWidth); - numOperandsPer2DLoadN = - std::min(numOperandsPer2DLoadN, maskConstancyVer / instHeight); - - // PVC 2D load supports 32 rows at most. Load multiple dot operands in by - // enlarging the tileHeight. - constexpr unsigned MAX_TILE_HEIGHT = 32; - numOperandsPer2DLoadM = - std::min(numOperandsPer2DLoadM, - static_cast(MAX_TILE_HEIGHT / tileHeight)); - - // PVC 2D load supports 64 bytes per row at most. Load multiple dot operands - // by enlarging the vBlocks. - unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8; - constexpr int MAX_WIDTH = 64; - if (totalBytesPerRowPerDPASOp > MAX_WIDTH) - return failure(); - numOperandsPer2DLoadN = - std::min(numOperandsPer2DLoadN, MAX_WIDTH / totalBytesPerRowPerDPASOp); - // vBlocks has HW limitation of 4. - numOperandsPer2DLoadN = std::min(numOperandsPer2DLoadN, 4u); - - tileHeight = instHeight * numOperandsPer2DLoadM; - tileWidth = instWidth; - vBlocks = numOperandsPer2DLoadN; - - numOperandsOuterDimPerLoad = opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsPer2DLoadM - : numOperandsPer2DLoadN; - numOperandsInnerDimPerLoad = opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsPer2DLoadN - : numOperandsPer2DLoadM; - - std::swap(numOperandsOuterDimPerLoad, numOperandsInnerDimPerLoad); - - unsigned numLoadPerOutRepCluster = - mlir::ceil(repCluster[dimOuter], numOperandsOuterDimPerLoad); - unsigned numLoadPerInnerRepCluster = - mlir::ceil(repCluster[dimInner], numOperandsInnerDimPerLoad); - - unsigned numValuesPerLoad = packedElemsPerLanePerDPASInst * - numOperandsOuterDimPerLoad * - numOperandsInnerDimPerLoad; - Type load2DGenXType = - LLVM::getVectorType(loadResultElemType, numValuesPerLoad); - - // Step 4: Generates the load instruction. - // The stride for the tile replicates. - unsigned numRepOuter; - unsigned numRepInner; - unsigned repOuterStride = warpShape[dimOuter] * outerDimWarpNum; - unsigned repInnerStride; - switch (opIdx) { - case DpasEncodingAttr::OpIdx::OperandA: - case DpasEncodingAttr::OpIdx::OperandB: - numRepOuter = numReps[dimOuter]; - numRepInner = - mlir::ceil(numReps[dimInner], numOperandsInnerDimPerLoad); - repInnerStride = warpShape[dimInner] * numOperandsInnerDimPerLoad; - break; - case DpasEncodingAttr::OpIdx::OperandC: - numRepOuter = numReps[dimOuter]; - numRepInner = numReps[dimInner]; - repInnerStride = warpShape[dimInner] * innerDimWarpNum; - break; - default: - llvm_unreachable("unknown DPAS operands index type."); - break; - } - - Value pitch = - getPitch(rewriter, ptr, elemSizeInBits, memoryRowMajor ? 0 : 1); - if (!pitch) - return failure(); - - // If the stride is 0, we want to load only the first row. - int stride = getStride(ptr, memoryRowMajor ? 0 : 1); - unsigned baseHeightInt = (stride == 0 ? 1 : tileHeight); - Value baseHeight = b.i32_val(baseHeightInt); - Value baseWidth = - b.i32_val(std::max(64u, vBlocks * tileWidth * (elemSizeInBits / 8))); - - StringAttr kRegister = str_attr("register"); - StringAttr kLane = str_attr("lane"); - StringAttr kWarp = str_attr("warp"); - StringAttr kBlock = str_attr("block"); - - const unsigned originalElemBits = elemSizeInBits; - - LDBG("Block io tile shape: [" - << tileHeight << ", " << tileWidth << "], vblocks: " << vBlocks - << ", numOperandsPerLoad: [" - << (opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsOuterDimPerLoad - : numOperandsInnerDimPerLoad) - << ", " - << (opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsInnerDimPerLoad - : numOperandsOuterDimPerLoad) - << "], number loads per repCluster: [" - << (opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numLoadPerOutRepCluster - : numLoadPerInnerRepCluster) - << ", " - << (opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numLoadPerInnerRepCluster - : numLoadPerOutRepCluster) - << "], number repCluster: [" - << (opIdx != DpasEncodingAttr::OpIdx::OperandB ? numRepOuter - : numRepInner) - << ", " - << (opIdx != DpasEncodingAttr::OpIdx::OperandB ? numRepInner - : numRepOuter) - << "]"); - - ValueTable loadVals; - for (int inner = 0; inner < numRepInner; ++inner) { - for (int outer = 0; outer < numRepOuter; ++outer) { - for (int loadInner = 0; loadInner < numLoadPerInnerRepCluster; - ++loadInner) { - for (int loadOuter = 0; loadOuter < numLoadPerOutRepCluster; - ++loadOuter) { - unsigned offsetOuter = - outer * repOuterStride + loadOuter * dpasInstShape[dimOuter] * - numOperandsOuterDimPerLoad; - unsigned offsetInner = - inner * repInnerStride + loadInner * dpasInstShape[dimInner] * - numOperandsInnerDimPerLoad; - unsigned offsetM = - (opIdx != DpasEncodingAttr::OpIdx::OperandB ? offsetOuter - : offsetInner); - unsigned offsetN = - (opIdx != DpasEncodingAttr::OpIdx::OperandB ? offsetInner - : offsetOuter); - - LDBG("Block load iterator: inner: " - << inner << ", outer:" << outer << ", loadInner:" << loadInner - << ", loadOuter:" << loadOuter << " offset: [" << offsetM - << ", " << offsetN << "]"); - - Value offsetY = b.i32_val(0); - Value pred; - if (llMask) { - assert(masks.size() && "Invalid size of the masks."); - pred = targetInfo.shuffleIdx(rewriter, loc, - masks[{offsetM, offsetN}], 0); - // We leverage the GPU block I/O hardware out-of-bound protection - // feature by setting the offset to an invalid value when 'pred' - // is false (the HW will not read out-of-bounds values). Later on, - // after issuing the 2d block read operation, we will select the - // result of the load only if the mask evaluate to true, otherwise - // we will use 'other'. - offsetY = b.select(pred, offsetY, baseHeight); - } - - // Use the top-left address of the block to load the data. - Value addrElem = - b.bitcast(ptrs[{offsetM, offsetN}], ptr_ty(ctx, 1 /*global*/)); - addrElem = targetInfo.shuffleIdx(rewriter, loc, addrElem, 0); - - Value ret = rewriter.create( - loc, load2DGenXType, - /*ptr*/ addrElem, - /*base_width*/ baseWidth, - /*base_height*/ baseHeight, - /*base_pitch*/ pitch, - /*x*/ b.i32_val(0), - /*y*/ offsetY, - /*elem_size_in_bits*/ elemSizeInBits, - /*tile_width*/ tileWidth, - /*tile_height*/ tileHeight, - /*v_blocks*/ vBlocks, - /*transpose*/ false, - /*vnni_transform*/ - (usePackedType && opIdx == DpasEncodingAttr::OpIdx::OperandB && - !isTransposeRequired && originalElemBits != 32)); - - // When strides[0] is 0, we only want to load the first row, so we - // set the base height to be 1. If tile height is bigger than 1, - // then only the first row contain valid data. To ensure the entire - // tile is filled with valid data, we must replicate the first row - // throughout the tile. - if (baseHeightInt < tileHeight && baseHeightInt == 1) { - unsigned numIndicesPerMatrix = numValuesPerLoad / vBlocks; - SmallVector shuffleIndices(numValuesPerLoad); - - // Create a vector to store the data of the first index of each - // matrix. - VectorType vecTy = vec_ty(loadResultElemType, vBlocks); - Value firstIndexVec = b.undef(vecTy); - - for (unsigned valueIndex = 0; valueIndex < numValuesPerLoad; - ++valueIndex) { - unsigned firstIndexVecIdx = valueIndex / numIndicesPerMatrix; - // Handle case where an index spans two rows. - if (valueIndex % numIndicesPerMatrix == 0) { - Value oldVal = b.extract_element(ret, b.i32_val(valueIndex)); - Value newVal = oldVal; - if (tileWidth < threadsPerWarp) { - assert(tileWidth * 2 == threadsPerWarp && - "Expecting tileWidth to be 2x threadsPerWarp"); - Value threadId = getThreadId(rewriter, loc); - newVal = targetInfo.shuffleIdx( - rewriter, loc, oldVal, - b.urem(threadId, b.i32_val(tileWidth))); - } - firstIndexVec = - b.insert_element(firstIndexVec.getType(), firstIndexVec, - newVal, b.i32_val(firstIndexVecIdx)); - } - - shuffleIndices[valueIndex] = firstIndexVecIdx; - } - DenseI32ArrayAttr attr = - rewriter.getDenseI32ArrayAttr(shuffleIndices); - ret = rewriter.create( - loc, load2DGenXType, firstIndexVec, firstIndexVec, attr); - } - - if (others.size()) { - assert(masks.size() == others.size() && - "The mask value has to be provided when " - "the other value is provided."); - VectorType vecTy = - vec_ty(eltTy, numValuesPerLoad * packedElemsNum); - - Value v = b.undef(vecTy); - unsigned nWords = 0; - for (int vblk = 0; vblk < vBlocks; ++vblk) - for (int i = 0; i < tileHeight; ++i) { - unsigned numColPerPackedValue = - opIdx == DpasEncodingAttr::OpIdx::OperandA - ? packedElemsNum - : 1; - unsigned numPackedValuesPerRow = mlir::ceil( - (tileWidth / numColPerPackedValue), threadsPerWarp); - for (int col = 0; col < numPackedValuesPerRow; ++col) { - for (int packedCol = 0; packedCol < numColPerPackedValue; - ++packedCol) { - unsigned N = packedCol + - col * threadsPerWarp * numColPerPackedValue + - vblk * tileWidth + offsetN; - unsigned M = i + offsetM; - Value falseVal = others[{M, N}]; - Value sVal = createIndexAttrConstant( - rewriter, loc, typeConverter->getIndexType(), - nWords++); - v = b.insert_element(vecTy, v, falseVal, sVal); - } - } - } - Value others = b.bitcast(v, load2DGenXType); - ret = b.select(pred, ret, others); - } - - unsigned numOperandsM = opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsOuterDimPerLoad - : numOperandsInnerDimPerLoad; - unsigned numOperandsN = opIdx != DpasEncodingAttr::OpIdx::OperandB - ? numOperandsInnerDimPerLoad - : numOperandsOuterDimPerLoad; - - // Split the return matrix by large 2d block io size into multiple - // DPAS operands. - assert(numOperandsN >= vBlocks && - "numOperandsN has to be >= vBlocks"); - unsigned numOperandsPerVBlockN = numOperandsN / vBlocks; - for (int vblk = 0; vblk < vBlocks; ++vblk) - for (int row = 0; row < numOperandsM; ++row) - for (int col = 0; col < numOperandsPerVBlockN; ++col) { - - unsigned operandStartOffset = (vblk * numOperandsM + row) * - numOperandsPerVBlockN * - packedElemsPerLanePerDPASInst; - - SmallVector indices(packedElemsPerLanePerDPASInst); - for (int elemIdx = 0; elemIdx < packedElemsPerLanePerDPASInst; - ++elemIdx) { - indices[elemIdx] = operandStartOffset + - elemIdx * numOperandsPerVBlockN + col; - } - - LLVM_DEBUG({ - DBGS() << "shuffle idx: ["; - for (int elemIdx = 0; - elemIdx < packedElemsPerLanePerDPASInst; ++elemIdx) { - llvm::dbgs() << indices[elemIdx] << ", "; - } - llvm::dbgs() << "]\n"; - }); - - DenseI32ArrayAttr attr = - rewriter.getDenseI32ArrayAttr(indices); - Value loadVal = rewriter.create( - loc, packedDPASOperandType, ret, ret, attr); - - // Save the decomposed vals to the map; - switch (opIdx) { - case DpasEncodingAttr::OpIdx::OperandC: - case DpasEncodingAttr::OpIdx::OperandA: { - unsigned o = outer * numLoadPerOutRepCluster * - numOperandsOuterDimPerLoad + - loadOuter * numOperandsOuterDimPerLoad + row; - unsigned i = inner * numLoadPerInnerRepCluster * - numOperandsInnerDimPerLoad + - loadInner * numOperandsInnerDimPerLoad + - vblk * numOperandsPerVBlockN + col; - - LDBG("insert: [" << o << ", " << i << "]"); - loadVals[{o, i}] = - b.bitcast(loadVal, unpackedDPASOperandType); - } break; - case DpasEncodingAttr::OpIdx::OperandB: { - unsigned o = outer * numLoadPerOutRepCluster * - numOperandsOuterDimPerLoad + - loadOuter * numOperandsOuterDimPerLoad + - vblk * numOperandsPerVBlockN + col; - unsigned i = inner * numOperandsInnerDimPerLoad + row; - LDBG("insert: [" << o << ", " << i << "]"); - loadVals[{o, i}] = - b.bitcast(loadVal, unpackedDPASOperandType); - } break; - default: { - llvm_unreachable("unknown DPAS operands index type."); - } break; - } - } - } - } - } - } - - // Step 5: Unpack the load values. - // Extract the value returned by the load ops. And put the values in the - // expected order for the layout. - SmallVector unpackedLoadedVals; - for (int outer = 0; outer < numReps[dimOuter]; ++outer) { - for (int inner = 0; inner < numReps[dimInner]; ++inner) { - for (int repOuter = 0; repOuter < repCluster[dimOuter]; ++repOuter) { - for (int repInner = 0; repInner < repCluster[dimInner]; ++repInner) { - unsigned o = outer * repCluster[dimOuter] + repOuter; - unsigned i = inner * repCluster[dimInner] + repInner; - LDBG("extract: [" << o << ", " << i << "]"); - Value loadVal = loadVals.at({o, i}); - VectorType loadTy = cast(loadVal.getType()); - for (int i = 0; i < loadTy.getNumElements(); ++i) { - auto val = b.extract_element(loadVal, b.i32_val(i)); - unpackedLoadedVals.push_back(val); - } - loadVals.erase({o, i}); - } - } - } - } - - assert(loadVals.empty() && "not all loaded values is unpacked."); - - Type llvmResultStructTy = typeConverter->convertType(op.getType()); - Value resultStruct = packLLElements(loc, typeConverter, unpackedLoadedVals, - rewriter, llvmResultStructTy); - rewriter.replaceOp(op, {resultStruct}); - - return success(); - } - LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { @@ -2521,18 +1887,6 @@ struct LoadOpToBlockIOConversion if (tileHeight * tileWidth * packedElemSizeInBits / 8 < GRF_SIZE) vBlocks = 1; - // TODO: use the axis info to general the handling for both regular pointer - // and block pointer. - const bool memoryRowMajor = isMemoryRowMajor(op); - // FIXME: Add support of column major. - if (!memoryRowMajor) - return failure(); - - unsigned contiguousDim = memoryRowMajor ? 1 : 0; - const bool isTransposeRequired = contiguousDim != colDim; - if (isTransposeRequired) - return matchAndRewriteTranspose(op, adaptor, rewriter); - Location loc = op.getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); @@ -2661,6 +2015,55 @@ struct LoadOpToBlockIOConversion } } + // TODO: use the axis info to general the handling for both regular pointer + // and block pointer. + const bool memoryRowMajor = isMemoryRowMajor(op); + unsigned contiguousDim = memoryRowMajor ? 1 : 0; + const bool isTransposeRequired = contiguousDim != colDim; + + if (isTransposeRequired) { + if (numPackedVals > 1) + return failure(); + if (elemSizeInBits > 32) + return failure(); + if (tileWidth > 32) + return failure(); // tileWidth is limited to 32 for transpose 2d load. + + vBlocks = 1; + + // use the d32 for transpose 2d load. + packedElemSizeInBits = 32; + numPackedVals = packedElemSizeInBits / elemSizeInBits; + if (numPackedVals > 1 && tileWidth != threadsPerWarp) + return failure(); // Couldn't use the transpose 2d load for un-packable + // along tile height dim. + tileHeight = std::min(tileHeight / numPackedVals, 8); + + if (tileHeight * tileWidth < threadsPerWarp) + return failure(); // The tile size is not large enough for IGC scalar + // backend vectorization. + // transpose the width and height of the tile + std::swap(tileHeight, tileWidth); + // if (oneMatrixPerLoadForBT) { + // // Only load 1 operand per inst on row. + // numOperandsPer2DLoadM = 1; + // tileHeight = elemsPerDPASInst[threadOrder[rank - 2]]; + // } else { + // // We can decompose the matrix returned by transposed large 2d load + // // when threads per warp < column size. Otherwise we have to load one + // // operand per inst. + // // Note: the tileHeight and numOperandsPer2DLoadM are the column size + // // now. + // numOperandsPer2DLoadM = + // (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1; + // } + // // The transpose 2d load only support 1 operand per inst on column. + // // (vBlocks = 1) + // numOperandsPer2DloadN = 1; + // // TODO: support load column major data. + // return failure(); + } + int64_t numElemsPerLoad = mlir::ceil( tileHeight * tileWidth * numPackedVals * vBlocks, (int)threadsPerWarp); unsigned numValuesPerLoad = mlir::ceil((int)numElemsPerLoad, numPackedVals); @@ -2740,8 +2143,6 @@ struct LoadOpToBlockIOConversion } } break; case DpasEncodingAttr::OpIdx::OperandB: { - assert(numPackedVals == 1 && - "invalid number of packed values for DPAS operand B."); unsigned elemsPerLanePerDPASInst = product(dpasLayout.getDPASInstShapeB()) / threadsPerWarp; // Block 2D contain at least one DotOp B. @@ -2751,6 +2152,9 @@ struct LoadOpToBlockIOConversion if (tileHeight >= (opsPerChannel * sysDepth) && ((opsPerChannel == 4 && elemSizeInBits == 8) || (opsPerChannel == 2 && elemSizeInBits == 16))) { + assert(!isTransposeRequired || + opsPerChannel == numPackedVals && + "invalid opsPerChannel for transposed DotOp B"); // Use the VNNI packing format for DotOp B layout. numValuesPerLoad = numElemsPerLoad / opsPerChannel; packedType = i32_ty; @@ -2814,8 +2218,8 @@ struct LoadOpToBlockIOConversion /*tile_width*/ tileWidth, /*tile_height*/ tileHeight, /*v_blocks*/ vBlocks, - /*transpose*/ false, - /*vnni_transform*/ useVNNIFormat); + /*transpose*/ isTransposeRequired, + /*vnni_transform*/ !isTransposeRequired && useVNNIFormat); // When strides[0] is 0, we only want to load the first row, so we // set the base height to be 1. If tile height is bigger than 1,