From ba5d6d6b69762b542b94cd062ec411b3768f1440 Mon Sep 17 00:00:00 2001 From: Yushi Ogiwara Date: Mon, 7 Jul 2025 16:24:02 +0900 Subject: [PATCH 1/2] Update bitnet_kernels.h specify alignment of A_local --- gpu/bitnet_kernels/bitnet_kernels.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gpu/bitnet_kernels/bitnet_kernels.h b/gpu/bitnet_kernels/bitnet_kernels.h index 1d897908f..a21d5cc8e 100644 --- a/gpu/bitnet_kernels/bitnet_kernels.h +++ b/gpu/bitnet_kernels/bitnet_kernels.h @@ -49,7 +49,7 @@ __global__ void __launch_bounds__(128) ladder_int8xint2_kernel(int8_t* __restric constexpr int wmma_K = 32; constexpr int wmma_N = 16; int in_thread_C_local[1]; - signed char A_local[K_per_loop]; + alignas(16) signed char A_local[K_per_loop]; int B_reshape_local[1]; signed char B_decode_local[K_per_loop]; int red_buf0[1]; @@ -80,4 +80,4 @@ __global__ void __launch_bounds__(128) ladder_int8xint2_kernel(int8_t* __restric int ws_idx = out_idx / (N / ws_num); if (threadIdx.x == 0) dtype_transform[out_idx] = (__nv_bfloat16)(((float)red_buf0[0])/(float)s[0]*(float)ws[ws_idx]); -} \ No newline at end of file +} From ab915c3845b4830f3585fe06bfa1ac7ee6ce291b Mon Sep 17 00:00:00 2001 From: yushiogiwara Date: Sat, 20 Sep 2025 14:21:51 +0900 Subject: [PATCH 2/2] support multi batch on ladder_int8xint2_kernel --- gpu/bitnet_kernels/bitnet_kernels.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gpu/bitnet_kernels/bitnet_kernels.h b/gpu/bitnet_kernels/bitnet_kernels.h index a21d5cc8e..365ad28f7 100644 --- a/gpu/bitnet_kernels/bitnet_kernels.h +++ b/gpu/bitnet_kernels/bitnet_kernels.h @@ -56,7 +56,7 @@ __global__ void __launch_bounds__(128) ladder_int8xint2_kernel(int8_t* __restric in_thread_C_local[0] = 0; #pragma unroll for (int k_0 = 0; k_0 < K/(K_per_loop * K_block_size); ++k_0) { - *(int4*)(A_local + 0) = *(int4*)(A + ((k_0 * K_per_loop * K_block_size) + (((int)threadIdx.x) * K_per_loop))); + *(int4*)(A_local + 0) = *(int4*)(A + blockIdx.y * K + ((k_0 * K_per_loop * K_block_size) + (((int)threadIdx.x) * K_per_loop))); B_reshape_local[0] = *(int*)(B + (((int)blockIdx.x) * N_block_size * K / 4) + (k_0 * K_block_size * K_per_loop * wmma_N / 4) + @@ -76,7 +76,7 @@ __global__ void __launch_bounds__(128) ladder_int8xint2_kernel(int8_t* __restric for (int offset = K_block_size/2; offset > 0; offset /= 2) { red_buf0[0] += __shfl_down_sync(__activemask(), red_buf0[0], offset, K_block_size); } - int out_idx = ((((int)blockIdx.x) * N_block_size) + ((int)threadIdx.y)); + int out_idx = ( blockIdx.y * K + (((int)blockIdx.x) * N_block_size) + ((int)threadIdx.y)); int ws_idx = out_idx / (N / ws_num); if (threadIdx.x == 0) dtype_transform[out_idx] = (__nv_bfloat16)(((float)red_buf0[0])/(float)s[0]*(float)ws[ws_idx]);