Skip to content

Commit 79d57b7

Browse files
authored
Fix layernorm bwd unit test (#1047)
1 parent 7db09b7 commit 79d57b7

File tree

2 files changed

+29
-30
lines changed

2 files changed

+29
-30
lines changed

test/test_examples.expected

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3689,45 +3689,43 @@ def _helion_layer_norm_bwd(weight, x, grad_out, mean, rstd, grad_x, grad_weight_
36893689
load_2 = tl.load(grad_out + (indices_1[:, None] * 64 + indices_3[None, :] * 1), None)
36903690
v_2 = tl.cast(load_2, tl.float32)
36913691
# src[layer_norm.py:N]: mean_mb = mean[mb].to(torch.float32)
3692-
load_3 = tl.load(mean + indices_1 * 1, None)
3693-
v_3 = tl.cast(load_3, tl.float32)
3692+
mean_mb = tl.load(mean + indices_1 * 1, None)
36943693
# src[layer_norm.py:N]: rstd_mb = rstd[mb].to(torch.float32)
3695-
load_4 = tl.load(rstd + indices_1 * 1, None)
3696-
v_4 = tl.cast(load_4, tl.float32)
3694+
rstd_mb = tl.load(rstd + indices_1 * 1, None)
36973695
# src[layer_norm.py:N]: x_hat = (x_mb - mean_mb[:, None]) * rstd_mb[:, None]
3698-
subscript = v_3[:, None]
3699-
v_5 = v_1 - subscript
3700-
subscript_1 = v_4[:, None]
3701-
v_6 = v_5 * subscript_1
3696+
subscript = mean_mb[:, None]
3697+
v_3 = v_1 - subscript
3698+
subscript_1 = rstd_mb[:, None]
3699+
v_4 = v_3 * subscript_1
37023700
# src[layer_norm.py:N]: grad_w_acc += torch.sum(dy_mb * x_hat, dim=0)
3703-
v_7 = v_2 * v_6
3704-
sum_1 = tl.cast(tl.sum(v_7, 0), tl.float32)
3701+
v_5 = v_2 * v_4
3702+
sum_1 = tl.cast(tl.sum(v_5, 0), tl.float32)
37053703
grad_w_acc = grad_w_acc_copy_0 + sum_1
37063704
# src[layer_norm.py:N]: grad_b_acc += torch.sum(dy_mb, dim=0) # pyright: ignore[reportPossiblyUnboundVariable]
37073705
sum_2 = tl.cast(tl.sum(v_2, 0), tl.float32)
37083706
grad_b_acc = grad_b_acc_copy_0 + sum_2
37093707
# src[layer_norm.py:N]: wdy = weight_cta * dy_mb
3710-
v_10 = v_0_copy_0 * v_2
3708+
v_8 = v_0_copy_0 * v_2
37113709
# src[layer_norm.py:N]: c1 = torch.sum(x_hat * wdy, dim=-1) / n
3712-
v_11 = v_6 * v_10
3713-
sum_3 = tl.cast(tl.sum(v_11, 1), tl.float32)
3714-
v_12 = 0.015625
3715-
v_13 = sum_3 * v_12
3710+
v_9 = v_4 * v_8
3711+
sum_3 = tl.cast(tl.sum(v_9, 1), tl.float32)
3712+
v_10 = 0.015625
3713+
v_11 = sum_3 * v_10
37163714
# src[layer_norm.py:N]: c2 = torch.sum(wdy, dim=-1) / n
3717-
sum_4 = tl.cast(tl.sum(v_10, 1), tl.float32)
3718-
v_14 = 0.015625
3719-
v_15 = sum_4 * v_14
3715+
sum_4 = tl.cast(tl.sum(v_8, 1), tl.float32)
3716+
v_12 = 0.015625
3717+
v_13 = sum_4 * v_12
37203718
# src[layer_norm.py:N]: dx = (wdy - (x_hat * c1[:, None] + c2[:, None])) * rstd_mb[:, None]
3721-
subscript_2 = v_13[:, None]
3722-
v_16 = v_6 * subscript_2
3723-
subscript_3 = v_15[:, None]
3724-
v_17 = v_16 + subscript_3
3725-
v_18 = v_10 - v_17
3726-
subscript_4 = v_4[:, None]
3727-
v_19 = v_18 * subscript_4
3719+
subscript_2 = v_11[:, None]
3720+
v_14 = v_4 * subscript_2
3721+
subscript_3 = v_13[:, None]
3722+
v_15 = v_14 + subscript_3
3723+
v_16 = v_8 - v_15
3724+
subscript_4 = rstd_mb[:, None]
3725+
v_17 = v_16 * subscript_4
37283726
# src[layer_norm.py:N]: grad_x[mb, :] = dx.to(x.dtype)
3729-
v_20 = tl.cast(v_19, tl.float16)
3730-
tl.store(grad_x + (indices_1[:, None] * 64 + indices_3[None, :] * 1), v_20, None)
3727+
v_18 = tl.cast(v_17, tl.float16)
3728+
tl.store(grad_x + (indices_1[:, None] * 64 + indices_3[None, :] * 1), v_18, None)
37313729
# src[layer_norm.py:N]: grad_weight_blocks[mb_cta.id, :] = grad_w_acc
37323730
tile_id = offset_0 // _BLOCK_SIZE_0
37333731
tl.store(grad_weight_blocks + (tile_id * 64 + indices_3 * 1), grad_w_acc, None)

test/test_examples.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,6 @@ def test_layernorm_no_bias(self):
929929
)
930930
)
931931

932-
@skipIfRocm("accuracy check fails on AMD GPUs")
933932
@skipIfA10G("accuracy check fails on A10G GPUs")
934933
def test_layernorm_bwd(self):
935934
"""Test combined backward pass for layer norm with bias, including regression coverage."""
@@ -966,8 +965,10 @@ def test_layernorm_bwd(self):
966965
[batch_size, dim], device=DEVICE, dtype=torch.float16
967966
)
968967

969-
mean = x.mean(dim=-1)
970-
var = x.var(dim=-1, unbiased=False)
968+
# Compute mean, var, and rstd in fp32 to match Helion forward kernel output
969+
x_fp32 = x.to(torch.float32)
970+
mean = x_fp32.mean(dim=-1)
971+
var = x_fp32.var(dim=-1, unbiased=False)
971972
rstd = torch.rsqrt(var + eps)
972973

973974
x_ref = x.clone().detach().requires_grad_(True)

0 commit comments

Comments
 (0)