Skip to content

Commit 990a6b0

Browse files
committed
Add test for gradients in the bitwise equivalent test
Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: a9c1298 ghstack-comment-id: 3354350922 Pull-Request: #176
1 parent 5dcba70 commit 990a6b0

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

tests/test_aot_eager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ def test_aot_eager_bitwise_equivalent(llama3_debug_model):
5353
x = torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda")
5454
torch.manual_seed(3999)
5555
r1 = llama3_debug_model(x)
56+
grads1 = torch.autograd.grad(r1.sum(), llama3_debug_model.parameters())
5657
torch.manual_seed(3999)
5758
r2 = torch.compile(backend="aot_eager")(llama3_debug_model)(x)
59+
grads2 = torch.autograd.grad(r2.sum(), llama3_debug_model.parameters())
5860
assert torch.equal(r1, r2) # bitwise equal
61+
for g1, g2 in zip(grads1, grads2):
62+
assert torch.equal(g1, g2)

0 commit comments

Comments
 (0)