Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 4e45885

Browse files
committed
delete tdesc from load_xe and store_xe
1 parent 71c2c1d commit 4e45885

File tree

3 files changed

+124
-116
lines changed

3 files changed

+124
-116
lines changed

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 67 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ tile_load(tile_t& tile, payload_t& payload) {
8989

9090
static constexpr uint32_t num_block_x = tile_desc::num_block_x;
9191
static constexpr uint32_t num_block_y = tile_desc::num_block_y;
92-
static constexpr uint32_t num_block = tile_desc::num_block;
92+
// static constexpr uint32_t num_block = tile_desc::num_block;
9393

9494
static constexpr gpu_arch arch_tag = payload_t::arch_tag;
9595

@@ -176,24 +176,24 @@ tile_load(tile_t& tile, payload_t& payload) {
176176
((block_size_y * sizeof(dtype)) % sizeof(load_dtype) == 0),
177177
"check vnni limitation for DW transpose");
178178

179-
auto payload_2d = payload.payloads.xetla_format<uint32_t, num_block, 16>();
179+
// auto payload_2d = payload.payloads.xetla_format<uint32_t, num_block, 16>();
180180
#pragma unroll
181181
for (uint32_t i = 0; i < num_block_y; ++i) {
182182
constexpr uint32_t load_block_elems = block_elems * arr_len;
183183
int offset_y = i * block_size_y;
184-
auto payload_row =
185-
payload_2d.xetla_select<num_block_x, 1, 16, 1>(i * num_block_x, 0);
186-
detail::reset_tile_desc_core<
187-
num_block_x,
188-
block_size_x,
189-
ld_blk_size_y,
190-
scale_factor,
191-
arr_len,
192-
mem_transpose>(payload_row);
184+
// auto payload_row =
185+
// payload_2d.xetla_select<num_block_x, 1, 16, 1>(i * num_block_x, 0);
186+
// detail::reset_tile_desc_core<
187+
// num_block_x,
188+
// block_size_x,
189+
// ld_blk_size_y,
190+
// scale_factor,
191+
// arr_len,
192+
// mem_transpose>(payload_row);
193193
#pragma unroll
194194
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
195-
uint32_t offset_x = j * block_size_x;
196-
xetla_tdescriptor tdesc = payload_row.row(j);
195+
int32_t offset_x = j * block_size_x;
196+
// xetla_tdescriptor tdesc = payload_row.row(j);
197197
auto reg_blk = tile.reg.xetla_select<load_block_elems, 1>(
198198
(i * num_block_x + j) * block_elems);
199199
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
@@ -203,7 +203,6 @@ tile_load(tile_t& tile, payload_t& payload) {
203203
xetla_vector<dtype, tmp_size> reg_tmp;
204204
#pragma unroll
205205
for (uint32_t ii = 0; ii < block_size_y / ld_blk_size_y; ++ii) {
206-
// offset_y += ld_blk_size_y;
207206
constexpr uint32_t load_elems = ld_blk_size_y * block_size_x * arr_len;
208207
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
209208
native_type_t<load_dtype>,
@@ -220,12 +219,14 @@ tile_load(tile_t& tile, payload_t& payload) {
220219
payload.surface_width,
221220
payload.surface_height,
222221
payload.surface_pitch,
223-
mem_transpose
224-
// ? (payload.offset_x + offset_y / scale_factor)
225-
? ::gpu::xetla::detail::xetla_get_tensor_offset_x(tdesc)
226-
: (payload.offset_x + offset_x / scale_factor),
227-
228-
payload.offset_y + (mem_transpose ? offset_x : offset_y));
222+
payload.offset_x +
223+
(mem_transpose ? (offset_y / (int)scale_factor +
224+
ii * ld_blk_size_y / (int)scale_factor)
225+
: (offset_x / scale_factor)),
226+
227+
payload.offset_y +
228+
(mem_transpose ? offset_x : (offset_y + ii * ld_blk_size_y)));
229+
229230
if constexpr (reg_transpose && trans) {
230231
reg_blk.xetla_select<load_elems, 1>(ii * load_elems)
231232
.xetla_format<native_type_t<load_dtype>>() =
@@ -242,13 +243,6 @@ tile_load(tile_t& tile, payload_t& payload) {
242243
} else {
243244
reg_blk.xetla_select<tmp_size, 1>(ii * tmp_size) = reg_tmp;
244245
}
245-
if constexpr (mem_transpose) {
246-
xetla_update_tdesc_offsetx(
247-
tdesc.xetla_format<uint32_t>(), ld_blk_size_y / scale_factor);
248-
} else {
249-
xetla_update_tdesc_offsety(
250-
tdesc.xetla_format<uint32_t>(), ld_blk_size_y);
251-
}
252246
}
253247
// exceed HW limitation
254248
if constexpr (block_size_y % ld_blk_size_y != 0) {
@@ -258,22 +252,22 @@ tile_load(tile_t& tile, payload_t& payload) {
258252
remained_start_y * block_size_x * arr_len;
259253
constexpr uint32_t remained_blk_size_y = block_size_y % ld_blk_size_y;
260254
constexpr uint32_t load_elems =
261-
remained_blk_size_y * block_size_x * arr_len;
255+
remained_blk_size_y * block_size_x * arr_len / scale_factor;
262256

263257
constexpr uint8_t block_width =
264-
mem_transpose ? (remained_blk_size_y / scale_factor) : block_size_x;
258+
(mem_transpose ? remained_blk_size_y : block_size_x) / scale_factor;
265259
constexpr uint8_t block_height =
266-
trans ? block_size_x : remained_blk_size_y;
267-
constexpr uint32_t block_widthx_widthy_arrlen =
268-
(block_width - 1) | ((block_height - 1) << 8);
269-
gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
270-
tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
260+
mem_transpose ? block_size_x : remained_blk_size_y;
261+
// constexpr uint32_t block_widthx_widthy_arrlen =
262+
// (block_width - 1) | ((block_height - 1) << 8);
263+
// gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
264+
// tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
271265

272266
reg_blk.xetla_select<load_elems, 1>(remained_start)
273267
.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
274268
native_type_t<load_dtype>,
275-
block_size_x / scale_factor,
276-
remained_blk_size_y,
269+
block_width,
270+
block_height,
277271
arr_len,
278272
trans,
279273
mem_transform,
@@ -283,8 +277,9 @@ tile_load(tile_t& tile, payload_t& payload) {
283277
payload.surface_width,
284278
payload.surface_height,
285279
payload.surface_pitch,
286-
::gpu::xetla::detail::xetla_get_tensor_offset_x(tdesc),
287-
::gpu::xetla::detail::xetla_get_tensor_offset_y(tdesc));
280+
payload.offset_x + offset_x / scale_factor,
281+
payload.offset_y + offset_y + remained_start_y);
282+
288283
// xetla_tload_global<
289284
// load_dtype,
290285
// (load_elems / scale_factor),
@@ -305,18 +300,19 @@ tile_load(tile_t& tile, payload_t& payload) {
305300
(!reg_transpose && (remained_size_y > ld_blk_size_y_limit))
306301
? ld_blk_size_y_limit
307302
: remained_size_y;
308-
auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
309-
num_block_y * num_block_x, 0);
310-
detail::reset_tile_desc_core<
311-
num_block_x,
312-
block_size_x,
313-
remained_ld_blk_size_y,
314-
scale_factor,
315-
arr_len,
316-
mem_transpose>(payload_row);
303+
// auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
304+
// num_block_y * num_block_x, 0);
305+
// detail::reset_tile_desc_core<
306+
// num_block_x,
307+
// block_size_x,
308+
// remained_ld_blk_size_y,
309+
// scale_factor,
310+
// arr_len,
311+
// mem_transpose>(payload_row);
317312
#pragma unroll
318313
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
319-
xetla_tdescriptor tdesc = payload_row.row(j);
314+
int32_t offset_x = j * block_size_x;
315+
// xetla_tdescriptor tdesc = payload_row.row(j);
320316
auto reg_blk = tile.reg.xetla_select<remained_block_elems * arr_len, 1>(
321317
processed_elems + j * remained_block_elems);
322318
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
@@ -343,8 +339,9 @@ tile_load(tile_t& tile, payload_t& payload) {
343339
payload.surface_width,
344340
payload.surface_height,
345341
payload.surface_pitch,
346-
::gpu::xetla::detail::xetla_get_tensor_offset_x(tdesc),
347-
::gpu::xetla::detail::xetla_get_tensor_offset_y(tdesc));
342+
payload.offset_x + offset_x / scale_factor,
343+
payload.offset_y + num_block_y * block_size_y +
344+
ii * remained_ld_blk_size_y);
348345
// xetla_tload_global<
349346
// load_dtype,
350347
// (ld_blk_height * block_size_x * arr_len / scale_factor),
@@ -370,14 +367,14 @@ tile_load(tile_t& tile, payload_t& payload) {
370367
} else {
371368
reg_blk.xetla_select<tmp_size, 1>(ii * tmp_size) = reg_tmp;
372369
}
373-
if constexpr (mem_transpose) {
374-
xetla_update_tdesc_offsetx(
375-
tdesc.xetla_format<uint32_t>(),
376-
remained_ld_blk_size_y / scale_factor);
377-
} else {
378-
xetla_update_tdesc_offsety(
379-
tdesc.xetla_format<uint32_t>(), remained_ld_blk_size_y);
380-
}
370+
// if constexpr (mem_transpose) {
371+
// xetla_update_tdesc_offsetx(
372+
// tdesc.xetla_format<uint32_t>(),
373+
// remained_ld_blk_size_y / scale_factor);
374+
// } else {
375+
// xetla_update_tdesc_offsety(
376+
// tdesc.xetla_format<uint32_t>(), remained_ld_blk_size_y);
377+
// }
381378
}
382379
constexpr uint32_t final_ld_blk_size_y =
383380
remained_size_y % remained_ld_blk_size_y;
@@ -388,18 +385,18 @@ tile_load(tile_t& tile, payload_t& payload) {
388385
constexpr uint32_t final_load_elems =
389386
final_ld_blk_size_y * block_size_x * arr_len;
390387
constexpr uint8_t block_width =
391-
mem_transpose ? (final_ld_blk_size_y / scale_factor) : block_size_x;
388+
(mem_transpose ? final_ld_blk_size_y : block_size_x) / scale_factor;
392389
constexpr uint8_t block_height =
393-
trans ? block_size_x : final_ld_blk_size_y;
394-
constexpr uint32_t block_widthx_widthy_arrlen =
395-
(block_width - 1) | ((block_height - 1) << 8);
396-
gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
397-
tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
390+
mem_transpose ? block_size_x : final_ld_blk_size_y;
391+
// constexpr uint32_t block_widthx_widthy_arrlen =
392+
// (block_width - 1) | ((block_height - 1) << 8);
393+
// gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
394+
// tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
398395
reg_blk.xetla_select<final_load_elems, 1>(final_start)
399396
.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
400397
native_type_t<load_dtype>,
401-
block_size_x / scale_factor,
402-
final_ld_blk_size_y,
398+
block_width,
399+
block_height,
403400
arr_len,
404401
trans,
405402
mem_transform,
@@ -409,8 +406,10 @@ tile_load(tile_t& tile, payload_t& payload) {
409406
payload.surface_width,
410407
payload.surface_height,
411408
payload.surface_pitch,
412-
::gpu::xetla::detail::xetla_get_tensor_offset_x(tdesc),
413-
::gpu::xetla::detail::xetla_get_tensor_offset_y(tdesc));
409+
payload.offset_x + offset_x / scale_factor,
410+
payload.offset_y + num_block_y * block_size_y +
411+
remained_size_y / remained_ld_blk_size_y *
412+
remained_ld_blk_size_y);
414413
// xetla_tload_global<
415414
// load_dtype,
416415
// final_load_elems / scale_factor,

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ struct mem_payload_t<
111111
this->surface_height =
112112
(mem_transpose ? mem_desc.shape.x : mem_desc.shape.y);
113113
this->surface_pitch = mem_desc.shape.stride * sizeof(dtype);
114-
this->offset_x =
115-
(mem_transpose ? mem_desc.coord.y : mem_desc.coord.x) / scale_factor;
114+
this->offset_x = (mem_transpose ? mem_desc.coord.y : mem_desc.coord.x) /
115+
int32_t(scale_factor);
116116
this->offset_y = mem_transpose ? mem_desc.coord.x : mem_desc.coord.y;
117117

118118
xetla_tdescriptor base_tdesc = mem_desc.get_tdesc();
@@ -134,7 +134,7 @@ struct mem_payload_t<
134134
this->surface_width = surface_width * sizeof(dtype);
135135
this->surface_height = surface_height;
136136
this->surface_pitch = surface_pitch * sizeof(dtype);
137-
this->offset_x = surface_offset_x / scale_factor;
137+
this->offset_x = surface_offset_x / int32_t(scale_factor);
138138
this->offset_y = surface_offset_y;
139139

140140
xetla_tdescriptor base_tdesc;
@@ -157,8 +157,8 @@ struct mem_payload_t<
157157
this->surface_height =
158158
(mem_transpose ? mem_desc.shape.x : mem_desc.shape.y);
159159
this->surface_pitch = mem_desc.shape.stride * sizeof(dtype);
160-
this->offset_x =
161-
(mem_transpose ? mem_desc.coord.y : mem_desc.coord.x) / scale_factor;
160+
this->offset_x = (mem_transpose ? mem_desc.coord.y : mem_desc.coord.x) /
161+
int32_t(scale_factor);
162162
this->offset_y = (mem_transpose ? mem_desc.coord.x : mem_desc.coord.y);
163163

164164
xetla_tdescriptor base_tdesc = mem_desc.get_tdesc();
@@ -188,7 +188,7 @@ struct mem_payload_t<
188188
this->surface_width = surface_width * sizeof(dtype);
189189
this->surface_height = surface_height;
190190
this->surface_pitch = surface_pitch * sizeof(dtype);
191-
this->offset_x = surface_offset_x / scale_factor;
191+
this->offset_x = surface_offset_x / int32_t(scale_factor);
192192
this->offset_y = surface_offset_y;
193193

194194
xetla_tdescriptor base_tdesc;

0 commit comments

Comments
 (0)