Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 09bd497

Browse files
committed
modify rvalue reference for store_2d
1 parent 4e45885 commit 09bd497

File tree

11 files changed

+41
-19
lines changed

11 files changed

+41
-19
lines changed

include/common/core/memory.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -769,24 +769,24 @@ __XETLA_API void xetla_store_global(
769769
unsigned SurfacePitch,
770770
int X,
771771
int Y,
772-
xetla_vector<T, N> Vals) {
772+
auto&& Vals) {
773773
if constexpr (std::is_same_v<T, bf16>) {
774-
xetla_vector<fp16, N> Vals_fp16 = Vals.xetla_format<fp16>();
775774
xetla_store_global<fp16, BlockWidth, BlockHeight, L1H, L2H>(
776775
reinterpret_cast<fp16*>(Ptr),
777776
SurfaceWidth,
778777
SurfaceHeight,
779778
SurfacePitch,
780779
X,
781780
Y,
782-
Vals_fp16);
781+
Vals.xetla_format<fp16>());
783782
} else {
784783
__ESIMD_ENS::lsc_store_2d<
785784
T,
786785
BlockWidth,
787786
BlockHeight,
788787
gpu::xetla::detail::get_cache_hint(L1H),
789-
gpu::xetla::detail::get_cache_hint(L2H)>(
788+
gpu::xetla::detail::get_cache_hint(L2H),
789+
N>(
790790
Ptr, SurfaceWidth - 1, SurfaceHeight - 1, SurfacePitch - 1, X, Y, Vals);
791791
}
792792
}

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ tile_load(tile_t& tile, payload_t& payload) {
8989

9090
static constexpr uint32_t num_block_x = tile_desc::num_block_x;
9191
static constexpr uint32_t num_block_y = tile_desc::num_block_y;
92-
// static constexpr uint32_t num_block = tile_desc::num_block;
92+
// static constexpr uint32_t num_block = tile_desc::num_block;
9393

9494
static constexpr gpu_arch arch_tag = payload_t::arch_tag;
9595

@@ -329,7 +329,7 @@ tile_load(tile_t& tile, payload_t& payload) {
329329
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
330330
native_type_t<load_dtype>,
331331
block_size_x / scale_factor,
332-
block_size_y,
332+
ld_blk_height,
333333
arr_len,
334334
trans,
335335
mem_transform,

include/subgroup/tile/impl/store_xe.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ tile_store(tile_t& tile, payload_t& payload) {
9898

9999
static constexpr uint32_t num_block_x = tile_desc::num_block_x;
100100
static constexpr uint32_t num_block_y = tile_desc::num_block_y;
101-
// static constexpr uint32_t num_block = tile_desc::num_block;
101+
// static constexpr uint32_t num_block = tile_desc::num_block;
102102

103103
using load_store_attr = typename arch_attr_t<
104104
payload_t::arch_tag>::template load_store_attr<msg_type::block_2d>;
@@ -145,7 +145,7 @@ tile_store(tile_t& tile, payload_t& payload) {
145145
#pragma unroll
146146
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
147147
int32_t offset_x = j * block_size_x;
148-
// xetla_tdescriptor tdesc = payload_row.row(j);
148+
// xetla_tdescriptor tdesc = payload_row.row(j);
149149
auto reg_blk = tile.reg.xetla_select<store_block_elems, 1>(
150150
(i * num_block_x + j) * block_elems);
151151
xetla_vector<dtype, store_block_elems> combine_blk;
@@ -163,7 +163,7 @@ tile_store(tile_t& tile, payload_t& payload) {
163163
for (uint32_t ii = 0; ii < block_size_y / st_block_size_y; ++ii) {
164164
constexpr uint32_t store_elems =
165165
st_block_size_y * block_size_x * arr_len;
166-
xetla_vector<dtype, store_elems> st_blk =
166+
auto st_blk =
167167
combine_blk.xetla_select<store_elems, 1>(ii * store_elems);
168168
// xetla_tstore_global<dtype, store_elems, L1, L2, payload_t::arch_tag>(
169169
// tdesc, st_blk);
@@ -173,7 +173,7 @@ tile_store(tile_t& tile, payload_t& payload) {
173173
st_block_size_y,
174174
L1,
175175
L2>(
176-
payload.base_ptr,
176+
reinterpret_cast<dtype*>(payload.base_ptr),
177177
payload.surface_width,
178178
payload.surface_height,
179179
payload.surface_pitch,
@@ -210,7 +210,7 @@ tile_store(tile_t& tile, payload_t& payload) {
210210
blk_remained_y,
211211
L1,
212212
L2>(
213-
payload.base_ptr,
213+
reinterpret_cast<dtype*>(payload.base_ptr),
214214
payload.surface_width,
215215
payload.surface_height,
216216
payload.surface_pitch,
@@ -240,7 +240,7 @@ tile_store(tile_t& tile, payload_t& payload) {
240240
#pragma unroll
241241
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
242242
int offset_x = j * block_size_x;
243-
// xetla_tdescriptor tdesc = payload_row.row(j);
243+
// xetla_tdescriptor tdesc = payload_row.row(j);
244244
auto reg_blk = tile.reg.xetla_select<remained_block_elems * arr_len, 1>(
245245
processed_elems + j * remained_block_elems);
246246
// Do combination
@@ -271,7 +271,7 @@ tile_store(tile_t& tile, payload_t& payload) {
271271
remained_st_blk_size_y,
272272
L1,
273273
L2>(
274-
payload.base_ptr,
274+
reinterpret_cast<dtype*>(payload.base_ptr),
275275
payload.surface_width,
276276
payload.surface_height,
277277
payload.surface_pitch,
@@ -308,7 +308,7 @@ tile_store(tile_t& tile, payload_t& payload) {
308308
final_st_blk_size_y,
309309
L1,
310310
L2>(
311-
payload.base_ptr,
311+
reinterpret_cast<dtype*>(payload.base_ptr),
312312
payload.surface_width,
313313
payload.surface_height,
314314
payload.surface_pitch,

tests/integration/default_config/group_gemm/kernel_func.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ struct default_config_group_gemm_test_func {
108108

109109
using gemm_op_t = gemm_universal_t<dispatch_policy, gemm_t, epilogue_t>;
110110

111+
static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
112+
static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();
113+
111114
static const char* func_name() {
112115
return "default_config_group_gemm_test_func";
113116
}

tests/integration/default_config/kernel_gemm/kernel_func.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ struct default_config_kernel_gemm_test_func {
6565
gpu_arch::XeHpc, // GPU arch
6666
tune_option>;
6767

68+
static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
69+
static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();
70+
6871
static const char* func_name() {
6972
return "default_config_kernel_gemm_test_func";
7073
}

tests/integration/gemm/bf16/kernel_func.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ struct bf16_gemm_test_func {
7676

7777
using gemm_op_t = gemm_universal_t<dispatch_policy, gemm_t, epilogue_t>;
7878

79+
static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
80+
static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();
81+
7982
static const char* func_name() {
8083
return "bf16_gemm_test_func";
8184
}

tests/integration/gemm/fp32/kernel_func.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ struct fp32_gemm_test_func {
7777

7878
using gemm_op_t = gemm_universal_t<dispatch_policy, gemm_t, epilogue_t>;
7979

80+
static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
81+
static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();
82+
8083
static const char* func_name() {
8184
return "fp32_gemm_test_func";
8285
}

tests/integration/gemm/int8/kernel_func.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ struct int8gemm_test_func {
7272

7373
using gemm_op_t = gemm_universal_t<dispatch_policy, gemm_t, epilogue_t>;
7474

75+
static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
76+
static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();
77+
7578
static const char* func_name() {
7679
return "int8gemm_test_func";
7780
}

tests/integration/gemm/tf32/kernel_func.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ struct tf32_gemm_test_func {
7171

7272
using gemm_op_t = gemm_universal_t<dispatch_policy, gemm_t, epilogue_t>;
7373

74+
static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
75+
static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();
76+
7477
static const char* func_name() {
7578
return "tf32_gemm_test_func";
7679
}

tests/integration/gemm/unaligned_bf16/kernel_func.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,20 @@ struct unaligned_gemm_test_func {
6868
using epilogue_t = epilogue_t<
6969
epilogue_policy_unaligned<arch_tag>,
7070
tile_shape,
71-
mem_desc_t<dtype_c, mem_layout::row_major, mem_space::global, ldc_alignment>>;
71+
mem_desc_t<
72+
dtype_c,
73+
mem_layout::row_major,
74+
mem_space::global,
75+
ldc_alignment>>;
7276

7377
using group_swizzle = gpu::xetla::kernel::group_swizzle_default<arch_tag>;
7478
using dispatch_policy =
7579
dispatch_policy_kslicing<group_swizzle, global_kslicing, local_kslicing>;
7680
using gemm_op_t = gemm_universal_t<dispatch_policy, gemm_t, epilogue_t>;
7781

82+
static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
83+
static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();
84+
7885
static const char* func_name() {
7986
return "unaligned_gemm_test_func";
8087
}

0 commit comments

Comments
 (0)