Skip to content

Commit f25a771

Browse files
authored
[example] fused_linear_jsd (#494)
1 parent 5dd2ae3 commit f25a771

File tree

4 files changed

+263
-0
lines changed

4 files changed

+263
-0
lines changed

benchmarks/run.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,11 @@ class RunResult:
187187
("examples.grouped_gemm", "grouped_gemm_jagged_persistent_tritonbench"),
188188
],
189189
),
190+
"fused_linear_jsd": (
191+
"tritonbench.operators.fused_linear_jsd.operator",
192+
"examples.fused_linear_jsd",
193+
"fused_linear_jsd_fwd_tritonbench",
194+
),
190195
# Multiple kernel variants:
191196
"gemm": (
192197
"tritonbench.operators.gemm.operator",

examples/fused_linear_jsd.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
"""
2+
Fused Linear JSD Example
3+
===========================
4+
5+
This example demonstrates how to implement a JSD kernel using Helion and
6+
fuse it with a linear layer.
7+
"""
8+
9+
# %%
10+
# Imports
11+
# -------
12+
from __future__ import annotations
13+
14+
from typing import Callable
15+
16+
import torch
17+
18+
import helion
19+
from helion._testing import run_example
20+
import helion.language as hl
21+
22+
23+
# %%
24+
# Helion Kernel
25+
# -------------------
26+
@helion.kernel()
27+
def fused_linear_jsd_kernel(
28+
beta: float,
29+
ignore_index: int,
30+
temperature: float,
31+
student_logits: torch.Tensor,
32+
teacher_logits: torch.Tensor,
33+
) -> torch.Tensor:
34+
loss = student_logits.new_empty(student_logits.shape[0], dtype=torch.float)
35+
for batch in hl.tile(student_logits.shape[0]):
36+
student_prob = torch.log_softmax(student_logits[batch, :] / temperature, dim=-1)
37+
teacher_prob = torch.log_softmax(teacher_logits[batch, :] / temperature, dim=-1)
38+
student_prob = student_prob.to(torch.float).view(-1, student_prob.size(-1))
39+
teacher_prob = teacher_prob.to(torch.float).view(-1, teacher_prob.size(-1))
40+
m = torch.exp(student_prob) + beta * (
41+
torch.exp(teacher_prob) - torch.exp(student_prob)
42+
)
43+
teacher_div = torch.nn.functional.kl_div(
44+
torch.log(m), teacher_prob, reduction="none", log_target=True
45+
).sum(dim=-1)
46+
student_div = torch.nn.functional.kl_div(
47+
torch.log(m), student_prob, reduction="none", log_target=True
48+
).sum(dim=-1)
49+
batch_loss = student_div + beta * (teacher_div - student_div)
50+
loss[batch] = batch_loss
51+
return (loss / student_logits.shape[0]).sum()
52+
53+
54+
def fused_linear_jsd_fwd(
55+
beta: float,
56+
ignore_index: int,
57+
temperature: float,
58+
student_weight: torch.Tensor,
59+
teacher_weight: torch.Tensor,
60+
student_input: torch.Tensor,
61+
teacher_input: torch.Tensor,
62+
) -> torch.Tensor:
63+
student_logits = student_input @ student_weight.T
64+
teacher_logits = teacher_input @ teacher_weight.T
65+
return fused_linear_jsd_kernel(
66+
beta, ignore_index, temperature, student_logits, teacher_logits
67+
)
68+
69+
70+
# %%
71+
# Benchmark Entry Point Function
72+
# -------------------
73+
def fused_linear_jsd_fwd_tritonbench(
74+
tb_op: object,
75+
student_input: torch.Tensor,
76+
teacher_input: torch.Tensor,
77+
label: torch.Tensor | None = None,
78+
) -> Callable[[], torch.Tensor]:
79+
assert label is None
80+
baseline_op = tb_op.baseline_op # pyright: ignore[reportAttributeAccessIssue]
81+
beta = baseline_op.jsd.beta
82+
ignore_index = baseline_op.jsd.ignore_index
83+
temperature = baseline_op.temperature
84+
student_weight = baseline_op.student_lin.weight
85+
teacher_weight = baseline_op.teacher_lin.weight
86+
return lambda: fused_linear_jsd_fwd(
87+
beta,
88+
ignore_index,
89+
temperature,
90+
student_weight,
91+
teacher_weight,
92+
student_input,
93+
teacher_input,
94+
)
95+
96+
97+
# %%
98+
# Reference Implementation
99+
# --------------------
100+
def fused_linear_jsd_pytorch(
101+
beta: float,
102+
ignore_index: int,
103+
temperature: float,
104+
student_weight: torch.Tensor,
105+
teacher_weight: torch.Tensor,
106+
student_input: torch.Tensor,
107+
teacher_input: torch.Tensor,
108+
) -> torch.Tensor:
109+
student_logits = student_input @ student_weight.T
110+
teacher_logits = teacher_input @ teacher_weight.T
111+
student_prob = torch.log_softmax(student_logits / temperature, dim=-1)
112+
teacher_prob = torch.log_softmax(teacher_logits / temperature, dim=-1)
113+
student_prob = student_prob.to(torch.float).view(-1, student_prob.size(-1))
114+
teacher_prob = teacher_prob.to(torch.float).view(-1, teacher_prob.size(-1))
115+
m = torch.exp(student_prob) + beta * (
116+
torch.exp(teacher_prob) - torch.exp(student_prob)
117+
)
118+
teacher_div = torch.nn.functional.kl_div(
119+
torch.log(m), teacher_prob, reduction="none", log_target=True
120+
).sum(dim=-1)
121+
student_div = torch.nn.functional.kl_div(
122+
torch.log(m), student_prob, reduction="none", log_target=True
123+
).sum(dim=-1)
124+
loss = student_div + beta * (teacher_div - student_div)
125+
return (loss / student_logits.shape[0]).sum()
126+
127+
128+
# %%
129+
# Verification Function
130+
# -------------------
131+
def check(m: int, n: int, k: int) -> None:
132+
student_input = torch.rand([m, n], device="cuda", dtype=torch.float)
133+
teacher_input = torch.rand([m, n], device="cuda", dtype=torch.float)
134+
student_weight = torch.rand([k, n], device="cuda", dtype=torch.float)
135+
teacher_weight = torch.rand([k, n], device="cuda", dtype=torch.float)
136+
run_example(
137+
fused_linear_jsd_fwd,
138+
fused_linear_jsd_pytorch,
139+
(0.5, 1, 1.0, student_weight, teacher_weight, student_input, teacher_input),
140+
)
141+
142+
143+
# %%
144+
# Main Function
145+
# -----------
146+
def main() -> None:
147+
check(1024, 4096, 128256)
148+
149+
150+
if __name__ == "__main__":
151+
main()

test/test_examples.expected

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,76 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
841841
_launcher(_helion_fp8_gemm, (triton.cdiv(256, _BLOCK_SIZE_0) * triton.cdiv(256, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
842842
return out
843843

844+
--- assertExpectedJournal(TestExamples.test_fused_linear_jsd)
845+
from __future__ import annotations
846+
847+
import torch
848+
import triton
849+
import triton.language as tl
850+
from torch._inductor.runtime.triton_helpers import math as tl_math
851+
from torch._inductor.runtime.triton_compat import libdevice
852+
from helion.runtime import default_launcher as _default_launcher
853+
854+
@triton.jit
855+
def _helion_fused_linear_jsd_kernel(student_logits, teacher_logits, loss, student_logits_size_0, teacher_logits_size_1, loss_stride_0, student_logits_stride_0, student_logits_stride_1, teacher_logits_stride_0, teacher_logits_stride_1, temperature, beta, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
856+
pid_0 = tl.program_id(0)
857+
offset_0 = pid_0 * _BLOCK_SIZE_0
858+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
859+
mask_0 = indices_0 < student_logits_size_0
860+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
861+
mask_1 = indices_1 < teacher_logits_size_1
862+
load = tl.load(student_logits + (indices_0[:, None] * student_logits_stride_0 + indices_1[None, :] * student_logits_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
863+
v_0 = load / temperature
864+
_mask_to = tl.where(mask_0[:, None] & mask_1[None, :], v_0, tl.full([], float('-inf'), tl.float32))
865+
amax = tl.cast(tl.reshape(tl.max(_mask_to, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
866+
v_1 = v_0 - amax
867+
v_2 = libdevice.exp(v_1)
868+
_mask_to_1 = tl.where(mask_0[:, None] & mask_1[None, :], v_2, tl.full([], 0, tl.float32))
869+
sum_1 = tl.cast(tl.reshape(tl.sum(_mask_to_1, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
870+
v_3 = tl_math.log(sum_1)
871+
v_4 = v_1 - v_3
872+
load_1 = tl.load(teacher_logits + (indices_0[:, None] * teacher_logits_stride_0 + indices_1[None, :] * teacher_logits_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
873+
v_5 = load_1 / temperature
874+
_mask_to_2 = tl.where(mask_0[:, None] & mask_1[None, :], v_5, tl.full([], float('-inf'), tl.float32))
875+
amax_1 = tl.cast(tl.reshape(tl.max(_mask_to_2, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
876+
v_6 = v_5 - amax_1
877+
v_7 = libdevice.exp(v_6)
878+
_mask_to_3 = tl.where(mask_0[:, None] & mask_1[None, :], v_7, tl.full([], 0, tl.float32))
879+
sum_2 = tl.cast(tl.reshape(tl.sum(_mask_to_3, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
880+
v_8 = tl_math.log(sum_2)
881+
v_9 = v_6 - v_8
882+
student_prob_1 = tl.reshape(v_4, [_BLOCK_SIZE_0, _RDIM_SIZE_1])
883+
teacher_prob_1 = tl.reshape(v_9, [_BLOCK_SIZE_0, _RDIM_SIZE_1])
884+
v_10 = libdevice.exp(student_prob_1)
885+
v_11 = libdevice.exp(teacher_prob_1)
886+
v_12 = libdevice.exp(student_prob_1)
887+
v_13 = v_11 - v_12
888+
v_14 = v_13 * beta
889+
v_15 = v_10 + v_14
890+
v_16 = tl_math.log(v_15)
891+
v_17 = teacher_prob_1 - v_16
892+
v_18 = libdevice.exp(teacher_prob_1)
893+
v_19 = v_18 * v_17
894+
_mask_to_4 = tl.where(mask_0[:, None] & mask_1[None, :], v_19, tl.full([], 0, tl.float32))
895+
teacher_div = tl.cast(tl.sum(_mask_to_4, 1), tl.float32)
896+
v_20 = tl_math.log(v_15)
897+
v_21 = student_prob_1 - v_20
898+
v_22 = libdevice.exp(student_prob_1)
899+
v_23 = v_22 * v_21
900+
_mask_to_5 = tl.where(mask_0[:, None] & mask_1[None, :], v_23, tl.full([], 0, tl.float32))
901+
student_div = tl.cast(tl.sum(_mask_to_5, 1), tl.float32)
902+
v_24 = teacher_div - student_div
903+
v_25 = v_24 * beta
904+
v_26 = student_div + v_25
905+
tl.store(loss + indices_0 * loss_stride_0, v_26, mask_0)
906+
907+
def fused_linear_jsd_kernel(beta: float, ignore_index: int, temperature: float, student_logits: torch.Tensor, teacher_logits: torch.Tensor, *, _launcher=_default_launcher):
908+
loss = student_logits.new_empty(student_logits.shape[0], dtype=torch.float)
909+
_BLOCK_SIZE_0 = 32
910+
_RDIM_SIZE_1 = triton.next_power_of_2(teacher_logits.size(1))
911+
_launcher(_helion_fused_linear_jsd_kernel, (triton.cdiv(student_logits.size(0), _BLOCK_SIZE_0),), student_logits, teacher_logits, loss, student_logits.size(0), teacher_logits.size(1), loss.stride(0), student_logits.stride(0), student_logits.stride(1), teacher_logits.stride(0), teacher_logits.stride(1), temperature, beta, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
912+
return (loss / student_logits.shape[0]).sum()
913+
844914
--- assertExpectedJournal(TestExamples.test_gather_gemv)
845915
from __future__ import annotations
846916

test/test_examples.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,6 +1239,43 @@ def test_int4_gemm(self):
12391239
)
12401240
)
12411241

1242+
def test_fused_linear_jsd(self):
1243+
beta = 0.5
1244+
ignore_index = 1
1245+
temperature = 1.0
1246+
m, n, k = 64, 128, 256
1247+
1248+
student_input = torch.randn([m, n], device=DEVICE, dtype=torch.float32)
1249+
teacher_input = torch.randn([m, n], device=DEVICE, dtype=torch.float32)
1250+
student_weight = torch.randn([k, n], device=DEVICE, dtype=torch.float32)
1251+
teacher_weight = torch.randn([k, n], device=DEVICE, dtype=torch.float32)
1252+
student_logits = student_input @ student_weight.T
1253+
teacher_logits = teacher_input @ teacher_weight.T
1254+
1255+
args = (
1256+
beta,
1257+
ignore_index,
1258+
temperature,
1259+
student_logits,
1260+
teacher_logits,
1261+
)
1262+
1263+
# Import and use the reference implementation
1264+
mod = import_path(EXAMPLES_DIR / "fused_linear_jsd.py")
1265+
expected = mod.fused_linear_jsd_pytorch(
1266+
*args[:-2], student_input, teacher_input, student_weight, teacher_weight
1267+
)
1268+
1269+
self.assertExpectedJournal(
1270+
check_example(
1271+
"fused_linear_jsd",
1272+
args,
1273+
expected,
1274+
fn_name="fused_linear_jsd_kernel",
1275+
block_sizes=[32],
1276+
)
1277+
)
1278+
12421279

12431280
if __name__ == "__main__":
12441281
unittest.main()

0 commit comments

Comments
 (0)