@@ -97,7 +97,7 @@ class Test3 : public TestBase {
97
97
static constexpr size_t mat_m = 16 ;
98
98
static constexpr size_t mat_n = 64 ;
99
99
static constexpr size_t mat_k = 32 ;
100
- static constexpr size_t wg_m = 8 ;
100
+ static constexpr size_t wg_m = 16 ;
101
101
static constexpr size_t wg_n = 64 ;
102
102
static constexpr size_t sg_m = 1 ;
103
103
static constexpr size_t sg_n = 64 ;
@@ -205,7 +205,7 @@ class Test8 : public TestBase {
205
205
static constexpr uint32_t global_kslicing = 2 ;
206
206
static constexpr uint32_t local_kslicing = 1 ;
207
207
static constexpr mem_layout layout_a = mem_layout::row_major;
208
- static constexpr mem_layout layout_b = mem_layout::col_major ;
208
+ static constexpr mem_layout layout_b = mem_layout::row_major ;
209
209
using data_type_a = float ;
210
210
using data_type_b = float ;
211
211
using data_type_c = float ;
@@ -227,7 +227,6 @@ class Test9 : public TestBase {
227
227
static constexpr uint32_t local_kslicing = 1 ;
228
228
static constexpr mem_layout layout_a = mem_layout::row_major;
229
229
static constexpr mem_layout layout_b = mem_layout::row_major;
230
- static constexpr mma_engine engine = mma_engine::xmx;
231
230
using data_type_a = float ;
232
231
using data_type_b = float ;
233
232
using data_type_c = float ;
@@ -245,10 +244,10 @@ class Test10 : public TestBase {
245
244
static constexpr size_t sg_m = 32 ;
246
245
static constexpr size_t sg_n = 64 ;
247
246
static constexpr size_t sg_k = 8 ;
248
- static constexpr uint32_t global_kslicing = 2 ;
247
+ static constexpr uint32_t global_kslicing = 1 ;
249
248
static constexpr uint32_t local_kslicing = 1 ;
250
249
static constexpr mem_layout layout_a = mem_layout::row_major;
251
- static constexpr mem_layout layout_b = mem_layout::col_major ;
250
+ static constexpr mem_layout layout_b = mem_layout::row_major ;
252
251
using data_type_a = float ;
253
252
using data_type_b = float ;
254
253
using data_type_c = float ;
@@ -258,9 +257,9 @@ class Test10 : public TestBase {
258
257
class Test11 : public TestBase {
259
258
public:
260
259
static constexpr size_t batch_size = 35 ;
261
- static constexpr size_t mat_m = 4192 ;
262
- static constexpr size_t mat_k = 1136 ;
263
- static constexpr size_t mat_n = 688 ;
260
+ static constexpr size_t mat_m = 4193 ;
261
+ static constexpr size_t mat_k = 1134 ;
262
+ static constexpr size_t mat_n = 686 ;
264
263
static constexpr size_t wg_m = 256 ;
265
264
static constexpr size_t wg_n = 256 ;
266
265
static constexpr size_t sg_m = 32 ;
@@ -314,4 +313,4 @@ class result_validate {
314
313
Test::layout_a,
315
314
Test::layout_b);
316
315
}
317
- };
316
+ };
0 commit comments