|
| 1 | +""" |
| 2 | +Helion squeeze and excitation net Example |
| 3 | +============================ |
| 4 | +This example demonstrates a Helion kernel implementation of squeeze and excitation |
| 5 | +net as those used in https://arxiv.org/abs/1709.01507. |
| 6 | +""" |
| 7 | + |
| 8 | +# %% |
| 9 | +from __future__ import annotations |
| 10 | + |
| 11 | +import torch |
| 12 | +from torch import Tensor |
| 13 | + |
| 14 | +import helion |
| 15 | +from helion._testing import DEVICE |
| 16 | +from helion._testing import run_example |
| 17 | +import helion.language as hl |
| 18 | + |
| 19 | + |
| 20 | +# %% |
| 21 | +@helion.kernel( |
| 22 | + # static_shapes=True gives a performance boost for matmuls |
| 23 | + static_shapes=True, |
| 24 | +) |
| 25 | +def squeeze_and_excitation_net_fwd( |
| 26 | + x: Tensor, a: Tensor, b: Tensor |
| 27 | +) -> tuple[Tensor, Tensor, Tensor]: |
| 28 | + """ |
| 29 | + Performs torch.mul(x, torch.sigmoid(torch.relu((x @ a)) @ b)) |
| 30 | + Args: |
| 31 | + x: 2D tensor of shape [m, n]. |
| 32 | + a: 2D tensor of shape [n, k]. |
| 33 | + b: 2D tensor of shape [k, n]. |
| 34 | + Returns: |
| 35 | + out: Resulting matrix of shape [m, n]. |
| 36 | + c = torch.relu(x @ a) of shape [m, k]. |
| 37 | + d = torch.sigmoid(c @ b) of shape [m, n]. |
| 38 | + """ |
| 39 | + m, n = x.size() |
| 40 | + k = a.size(1) |
| 41 | + |
| 42 | + out = torch.empty([m, n], dtype=x.dtype, device=x.device) |
| 43 | + c = torch.empty([m, k], dtype=x.dtype, device=x.device) |
| 44 | + d = torch.empty([m, n], dtype=x.dtype, device=x.device) |
| 45 | + |
| 46 | + for tile_m in hl.tile(m): |
| 47 | + # Compute c = relu(x @ a) for this tile_m |
| 48 | + for tile_k in hl.tile(k): |
| 49 | + partial_xa = x[tile_m, :] @ a[:, tile_k] |
| 50 | + c[tile_m, tile_k] = torch.relu(partial_xa) |
| 51 | + |
| 52 | + # Compute d = sigmoid(c @ b) and out = x * d for this tile_m |
| 53 | + for tile_n in hl.tile(n): |
| 54 | + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) |
| 55 | + for tile_k in hl.tile(k): |
| 56 | + acc += c[tile_m, tile_k] @ b[tile_k, tile_n] |
| 57 | + d[tile_m, tile_n] = torch.sigmoid(acc) |
| 58 | + out[tile_m, tile_n] = x[tile_m, tile_n] * d[tile_m, tile_n] |
| 59 | + |
| 60 | + return out, c, d |
| 61 | + |
| 62 | + |
| 63 | +# %% |
| 64 | +@helion.kernel(static_shapes=True) |
| 65 | +def squeeze_and_excitation_net_bwd_dx( |
| 66 | + grad_out: Tensor, x: Tensor, a: Tensor, b: Tensor, c: Tensor, d: Tensor |
| 67 | +) -> Tensor: |
| 68 | + """ |
| 69 | + Compute grad_x for the squeeze and excitation network. |
| 70 | + grad_x = grad_out * d + (grad_out * x * d * (1-d) @ b.T * (c>0)) @ a.T |
| 71 | +
|
| 72 | + The computation is structured to properly accumulate over the k dimension: |
| 73 | + 1. First term: grad_out * d (element-wise, no reduction) |
| 74 | + 2. Second term: chain rule through d->c->x path |
| 75 | + - For each output position (m, n), accumulate over k dimension |
| 76 | + - grad_c[m,k] = (grad_out * x * d * (1-d))[m,:] @ b[k,:].T * (c[m,k] > 0) |
| 77 | + - grad_x[m,n] += grad_c[m,k] @ a[n,k].T |
| 78 | + """ |
| 79 | + m, n = x.size() |
| 80 | + k = a.size(1) |
| 81 | + |
| 82 | + grad_x = torch.empty([m, n], dtype=x.dtype, device=x.device) |
| 83 | + |
| 84 | + # Compute grad_x: grad_out * d + second_term where second_term accumulates over k |
| 85 | + for tile_m, tile_n in hl.tile([m, n]): |
| 86 | + # First term: grad_out * d (element-wise) |
| 87 | + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) |
| 88 | + acc += grad_out[tile_m, tile_n] * d[tile_m, tile_n] |
| 89 | + |
| 90 | + # Second term: accumulate gradient chain over k dimension |
| 91 | + for tile_k in hl.tile(k): |
| 92 | + # Compute grad_to_d for the full row: shape [tile_m, n] |
| 93 | + grad_to_d = ( |
| 94 | + grad_out[tile_m, :] * x[tile_m, :] * d[tile_m, :] * (1.0 - d[tile_m, :]) |
| 95 | + ) |
| 96 | + |
| 97 | + # Backprop through (c @ b): grad_c = grad_to_d @ b.T |
| 98 | + # [tile_m, n] @ [n, tile_k] = [tile_m, tile_k] |
| 99 | + grad_to_c = grad_to_d @ b[tile_k, :].T |
| 100 | + |
| 101 | + # Apply ReLU mask: shape [tile_m, tile_k] |
| 102 | + grad_c_masked = grad_to_c * (c[tile_m, tile_k] > 0) |
| 103 | + |
| 104 | + # Backprop through (x @ a): grad_x_contribution = grad_c_masked @ a.T |
| 105 | + # [tile_m, tile_k] @ [tile_k, tile_n] = [tile_m, tile_n] |
| 106 | + acc += grad_c_masked @ a[tile_n, tile_k].T |
| 107 | + |
| 108 | + grad_x[tile_m, tile_n] = acc |
| 109 | + |
| 110 | + return grad_x |
| 111 | + |
| 112 | + |
| 113 | +# %% |
| 114 | +@helion.kernel(static_shapes=True) |
| 115 | +def squeeze_and_excitation_net_bwd_da( |
| 116 | + grad_out: Tensor, x: Tensor, b: Tensor, c: Tensor, d: Tensor |
| 117 | +) -> Tensor: |
| 118 | + """ |
| 119 | + Compute grad_a for the squeeze and excitation network. |
| 120 | + grad_a = x.T @ (grad_out * x * d * (1-d) @ b.T * (c>0)) |
| 121 | + """ |
| 122 | + m, n = x.size() |
| 123 | + k = c.size(1) |
| 124 | + |
| 125 | + grad_a = torch.empty([n, k], dtype=x.dtype, device=x.device) |
| 126 | + |
| 127 | + # Compute grad_a: x.T @ grad_c |
| 128 | + for tile_n, tile_k in hl.tile([n, k]): |
| 129 | + acc_a = hl.zeros([tile_n, tile_k], dtype=torch.float32) |
| 130 | + for tile_m in hl.tile(m): |
| 131 | + # Backprop through sigmoid: need full row for matmul with b.T |
| 132 | + grad_to_d = grad_out[tile_m, :] * x[tile_m, :] |
| 133 | + grad_to_cb = grad_to_d * d[tile_m, :] * (1.0 - d[tile_m, :]) |
| 134 | + # Backprop through c @ b: [tile_m, n] @ [n, tile_k] = [tile_m, tile_k] |
| 135 | + grad_to_c = grad_to_cb @ b[tile_k, :].T |
| 136 | + # Backprop through relu |
| 137 | + grad_through_relu = grad_to_c * (c[tile_m, tile_k] > 0) |
| 138 | + # Accumulate x.T @ grad_c: [tile_n, tile_m] @ [tile_m, tile_k] = [tile_n, tile_k] |
| 139 | + acc_a += x[tile_m, tile_n].T @ grad_through_relu |
| 140 | + grad_a[tile_n, tile_k] = acc_a |
| 141 | + |
| 142 | + return grad_a |
| 143 | + |
| 144 | + |
| 145 | +# %% |
| 146 | +@helion.kernel(static_shapes=True) |
| 147 | +def squeeze_and_excitation_net_bwd_db( |
| 148 | + grad_out: Tensor, x: Tensor, d: Tensor, c: Tensor |
| 149 | +) -> Tensor: |
| 150 | + """ |
| 151 | + Compute grad_b by fusing grad_d computation inline. |
| 152 | + grad_b = c.T @ (grad_out * x * d * (1 - d)) |
| 153 | + """ |
| 154 | + m, n = grad_out.size() |
| 155 | + k = c.size(1) |
| 156 | + grad_b = torch.empty([k, n], dtype=grad_out.dtype, device=grad_out.device) |
| 157 | + |
| 158 | + for tile_k, tile_n in hl.tile([k, n]): |
| 159 | + acc = hl.zeros([tile_k, tile_n], dtype=torch.float32) |
| 160 | + for tile_m in hl.tile(m): |
| 161 | + grad_d = ( |
| 162 | + grad_out[tile_m, tile_n] |
| 163 | + * x[tile_m, tile_n] |
| 164 | + * d[tile_m, tile_n] |
| 165 | + * (1.0 - d[tile_m, tile_n]) |
| 166 | + ) |
| 167 | + acc += c[tile_m, tile_k].T @ grad_d |
| 168 | + grad_b[tile_k, tile_n] = acc |
| 169 | + |
| 170 | + return grad_b |
| 171 | + |
| 172 | + |
| 173 | +# %% |
| 174 | +# Reference Implementation |
| 175 | +# -------------------- |
| 176 | +def squeeze_and_excitation_net_pytorch( |
| 177 | + x: torch.Tensor, a: torch.Tensor, b: torch.Tensor |
| 178 | +) -> torch.Tensor: |
| 179 | + """ |
| 180 | + PyTorch reference implementation of squeeze_and_excitation_net. |
| 181 | +
|
| 182 | + Args: |
| 183 | + x, a, b: Input tensors |
| 184 | +
|
| 185 | + Returns: |
| 186 | + tensor of torch.mul(x, torch.sigmoid(torch.relu((x @ a)) @ b)) |
| 187 | + """ |
| 188 | + return torch.mul(x, torch.sigmoid(torch.relu(x @ a) @ b)) |
| 189 | + |
| 190 | + |
| 191 | +# %% |
| 192 | +# Autograd Function |
| 193 | +# ------------------ |
| 194 | +class SqueezeAndExcitationNetFunction(torch.autograd.Function): |
| 195 | + @staticmethod |
| 196 | + def forward( # type: ignore[override] |
| 197 | + ctx: object, |
| 198 | + x: torch.Tensor, |
| 199 | + a: torch.Tensor, |
| 200 | + b: torch.Tensor, |
| 201 | + ) -> torch.Tensor: |
| 202 | + """Forward pass for squeeze and excitation network.""" |
| 203 | + out, c, d = squeeze_and_excitation_net_fwd(x, a, b) |
| 204 | + ctx.save_for_backward(x, a, b, c, d) # type: ignore[attr-defined] |
| 205 | + return out |
| 206 | + |
| 207 | + @staticmethod |
| 208 | + def backward( # type: ignore[override] |
| 209 | + ctx: object, |
| 210 | + grad_out: torch.Tensor, |
| 211 | + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 212 | + """Backward pass for squeeze and excitation network.""" |
| 213 | + x, a, b, c, d = ctx.saved_tensors # type: ignore[attr-defined] |
| 214 | + |
| 215 | + grad_x = squeeze_and_excitation_net_bwd_dx(grad_out, x, a, b, c, d) |
| 216 | + grad_a = squeeze_and_excitation_net_bwd_da(grad_out, x, b, c, d) |
| 217 | + grad_b = squeeze_and_excitation_net_bwd_db(grad_out, x, d, c) |
| 218 | + return grad_x, grad_a, grad_b |
| 219 | + |
| 220 | + |
| 221 | +def squeeze_and_excitation_net( |
| 222 | + x: torch.Tensor, a: torch.Tensor, b: torch.Tensor |
| 223 | +) -> torch.Tensor: |
| 224 | + """ |
| 225 | + Squeeze and excitation network with autograd support. |
| 226 | +
|
| 227 | + Args: |
| 228 | + x: Input tensor [m, n] |
| 229 | + a: Weight matrix [n, k] |
| 230 | + b: Weight matrix [k, n] |
| 231 | +
|
| 232 | + Returns: |
| 233 | + Output tensor [m, n] |
| 234 | + """ |
| 235 | + return SqueezeAndExcitationNetFunction.apply(x, a, b) # type: ignore[no-any-return] |
| 236 | + |
| 237 | + |
| 238 | +def check(m: int, k: int, n: int) -> None: |
| 239 | + """ |
| 240 | + Checks the correctness against PyTorch. |
| 241 | + Args: |
| 242 | + m (int): Number of rows in matrix x. |
| 243 | + n (int): Number of columns in matrix x. |
| 244 | + k (int): Number of columns in matrix a. |
| 245 | + """ |
| 246 | + x = torch.randn([m, n], device=DEVICE, dtype=torch.float16, requires_grad=True) |
| 247 | + a = torch.randn([n, k], device=DEVICE, dtype=torch.float16, requires_grad=True) |
| 248 | + b = torch.randn([k, n], device=DEVICE, dtype=torch.float16, requires_grad=True) |
| 249 | + for bwd in [True, False]: |
| 250 | + run_example( |
| 251 | + squeeze_and_excitation_net, |
| 252 | + squeeze_and_excitation_net_pytorch, |
| 253 | + (x, a, b), |
| 254 | + bwd=bwd, |
| 255 | + ) |
| 256 | + |
| 257 | + |
| 258 | +# %% |
| 259 | +def main() -> None: |
| 260 | + """ |
| 261 | + Main function to run correctness checks. |
| 262 | + """ |
| 263 | + check(1024, 1024, 1024) |
| 264 | + |
| 265 | + |
| 266 | +# %% |
| 267 | +if __name__ == "__main__": |
| 268 | + main() |
0 commit comments