From d2d817d39e522e8146f1e642f291b8ec86c7ccda Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 30 Sep 2025 18:30:18 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- tests/test_aot_eager.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_aot_eager.py b/tests/test_aot_eager.py index dd4124dd..86d994b7 100644 --- a/tests/test_aot_eager.py +++ b/tests/test_aot_eager.py @@ -53,6 +53,10 @@ def test_aot_eager_bitwise_equivalent(llama3_debug_model): x = torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda") torch.manual_seed(3999) r1 = llama3_debug_model(x) + grads1 = torch.autograd.grad(r1.sum(), llama3_debug_model.parameters()) torch.manual_seed(3999) r2 = torch.compile(backend="aot_eager")(llama3_debug_model)(x) + grads2 = torch.autograd.grad(r2.sum(), llama3_debug_model.parameters()) assert torch.equal(r1, r2) # bitwise equal + for g1, g2 in zip(grads1, grads2): + assert torch.equal(g1, g2)