From 0c5e16c539935ae09a9ea880c67a4e18c1cdad8d Mon Sep 17 00:00:00 2001 From: WyldeCat Date: Tue, 14 Apr 2026 08:14:21 +0000 Subject: [PATCH 1/4] feat: add num_chunks_override to FusedLinearCrossEntropyLoss Allow users to override the auto-computed chunk count in FLCE. Default auto-calculation yields ~32 chunks for large vocab (V=220k), causing excessive elementwise kernel launches between chunks. Overriding to fewer chunks (e.g. 4-8) reduces kernel launch overhead with minimal memory impact since peak is dominated by activations. Also free chunk tensors (del logits_chunk, grad_logits_chunk, _input_chunk) at end of each loop iteration to prevent two logits chunks co-existing in GPU memory between iterations. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../ops/fused_linear_cross_entropy.py | 15 ++++++++++++--- .../transformers/fused_linear_cross_entropy.py | 2 ++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 01f1b5658..fc0da7a52 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -30,6 +30,7 @@ def fused_linear_cross_entropy_forward( use_token_scaling=False, return_token_accuracy=False, return_predicted_tokens=False, + num_chunks_override=None, ): assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" assert isinstance(return_token_accuracy, bool), ( @@ -53,9 +54,13 @@ def fused_linear_cross_entropy_forward( V = weight.shape[0] BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) - inc_factor = triton.cdiv(V, H) # (V + H - 1) // H - chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor - num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size + if num_chunks_override is not None: + chunk_size = triton.next_power_of_2(max(1, BT // num_chunks_override)) + num_chunks = triton.cdiv(BT, chunk_size) + else: + inc_factor = triton.cdiv(V, H) # (V + H - 1) // H + chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor + num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size grad_input = torch.zeros_like(_input, device=device) @@ -219,6 +224,8 @@ def fused_linear_cross_entropy_forward( alpha=1.0, ) + del logits_chunk, grad_logits_chunk, _input_chunk + # Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now. # if reduction == "none": # loss = loss_1d @@ -311,6 +318,7 @@ def forward( use_token_scaling: bool = False, return_token_accuracy: bool = False, return_predicted_tokens: bool = False, + num_chunks_override=None, ): """ Fusing the last linear layer with cross-entropy loss @@ -355,6 +363,7 @@ def forward( use_token_scaling=use_token_scaling, return_token_accuracy=return_token_accuracy, return_predicted_tokens=return_predicted_tokens, + num_chunks_override=num_chunks_override, ) ) # downcast to dtype and store for backward diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py index c4a4474ce..08b41eb47 100644 --- a/src/liger_kernel/transformers/fused_linear_cross_entropy.py +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -42,6 +42,7 @@ def __init__( self.use_token_scaling = use_token_scaling self.return_token_accuracy = return_token_accuracy self.return_predicted_tokens = return_predicted_tokens + self.num_chunks_override = None def forward(self, lin_weight, _input, target, bias=None): loss, z_loss, token_accuracy, predicted_tokens = LigerFusedLinearCrossEntropyFunction.apply( @@ -60,6 +61,7 @@ def forward(self, lin_weight, _input, target, bias=None): self.use_token_scaling, self.return_token_accuracy, self.return_predicted_tokens, + self.num_chunks_override, ) if not self.return_z_loss and not self.return_token_accuracy and not self.return_predicted_tokens: return loss From 35ef85b3416dd6e262092d8f69dab4a466ecd723 Mon Sep 17 00:00:00 2001 From: WyldeCat Date: Tue, 14 Apr 2026 08:16:35 +0000 Subject: [PATCH 2/4] fix: assert num_chunks_override evenly divides BT Co-Authored-By: Claude Opus 4.6 (1M context) --- src/liger_kernel/ops/fused_linear_cross_entropy.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index fc0da7a52..5dd9f46f0 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -55,6 +55,9 @@ def fused_linear_cross_entropy_forward( BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) if num_chunks_override is not None: + assert BT % num_chunks_override == 0, ( + f"num_chunks_override={num_chunks_override} must evenly divide BT={BT}" + ) chunk_size = triton.next_power_of_2(max(1, BT // num_chunks_override)) num_chunks = triton.cdiv(BT, chunk_size) else: From 2ba5af6ca3ba4c024fad0eadb29ef069990541e7 Mon Sep 17 00:00:00 2001 From: WyldeCat Date: Tue, 14 Apr 2026 08:33:46 +0000 Subject: [PATCH 3/4] feat: add num_chunks as __init__ parameter to LigerFusedLinearCrossEntropyLoss Co-Authored-By: Claude Opus 4.6 (1M context) --- src/liger_kernel/transformers/fused_linear_cross_entropy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py index 08b41eb47..ba365f51f 100644 --- a/src/liger_kernel/transformers/fused_linear_cross_entropy.py +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -20,6 +20,7 @@ def __init__( use_token_scaling: bool = False, return_token_accuracy: bool = False, return_predicted_tokens: bool = False, + num_chunks: Optional[int] = None, ): super().__init__() assert (label_smoothing >= 0) and (label_smoothing <= 1), ( @@ -42,7 +43,7 @@ def __init__( self.use_token_scaling = use_token_scaling self.return_token_accuracy = return_token_accuracy self.return_predicted_tokens = return_predicted_tokens - self.num_chunks_override = None + self.num_chunks_override = num_chunks def forward(self, lin_weight, _input, target, bias=None): loss, z_loss, token_accuracy, predicted_tokens = LigerFusedLinearCrossEntropyFunction.apply( From ae52b0c680095139ed07cba6f26ad434e6113b9b Mon Sep 17 00:00:00 2001 From: WyldeCat Date: Tue, 14 Apr 2026 08:53:37 +0000 Subject: [PATCH 4/4] fix: add missing None return for num_chunks_override in backward Co-Authored-By: Claude Opus 4.6 (1M context) --- src/liger_kernel/ops/fused_linear_cross_entropy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 5dd9f46f0..513c62a39 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -409,4 +409,5 @@ def backward(ctx, grad_output, grad_output2, grad_output3, grad_output4): None, # use_token_scaling None, # return_token_accuracy None, # return_predicted_tokens + None, # num_chunks_override )