From 624f5b5b2d65e68e5e98b5039fed5a2be9ac4255 Mon Sep 17 00:00:00 2001 From: Rain Jiang Date: Thu, 18 Sep 2025 22:10:59 +0000 Subject: [PATCH 1/3] fix the dequantize_block in the trtllm_cutlass fuse moe test --- tests/test_trtllm_cutlass_fused_moe.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/tests/test_trtllm_cutlass_fused_moe.py b/tests/test_trtllm_cutlass_fused_moe.py index 5f9f04dc63..45ca3ed47f 100644 --- a/tests/test_trtllm_cutlass_fused_moe.py +++ b/tests/test_trtllm_cutlass_fused_moe.py @@ -215,7 +215,7 @@ def compute_with_experts( 1, ] HIDDEN_SIZES = [ - 128, + 256, ] NUM_EXPERTS = [2] TOP_K_VALUES = [2] @@ -884,6 +884,7 @@ def dequantize_block( scales: torch.Tensor, dtype: torch.dtype, original_shape: tuple, + block_size_n: int = 128, ) -> torch.Tensor: """ Dequantize a block-quantized tensor. @@ -904,8 +905,8 @@ def transform_dim(a: torch.Tensor, dim: int = -1) -> torch.Tensor: if dim != -1: a = a.transpose(dim, -1) # Broadcast and reshape - a_broadcasted = a.unsqueeze(-1).expand(*a.shape, 128) - a_reshaped = a_broadcasted.reshape(*a.shape[:-1], a.shape[-1] * 128) + a_broadcasted = a.unsqueeze(-1).expand(*a.shape, block_size_n) + a_reshaped = a_broadcasted.reshape(*a.shape[:-1], a.shape[-1] * block_size_n) # Move back if needed if dim != -1: a_reshaped = a_reshaped.transpose(dim, -1) @@ -913,9 +914,13 @@ def transform_dim(a: torch.Tensor, dim: int = -1) -> torch.Tensor: if x_quant.dim() == 2: # For activation tensors [batch_size, hidden_size] batch_size, hidden_size = x_quant.shape - num_blocks = (hidden_size + 127) // 128 - scales = scales.view(batch_size, num_blocks, 1).expand(-1, -1, 128) - scales = scales[:, :, : hidden_size % 128] if hidden_size % 128 != 0 else scales + num_blocks = (hidden_size + block_size_n - 1) // block_size_n + scales = ( + scales.view(batch_size, num_blocks, 1) + .expand(-1, -1, block_size_n) + .reshape(batch_size, -1) + ) + scales = scales[:, :hidden_size] else: # For weight tensors [..., in_dim, out_dim] *_dims, in_dim, out_dim = x_quant.shape @@ -924,10 +929,10 @@ def transform_dim(a: torch.Tensor, dim: int = -1) -> torch.Tensor: scales = transform_dim(scales, -2) # Second-to-last dim # Handle padding - if in_dim % 128 != 0: - scales = scales[..., : in_dim % 128, :] - if out_dim % 128 != 0: - scales = scales[..., :, : out_dim % 128] + if in_dim % block_size_n != 0: + scales = scales[..., : in_dim % block_size_n, :] + if out_dim % block_size_n != 0: + scales = scales[..., :, : out_dim % block_size_n] x_dequant = x_quant.to(dtype) * scales.to(dtype) return x_dequant.view(original_shape) From bffc2e951a46e8892ca85d50e825f2da4449b6be Mon Sep 17 00:00:00 2001 From: Rain Jiang Date: Thu, 18 Sep 2025 22:14:53 +0000 Subject: [PATCH 2/3] add docs for block_size_n --- tests/test_trtllm_cutlass_fused_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_trtllm_cutlass_fused_moe.py b/tests/test_trtllm_cutlass_fused_moe.py index 45ca3ed47f..e1be4f3b5c 100644 --- a/tests/test_trtllm_cutlass_fused_moe.py +++ b/tests/test_trtllm_cutlass_fused_moe.py @@ -894,6 +894,7 @@ def dequantize_block( scales: Block scaling factors dtype: Target dtype for dequantization original_shape: Original shape of the tensor before padding + block_size_n: Block size Returns: torch.Tensor: Dequantized tensor From 640b36fadb2cb010f8ab674c86ab78f7a07b8c1e Mon Sep 17 00:00:00 2001 From: Rain Jiang <96632942+rainj-me@users.noreply.github.com> Date: Thu, 18 Sep 2025 16:03:46 -0700 Subject: [PATCH 3/3] Update tests/test_trtllm_cutlass_fused_moe.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/test_trtllm_cutlass_fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_trtllm_cutlass_fused_moe.py b/tests/test_trtllm_cutlass_fused_moe.py index e1be4f3b5c..84b2eddae7 100644 --- a/tests/test_trtllm_cutlass_fused_moe.py +++ b/tests/test_trtllm_cutlass_fused_moe.py @@ -915,7 +915,7 @@ def transform_dim(a: torch.Tensor, dim: int = -1) -> torch.Tensor: if x_quant.dim() == 2: # For activation tensors [batch_size, hidden_size] batch_size, hidden_size = x_quant.shape - num_blocks = (hidden_size + block_size_n - 1) // block_size_n + num_blocks = ceil_div(hidden_size, block_size_n) scales = ( scales.view(batch_size, num_blocks, 1) .expand(-1, -1, block_size_n)