Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def fused_linear_cross_entropy_forward(
use_token_scaling=False,
return_token_accuracy=False,
return_predicted_tokens=False,
num_chunks_override=None,
):
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
assert isinstance(return_token_accuracy, bool), (
Expand All @@ -53,9 +54,16 @@ def fused_linear_cross_entropy_forward(
V = weight.shape[0]
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))

inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
if num_chunks_override is not None:
assert BT % num_chunks_override == 0, (
f"num_chunks_override={num_chunks_override} must evenly divide BT={BT}"
)
chunk_size = triton.next_power_of_2(max(1, BT // num_chunks_override))
num_chunks = triton.cdiv(BT, chunk_size)
else:
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size

grad_input = torch.zeros_like(_input, device=device)

Expand Down Expand Up @@ -219,6 +227,8 @@ def fused_linear_cross_entropy_forward(
alpha=1.0,
)

del logits_chunk, grad_logits_chunk, _input_chunk

# Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
# if reduction == "none":
# loss = loss_1d
Expand Down Expand Up @@ -311,6 +321,7 @@ def forward(
use_token_scaling: bool = False,
return_token_accuracy: bool = False,
return_predicted_tokens: bool = False,
num_chunks_override=None,
):
"""
Fusing the last linear layer with cross-entropy loss
Expand Down Expand Up @@ -355,6 +366,7 @@ def forward(
use_token_scaling=use_token_scaling,
return_token_accuracy=return_token_accuracy,
return_predicted_tokens=return_predicted_tokens,
num_chunks_override=num_chunks_override,
)
)
# downcast to dtype and store for backward
Expand Down Expand Up @@ -397,4 +409,5 @@ def backward(ctx, grad_output, grad_output2, grad_output3, grad_output4):
None, # use_token_scaling
None, # return_token_accuracy
None, # return_predicted_tokens
None, # num_chunks_override
)
3 changes: 3 additions & 0 deletions src/liger_kernel/transformers/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
use_token_scaling: bool = False,
return_token_accuracy: bool = False,
return_predicted_tokens: bool = False,
num_chunks: Optional[int] = None,
):
super().__init__()
assert (label_smoothing >= 0) and (label_smoothing <= 1), (
Expand All @@ -42,6 +43,7 @@ def __init__(
self.use_token_scaling = use_token_scaling
self.return_token_accuracy = return_token_accuracy
self.return_predicted_tokens = return_predicted_tokens
self.num_chunks_override = num_chunks

def forward(self, lin_weight, _input, target, bias=None):
loss, z_loss, token_accuracy, predicted_tokens = LigerFusedLinearCrossEntropyFunction.apply(
Expand All @@ -60,6 +62,7 @@ def forward(self, lin_weight, _input, target, bias=None):
self.use_token_scaling,
self.return_token_accuracy,
self.return_predicted_tokens,
self.num_chunks_override,
)
if not self.return_z_loss and not self.return_token_accuracy and not self.return_predicted_tokens:
return loss
Expand Down
Loading