Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions python/test/unit/intel/test_block_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ def warps_per_cta(layout):
@pytest.mark.parametrize("layout", layouts)
@pytest.mark.parametrize("load_block_ptr, store_block_ptr", [(True, True), (False, False), (True, False),
(False, True)])
@pytest.mark.parametrize("transpose", [True, False])
@pytest.mark.skipif(not is_xpu(), reason="Block store tests are specific to the XPU backend")
def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, device, tmp_path: pathlib.Path):
def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, transpose, device, tmp_path: pathlib.Path):

warps = warps_per_cta(layout)
num_warps = int(np.prod(warps))
Expand All @@ -132,16 +133,18 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi

support_block_io = torch.xpu.get_device_capability()['has_subgroup_2d_block_io']

block_io = "\"column_major\"" if transpose else "\"row_major\""

if load_block_ptr:
load_ops = f"""
%src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
%store_val = tt.load %src_ptr {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
%src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], {"[%c1_i64, %M_i64]" if transpose else "[%N_i64, %c1_i64]"}, [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
%store_val = tt.load %src_ptr {{ttig.block_io = {block_io}, boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
"""
else:
load_ops = f"""
%src_base = tt.splat %src : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
%src_ptr = tt.addptr %src_base, %row_major_off : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
%store_val = tt.load %src_ptr {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
%src_ptr = tt.addptr %src_base, {"%col_major_off" if transpose else "%row_major_off" } : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
%store_val = tt.load %src_ptr {{ttig.block_io = {block_io}}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
"""
if store_block_ptr:
store_ops = f"""
Expand Down Expand Up @@ -175,6 +178,12 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
%7 = tt.broadcast %5 : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
%row_major_off = arith.addi %6, %7 : tensor<{M}x{N}xi32, #layout>

%stride_M = arith.constant dense<{M}> : tensor<1x{N}xi32, #layout>
%col_stride = arith.muli %5, %stride_M : tensor<1x{N}xi32, #layout>
%8 = tt.broadcast %2 : tensor<{M}x1xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
%9 = tt.broadcast %col_stride : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
%col_major_off = arith.addi %8, %9 : tensor<{M}x{N}xi32, #layout>

{load_ops}
{store_ops}

Expand All @@ -195,6 +204,8 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
temp_file.write_text(ir)
kernel = triton.compile(str(temp_file))

a = a.permute(1, 0).contiguous().permute(1, 0) if transpose else a

kernel[(1, 1, 1)](a, x)
assert torch.equal(a, x)

Expand Down
51 changes: 49 additions & 2 deletions test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
tt.func public @regular_pointer_block_io(%arg0: !tt.ptr<f16>) {

%a_mask = arith.constant dense<true> : tensor<256x64xi1, #mma>
%a_other = arith.constant dense<0.00e+00> : tensor<256x64xf16, #mma>
%a_other = arith.constant dense<1.00e+00> : tensor<256x64xf16, #mma>
// CHECK-NOT: llvm.cond_br

%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #mma}>>
Expand All @@ -389,7 +389,6 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
// CHECK: %[[TOP_LEFT_MASK_BOOL_64:.*]] = llvm.extractvalue {{.*}}[64] : !llvm.struct<(i1, i1, {{.*}}
// CHECK: %[[TOP_LEFT_MASK_BOOL_96:.*]] = llvm.extractvalue {{.*}}[96] : !llvm.struct<(i1, i1, {{.*}}


// CHECK: %[[BLOCK_SHAPE_Y:.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64
// CHECK: %[[VAL_2886:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64
Expand All @@ -402,6 +401,18 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
// CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32
// CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16>
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16>
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16>
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16>
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>

// CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64
// CHECK: %[[VAL_3046:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64
Expand All @@ -414,6 +425,18 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
// CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32
// CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16>
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16>
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16>
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16>
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>

// CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64
// CHECK: %[[VAL_3046:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64
Expand All @@ -426,6 +449,18 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
// CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32
// CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16>
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16>
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16>
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16>
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>

// CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64
// CHECK: %[[VAL_3046:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64
Expand All @@ -438,6 +473,18 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
// CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32
// CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16>
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16>
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16>
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16>
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
%11 = tt.load %10, %a_mask, %a_other {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<f16>, #mma>

tt.return
Expand Down
Loading
Loading