Skip to content

Commit ad16d6f

Browse files
authored
Add squeeze_and_excitation_net kernel (#870)
1 parent 2572047 commit ad16d6f

File tree

3 files changed

+683
-0
lines changed

3 files changed

+683
-0
lines changed
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
"""
2+
Helion squeeze and excitation net Example
3+
============================
4+
This example demonstrates a Helion kernel implementation of squeeze and excitation
5+
net as those used in https://arxiv.org/abs/1709.01507.
6+
"""
7+
8+
# %%
9+
from __future__ import annotations
10+
11+
import torch
12+
from torch import Tensor
13+
14+
import helion
15+
from helion._testing import DEVICE
16+
from helion._testing import run_example
17+
import helion.language as hl
18+
19+
20+
# %%
21+
@helion.kernel(
22+
# static_shapes=True gives a performance boost for matmuls
23+
static_shapes=True,
24+
)
25+
def squeeze_and_excitation_net_fwd(
26+
x: Tensor, a: Tensor, b: Tensor
27+
) -> tuple[Tensor, Tensor, Tensor]:
28+
"""
29+
Performs torch.mul(x, torch.sigmoid(torch.relu((x @ a)) @ b))
30+
Args:
31+
x: 2D tensor of shape [m, n].
32+
a: 2D tensor of shape [n, k].
33+
b: 2D tensor of shape [k, n].
34+
Returns:
35+
out: Resulting matrix of shape [m, n].
36+
c = torch.relu(x @ a) of shape [m, k].
37+
d = torch.sigmoid(c @ b) of shape [m, n].
38+
"""
39+
m, n = x.size()
40+
k = a.size(1)
41+
42+
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
43+
c = torch.empty([m, k], dtype=x.dtype, device=x.device)
44+
d = torch.empty([m, n], dtype=x.dtype, device=x.device)
45+
46+
for tile_m in hl.tile(m):
47+
# Compute c = relu(x @ a) for this tile_m
48+
for tile_k in hl.tile(k):
49+
partial_xa = x[tile_m, :] @ a[:, tile_k]
50+
c[tile_m, tile_k] = torch.relu(partial_xa)
51+
52+
# Compute d = sigmoid(c @ b) and out = x * d for this tile_m
53+
for tile_n in hl.tile(n):
54+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
55+
for tile_k in hl.tile(k):
56+
acc += c[tile_m, tile_k] @ b[tile_k, tile_n]
57+
d[tile_m, tile_n] = torch.sigmoid(acc)
58+
out[tile_m, tile_n] = x[tile_m, tile_n] * d[tile_m, tile_n]
59+
60+
return out, c, d
61+
62+
63+
# %%
64+
@helion.kernel(static_shapes=True)
65+
def squeeze_and_excitation_net_bwd_dx(
66+
grad_out: Tensor, x: Tensor, a: Tensor, b: Tensor, c: Tensor, d: Tensor
67+
) -> Tensor:
68+
"""
69+
Compute grad_x for the squeeze and excitation network.
70+
grad_x = grad_out * d + (grad_out * x * d * (1-d) @ b.T * (c>0)) @ a.T
71+
72+
The computation is structured to properly accumulate over the k dimension:
73+
1. First term: grad_out * d (element-wise, no reduction)
74+
2. Second term: chain rule through d->c->x path
75+
- For each output position (m, n), accumulate over k dimension
76+
- grad_c[m,k] = (grad_out * x * d * (1-d))[m,:] @ b[k,:].T * (c[m,k] > 0)
77+
- grad_x[m,n] += grad_c[m,k] @ a[n,k].T
78+
"""
79+
m, n = x.size()
80+
k = a.size(1)
81+
82+
grad_x = torch.empty([m, n], dtype=x.dtype, device=x.device)
83+
84+
# Compute grad_x: grad_out * d + second_term where second_term accumulates over k
85+
for tile_m, tile_n in hl.tile([m, n]):
86+
# First term: grad_out * d (element-wise)
87+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
88+
acc += grad_out[tile_m, tile_n] * d[tile_m, tile_n]
89+
90+
# Second term: accumulate gradient chain over k dimension
91+
for tile_k in hl.tile(k):
92+
# Compute grad_to_d for the full row: shape [tile_m, n]
93+
grad_to_d = (
94+
grad_out[tile_m, :] * x[tile_m, :] * d[tile_m, :] * (1.0 - d[tile_m, :])
95+
)
96+
97+
# Backprop through (c @ b): grad_c = grad_to_d @ b.T
98+
# [tile_m, n] @ [n, tile_k] = [tile_m, tile_k]
99+
grad_to_c = grad_to_d @ b[tile_k, :].T
100+
101+
# Apply ReLU mask: shape [tile_m, tile_k]
102+
grad_c_masked = grad_to_c * (c[tile_m, tile_k] > 0)
103+
104+
# Backprop through (x @ a): grad_x_contribution = grad_c_masked @ a.T
105+
# [tile_m, tile_k] @ [tile_k, tile_n] = [tile_m, tile_n]
106+
acc += grad_c_masked @ a[tile_n, tile_k].T
107+
108+
grad_x[tile_m, tile_n] = acc
109+
110+
return grad_x
111+
112+
113+
# %%
114+
@helion.kernel(static_shapes=True)
115+
def squeeze_and_excitation_net_bwd_da(
116+
grad_out: Tensor, x: Tensor, b: Tensor, c: Tensor, d: Tensor
117+
) -> Tensor:
118+
"""
119+
Compute grad_a for the squeeze and excitation network.
120+
grad_a = x.T @ (grad_out * x * d * (1-d) @ b.T * (c>0))
121+
"""
122+
m, n = x.size()
123+
k = c.size(1)
124+
125+
grad_a = torch.empty([n, k], dtype=x.dtype, device=x.device)
126+
127+
# Compute grad_a: x.T @ grad_c
128+
for tile_n, tile_k in hl.tile([n, k]):
129+
acc_a = hl.zeros([tile_n, tile_k], dtype=torch.float32)
130+
for tile_m in hl.tile(m):
131+
# Backprop through sigmoid: need full row for matmul with b.T
132+
grad_to_d = grad_out[tile_m, :] * x[tile_m, :]
133+
grad_to_cb = grad_to_d * d[tile_m, :] * (1.0 - d[tile_m, :])
134+
# Backprop through c @ b: [tile_m, n] @ [n, tile_k] = [tile_m, tile_k]
135+
grad_to_c = grad_to_cb @ b[tile_k, :].T
136+
# Backprop through relu
137+
grad_through_relu = grad_to_c * (c[tile_m, tile_k] > 0)
138+
# Accumulate x.T @ grad_c: [tile_n, tile_m] @ [tile_m, tile_k] = [tile_n, tile_k]
139+
acc_a += x[tile_m, tile_n].T @ grad_through_relu
140+
grad_a[tile_n, tile_k] = acc_a
141+
142+
return grad_a
143+
144+
145+
# %%
146+
@helion.kernel(static_shapes=True)
147+
def squeeze_and_excitation_net_bwd_db(
148+
grad_out: Tensor, x: Tensor, d: Tensor, c: Tensor
149+
) -> Tensor:
150+
"""
151+
Compute grad_b by fusing grad_d computation inline.
152+
grad_b = c.T @ (grad_out * x * d * (1 - d))
153+
"""
154+
m, n = grad_out.size()
155+
k = c.size(1)
156+
grad_b = torch.empty([k, n], dtype=grad_out.dtype, device=grad_out.device)
157+
158+
for tile_k, tile_n in hl.tile([k, n]):
159+
acc = hl.zeros([tile_k, tile_n], dtype=torch.float32)
160+
for tile_m in hl.tile(m):
161+
grad_d = (
162+
grad_out[tile_m, tile_n]
163+
* x[tile_m, tile_n]
164+
* d[tile_m, tile_n]
165+
* (1.0 - d[tile_m, tile_n])
166+
)
167+
acc += c[tile_m, tile_k].T @ grad_d
168+
grad_b[tile_k, tile_n] = acc
169+
170+
return grad_b
171+
172+
173+
# %%
174+
# Reference Implementation
175+
# --------------------
176+
def squeeze_and_excitation_net_pytorch(
177+
x: torch.Tensor, a: torch.Tensor, b: torch.Tensor
178+
) -> torch.Tensor:
179+
"""
180+
PyTorch reference implementation of squeeze_and_excitation_net.
181+
182+
Args:
183+
x, a, b: Input tensors
184+
185+
Returns:
186+
tensor of torch.mul(x, torch.sigmoid(torch.relu((x @ a)) @ b))
187+
"""
188+
return torch.mul(x, torch.sigmoid(torch.relu(x @ a) @ b))
189+
190+
191+
# %%
192+
# Autograd Function
193+
# ------------------
194+
class SqueezeAndExcitationNetFunction(torch.autograd.Function):
195+
@staticmethod
196+
def forward( # type: ignore[override]
197+
ctx: object,
198+
x: torch.Tensor,
199+
a: torch.Tensor,
200+
b: torch.Tensor,
201+
) -> torch.Tensor:
202+
"""Forward pass for squeeze and excitation network."""
203+
out, c, d = squeeze_and_excitation_net_fwd(x, a, b)
204+
ctx.save_for_backward(x, a, b, c, d) # type: ignore[attr-defined]
205+
return out
206+
207+
@staticmethod
208+
def backward( # type: ignore[override]
209+
ctx: object,
210+
grad_out: torch.Tensor,
211+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
212+
"""Backward pass for squeeze and excitation network."""
213+
x, a, b, c, d = ctx.saved_tensors # type: ignore[attr-defined]
214+
215+
grad_x = squeeze_and_excitation_net_bwd_dx(grad_out, x, a, b, c, d)
216+
grad_a = squeeze_and_excitation_net_bwd_da(grad_out, x, b, c, d)
217+
grad_b = squeeze_and_excitation_net_bwd_db(grad_out, x, d, c)
218+
return grad_x, grad_a, grad_b
219+
220+
221+
def squeeze_and_excitation_net(
222+
x: torch.Tensor, a: torch.Tensor, b: torch.Tensor
223+
) -> torch.Tensor:
224+
"""
225+
Squeeze and excitation network with autograd support.
226+
227+
Args:
228+
x: Input tensor [m, n]
229+
a: Weight matrix [n, k]
230+
b: Weight matrix [k, n]
231+
232+
Returns:
233+
Output tensor [m, n]
234+
"""
235+
return SqueezeAndExcitationNetFunction.apply(x, a, b) # type: ignore[no-any-return]
236+
237+
238+
def check(m: int, k: int, n: int) -> None:
239+
"""
240+
Checks the correctness against PyTorch.
241+
Args:
242+
m (int): Number of rows in matrix x.
243+
n (int): Number of columns in matrix x.
244+
k (int): Number of columns in matrix a.
245+
"""
246+
x = torch.randn([m, n], device=DEVICE, dtype=torch.float16, requires_grad=True)
247+
a = torch.randn([n, k], device=DEVICE, dtype=torch.float16, requires_grad=True)
248+
b = torch.randn([k, n], device=DEVICE, dtype=torch.float16, requires_grad=True)
249+
for bwd in [True, False]:
250+
run_example(
251+
squeeze_and_excitation_net,
252+
squeeze_and_excitation_net_pytorch,
253+
(x, a, b),
254+
bwd=bwd,
255+
)
256+
257+
258+
# %%
259+
def main() -> None:
260+
"""
261+
Main function to run correctness checks.
262+
"""
263+
check(1024, 1024, 1024)
264+
265+
266+
# %%
267+
if __name__ == "__main__":
268+
main()

0 commit comments

Comments
 (0)