Skip to content

[torch-xla 2.9] Training without torch_xla.sync() silently produces no weight updates -- silent correctness failure #1316

@hokindeng

Description

@hokindeng

Describe the bug

A standard PyTorch training loop (loss.backward() then optimizer.step()) runs without errors on XLA but silently does nothing -- model weights are never updated, loss never decreases. The XLA computation graph is traced but never executed.

This is a silent correctness failure -- the most dangerous kind of bug. Hours of training produce a model identical to random initialization, with no errors or warnings.

Root cause: XLA uses lazy evaluation. loss.backward() and optimizer.step() add nodes to a computation graph but don't execute it. Execution only happens when torch_xla.sync() (or the deprecated xm.mark_step()) is called.

Version note: In torch-xla 2.9, loss.item() may implicitly trigger a sync, partially masking this issue. However, relying on implicit .item() sync is dangerous because:

  1. It causes separate graph compilation per step (no batching)
  2. Without .item() calls (e.g., logging every N steps), training silently does nothing
  3. The recommended pattern is always explicit torch_xla.sync()

Workaround: Add torch_xla.sync() after backward and after optimizer step.

Instance Type

trn2.3xlarge

Release version

torch-neuronx    2.9.0.2.13.24727+8e870898
torch-xla        2.9.0
PyTorch           2.9.1
neuronx-cc        2.24.5133.0+58f8de22
NRT runtime       2.x.47730.0
Python            3.12.3

Reproduction Steps

import os
os.environ["NEURON_CC_FLAGS"] = "--auto-cast=none"

import torch
import torch.nn as nn
import torch_xla

device = torch_xla.device()
model = nn.Sequential(nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 1)).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

x = torch.randn(16, 32, device=device)
target = torch.ones(16, 1, device=device)

weight_before = model[0].weight.detach().clone()
torch_xla.sync()
weight_before_cpu = weight_before.cpu()

# BROKEN: no torch_xla.sync() calls
for step in range(10):
    loss = nn.functional.mse_loss(model(x), target)
    loss.backward()
    # Missing: torch_xla.sync()
    optimizer.step()
    optimizer.zero_grad()
    # Missing: torch_xla.sync()

torch_xla.sync()
weight_after_cpu = model[0].weight.detach().cpu()
diff = float((weight_after_cpu - weight_before_cpu).abs().max())
print(f"Weight change: {diff}")  # May be 0 if .item() wasn't called

Fixed version:

for step in range(10):
    loss = nn.functional.mse_loss(model(x), target)
    loss.backward()
    torch_xla.sync()       # <-- execute backward graph
    optimizer.step()
    optimizer.zero_grad()
    torch_xla.sync()       # <-- execute optimizer step
    print(f"Step {step}: loss={loss.item():.4f}")  # loss actually decreases

Logs/Context/Additional Information

  • This is not a crash -- the training loop completes normally with no errors
  • Loss values printed via loss.item() may appear to change (due to implicit sync from .item())
  • Without .item() or explicit sync, the XLA graph accumulates but never executes
  • Discovered during Wan2.2 DiT video diffusion pretraining where loss appeared flat

Full reproducer repo: https://github.com/hokindeng/neuron-missing-mark-step-silent-hang

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions