diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 01f1b5658..513c62a39 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -30,6 +30,7 @@ def fused_linear_cross_entropy_forward( use_token_scaling=False, return_token_accuracy=False, return_predicted_tokens=False, + num_chunks_override=None, ): assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" assert isinstance(return_token_accuracy, bool), ( @@ -53,9 +54,16 @@ def fused_linear_cross_entropy_forward( V = weight.shape[0] BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) - inc_factor = triton.cdiv(V, H) # (V + H - 1) // H - chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor - num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size + if num_chunks_override is not None: + assert BT % num_chunks_override == 0, ( + f"num_chunks_override={num_chunks_override} must evenly divide BT={BT}" + ) + chunk_size = triton.next_power_of_2(max(1, BT // num_chunks_override)) + num_chunks = triton.cdiv(BT, chunk_size) + else: + inc_factor = triton.cdiv(V, H) # (V + H - 1) // H + chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor + num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size grad_input = torch.zeros_like(_input, device=device) @@ -219,6 +227,8 @@ def fused_linear_cross_entropy_forward( alpha=1.0, ) + del logits_chunk, grad_logits_chunk, _input_chunk + # Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now. # if reduction == "none": # loss = loss_1d @@ -311,6 +321,7 @@ def forward( use_token_scaling: bool = False, return_token_accuracy: bool = False, return_predicted_tokens: bool = False, + num_chunks_override=None, ): """ Fusing the last linear layer with cross-entropy loss @@ -355,6 +366,7 @@ def forward( use_token_scaling=use_token_scaling, return_token_accuracy=return_token_accuracy, return_predicted_tokens=return_predicted_tokens, + num_chunks_override=num_chunks_override, ) ) # downcast to dtype and store for backward @@ -397,4 +409,5 @@ def backward(ctx, grad_output, grad_output2, grad_output3, grad_output4): None, # use_token_scaling None, # return_token_accuracy None, # return_predicted_tokens + None, # num_chunks_override ) diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py index c4a4474ce..ba365f51f 100644 --- a/src/liger_kernel/transformers/fused_linear_cross_entropy.py +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -20,6 +20,7 @@ def __init__( use_token_scaling: bool = False, return_token_accuracy: bool = False, return_predicted_tokens: bool = False, + num_chunks: Optional[int] = None, ): super().__init__() assert (label_smoothing >= 0) and (label_smoothing <= 1), ( @@ -42,6 +43,7 @@ def __init__( self.use_token_scaling = use_token_scaling self.return_token_accuracy = return_token_accuracy self.return_predicted_tokens = return_predicted_tokens + self.num_chunks_override = num_chunks def forward(self, lin_weight, _input, target, bias=None): loss, z_loss, token_accuracy, predicted_tokens = LigerFusedLinearCrossEntropyFunction.apply( @@ -60,6 +62,7 @@ def forward(self, lin_weight, _input, target, bias=None): self.use_token_scaling, self.return_token_accuracy, self.return_predicted_tokens, + self.num_chunks_override, ) if not self.return_z_loss and not self.return_token_accuracy and not self.return_predicted_tokens: return loss