@@ -3689,45 +3689,43 @@ def _helion_layer_norm_bwd(weight, x, grad_out, mean, rstd, grad_x, grad_weight_
36893689 load_2 = tl.load(grad_out + (indices_1[:, None] * 64 + indices_3[None, :] * 1), None)
36903690 v_2 = tl.cast(load_2, tl.float32)
36913691 # src[layer_norm.py:N]: mean_mb = mean[mb].to(torch.float32)
3692- load_3 = tl.load(mean + indices_1 * 1, None)
3693- v_3 = tl.cast(load_3, tl.float32)
3692+ mean_mb = tl.load(mean + indices_1 * 1, None)
36943693 # src[layer_norm.py:N]: rstd_mb = rstd[mb].to(torch.float32)
3695- load_4 = tl.load(rstd + indices_1 * 1, None)
3696- v_4 = tl.cast(load_4, tl.float32)
3694+ rstd_mb = tl.load(rstd + indices_1 * 1, None)
36973695 # src[layer_norm.py:N]: x_hat = (x_mb - mean_mb[:, None]) * rstd_mb[:, None]
3698- subscript = v_3 [:, None]
3699- v_5 = v_1 - subscript
3700- subscript_1 = v_4 [:, None]
3701- v_6 = v_5 * subscript_1
3696+ subscript = mean_mb [:, None]
3697+ v_3 = v_1 - subscript
3698+ subscript_1 = rstd_mb [:, None]
3699+ v_4 = v_3 * subscript_1
37023700 # src[layer_norm.py:N]: grad_w_acc += torch.sum(dy_mb * x_hat, dim=0)
3703- v_7 = v_2 * v_6
3704- sum_1 = tl.cast(tl.sum(v_7 , 0), tl.float32)
3701+ v_5 = v_2 * v_4
3702+ sum_1 = tl.cast(tl.sum(v_5 , 0), tl.float32)
37053703 grad_w_acc = grad_w_acc_copy_0 + sum_1
37063704 # src[layer_norm.py:N]: grad_b_acc += torch.sum(dy_mb, dim=0) # pyright: ignore[reportPossiblyUnboundVariable]
37073705 sum_2 = tl.cast(tl.sum(v_2, 0), tl.float32)
37083706 grad_b_acc = grad_b_acc_copy_0 + sum_2
37093707 # src[layer_norm.py:N]: wdy = weight_cta * dy_mb
3710- v_10 = v_0_copy_0 * v_2
3708+ v_8 = v_0_copy_0 * v_2
37113709 # src[layer_norm.py:N]: c1 = torch.sum(x_hat * wdy, dim=-1) / n
3712- v_11 = v_6 * v_10
3713- sum_3 = tl.cast(tl.sum(v_11 , 1), tl.float32)
3714- v_12 = 0.015625
3715- v_13 = sum_3 * v_12
3710+ v_9 = v_4 * v_8
3711+ sum_3 = tl.cast(tl.sum(v_9 , 1), tl.float32)
3712+ v_10 = 0.015625
3713+ v_11 = sum_3 * v_10
37163714 # src[layer_norm.py:N]: c2 = torch.sum(wdy, dim=-1) / n
3717- sum_4 = tl.cast(tl.sum(v_10 , 1), tl.float32)
3718- v_14 = 0.015625
3719- v_15 = sum_4 * v_14
3715+ sum_4 = tl.cast(tl.sum(v_8 , 1), tl.float32)
3716+ v_12 = 0.015625
3717+ v_13 = sum_4 * v_12
37203718 # src[layer_norm.py:N]: dx = (wdy - (x_hat * c1[:, None] + c2[:, None])) * rstd_mb[:, None]
3721- subscript_2 = v_13 [:, None]
3722- v_16 = v_6 * subscript_2
3723- subscript_3 = v_15 [:, None]
3724- v_17 = v_16 + subscript_3
3725- v_18 = v_10 - v_17
3726- subscript_4 = v_4 [:, None]
3727- v_19 = v_18 * subscript_4
3719+ subscript_2 = v_11 [:, None]
3720+ v_14 = v_4 * subscript_2
3721+ subscript_3 = v_13 [:, None]
3722+ v_15 = v_14 + subscript_3
3723+ v_16 = v_8 - v_15
3724+ subscript_4 = rstd_mb [:, None]
3725+ v_17 = v_16 * subscript_4
37283726 # src[layer_norm.py:N]: grad_x[mb, :] = dx.to(x.dtype)
3729- v_20 = tl.cast(v_19 , tl.float16)
3730- tl.store(grad_x + (indices_1[:, None] * 64 + indices_3[None, :] * 1), v_20 , None)
3727+ v_18 = tl.cast(v_17 , tl.float16)
3728+ tl.store(grad_x + (indices_1[:, None] * 64 + indices_3[None, :] * 1), v_18 , None)
37313729 # src[layer_norm.py:N]: grad_weight_blocks[mb_cta.id, :] = grad_w_acc
37323730 tile_id = offset_0 // _BLOCK_SIZE_0
37333731 tl.store(grad_weight_blocks + (tile_id * 64 + indices_3 * 1), grad_w_acc, None)
0 commit comments