diff --git a/tests/test_aot_eager.py b/tests/test_aot_eager.py index dd4124dd..86d994b7 100644 --- a/tests/test_aot_eager.py +++ b/tests/test_aot_eager.py @@ -53,6 +53,10 @@ def test_aot_eager_bitwise_equivalent(llama3_debug_model): x = torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda") torch.manual_seed(3999) r1 = llama3_debug_model(x) + grads1 = torch.autograd.grad(r1.sum(), llama3_debug_model.parameters()) torch.manual_seed(3999) r2 = torch.compile(backend="aot_eager")(llama3_debug_model)(x) + grads2 = torch.autograd.grad(r2.sum(), llama3_debug_model.parameters()) assert torch.equal(r1, r2) # bitwise equal + for g1, g2 in zip(grads1, grads2): + assert torch.equal(g1, g2)