Describe the bug
A standard PyTorch training loop (loss.backward() then optimizer.step()) runs without errors on XLA but silently does nothing -- model weights are never updated, loss never decreases. The XLA computation graph is traced but never executed.
This is a silent correctness failure -- the most dangerous kind of bug. Hours of training produce a model identical to random initialization, with no errors or warnings.
Root cause: XLA uses lazy evaluation. loss.backward() and optimizer.step() add nodes to a computation graph but don't execute it. Execution only happens when torch_xla.sync() (or the deprecated xm.mark_step()) is called.
Version note: In torch-xla 2.9, loss.item() may implicitly trigger a sync, partially masking this issue. However, relying on implicit .item() sync is dangerous because:
- It causes separate graph compilation per step (no batching)
- Without
.item() calls (e.g., logging every N steps), training silently does nothing
- The recommended pattern is always explicit
torch_xla.sync()
Workaround: Add torch_xla.sync() after backward and after optimizer step.
Instance Type
trn2.3xlarge
Release version
torch-neuronx 2.9.0.2.13.24727+8e870898
torch-xla 2.9.0
PyTorch 2.9.1
neuronx-cc 2.24.5133.0+58f8de22
NRT runtime 2.x.47730.0
Python 3.12.3
Reproduction Steps
import os
os.environ["NEURON_CC_FLAGS"] = "--auto-cast=none"
import torch
import torch.nn as nn
import torch_xla
device = torch_xla.device()
model = nn.Sequential(nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 1)).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
x = torch.randn(16, 32, device=device)
target = torch.ones(16, 1, device=device)
weight_before = model[0].weight.detach().clone()
torch_xla.sync()
weight_before_cpu = weight_before.cpu()
# BROKEN: no torch_xla.sync() calls
for step in range(10):
loss = nn.functional.mse_loss(model(x), target)
loss.backward()
# Missing: torch_xla.sync()
optimizer.step()
optimizer.zero_grad()
# Missing: torch_xla.sync()
torch_xla.sync()
weight_after_cpu = model[0].weight.detach().cpu()
diff = float((weight_after_cpu - weight_before_cpu).abs().max())
print(f"Weight change: {diff}") # May be 0 if .item() wasn't called
Fixed version:
for step in range(10):
loss = nn.functional.mse_loss(model(x), target)
loss.backward()
torch_xla.sync() # <-- execute backward graph
optimizer.step()
optimizer.zero_grad()
torch_xla.sync() # <-- execute optimizer step
print(f"Step {step}: loss={loss.item():.4f}") # loss actually decreases
Logs/Context/Additional Information
- This is not a crash -- the training loop completes normally with no errors
- Loss values printed via
loss.item() may appear to change (due to implicit sync from .item())
- Without
.item() or explicit sync, the XLA graph accumulates but never executes
- Discovered during Wan2.2 DiT video diffusion pretraining where loss appeared flat
Full reproducer repo: https://github.com/hokindeng/neuron-missing-mark-step-silent-hang
Describe the bug
A standard PyTorch training loop (
loss.backward()thenoptimizer.step()) runs without errors on XLA but silently does nothing -- model weights are never updated, loss never decreases. The XLA computation graph is traced but never executed.This is a silent correctness failure -- the most dangerous kind of bug. Hours of training produce a model identical to random initialization, with no errors or warnings.
Root cause: XLA uses lazy evaluation.
loss.backward()andoptimizer.step()add nodes to a computation graph but don't execute it. Execution only happens whentorch_xla.sync()(or the deprecatedxm.mark_step()) is called.Version note: In torch-xla 2.9,
loss.item()may implicitly trigger a sync, partially masking this issue. However, relying on implicit.item()sync is dangerous because:.item()calls (e.g., logging every N steps), training silently does nothingtorch_xla.sync()Workaround: Add
torch_xla.sync()after backward and after optimizer step.Instance Type
trn2.3xlargeRelease version
Reproduction Steps
Fixed version:
Logs/Context/Additional Information
loss.item()may appear to change (due to implicit sync from.item()).item()or explicit sync, the XLA graph accumulates but never executesFull reproducer repo: https://github.com/hokindeng/neuron-missing-mark-step-silent-hang