Describe the bug
Calling torch.save(model.state_dict(), path) when model parameters reside on the XLA device can hang indefinitely or produce corrupt checkpoint files. XLA tensors are lazy and cannot be directly serialized by pickle's torch.save path.
On small models with torch-xla 2.9, an implicit materialization may occur silently, but for large models (hundreds of millions of parameters), the save hangs indefinitely with no error message or timeout.
User-visible impact:
- Checkpoint saves hang forever with no error -- requires
kill -9
- Training runs that take hours lose all progress
- No warning or indication anything is wrong
Workaround: Move all tensors to CPU before saving, and call torch_xla.sync() to ensure the graph has been executed:
def state_dict_to_cpu(sd):
return {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in sd.items()}
torch_xla.sync()
torch.save({"model": state_dict_to_cpu(model.state_dict())}, path)
Instance Type
trn2.3xlarge
Release version
torch-neuronx 2.9.0.2.13.24727+8e870898
torch-xla 2.9.0
PyTorch 2.9.1
neuronx-cc 2.24.5133.0+58f8de22
NRT runtime 2.x.47730.0
Python 3.12.3
Reproduction Steps
import os
os.environ["NEURON_CC_FLAGS"] = "--auto-cast=none"
import torch
import torch.nn as nn
import torch_xla
model = nn.Sequential(nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 64)).to(torch_xla.device())
x = torch.randn(4, 64, device=torch_xla.device())
loss = model(x).sum()
loss.backward()
torch_xla.sync()
# This may hang for large models:
print(f"Tensor device: {next(model.parameters()).device}") # xla:0
torch.save({"model": model.state_dict()}, "/tmp/test.pt") # HANG on large models
Workaround (verified working):
def state_dict_to_cpu(sd):
return {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in sd.items()}
torch_xla.sync()
torch.save({"model": state_dict_to_cpu(model.state_dict())}, "/tmp/test.pt")
# Then reload normally with torch.load(..., map_location="cpu")
Logs/Context/Additional Information
- The hang occurs because pickle tries to serialize XLA tensors that haven't been materialized to host memory
- On small models, torch-xla 2.9 may implicitly materialize tensors, masking the issue
- The bug is reliably triggered with models >100M parameters during training checkpointing
- Discovered during Wan2.2 DiT pretraining on Trainium2
Full reproducer repo: https://github.com/hokindeng/neuron-checkpoint-xla-tensor-crash
Describe the bug
Calling
torch.save(model.state_dict(), path)when model parameters reside on the XLA device can hang indefinitely or produce corrupt checkpoint files. XLA tensors are lazy and cannot be directly serialized by pickle'storch.savepath.On small models with torch-xla 2.9, an implicit materialization may occur silently, but for large models (hundreds of millions of parameters), the save hangs indefinitely with no error message or timeout.
User-visible impact:
kill -9Workaround: Move all tensors to CPU before saving, and call
torch_xla.sync()to ensure the graph has been executed:Instance Type
trn2.3xlargeRelease version
Reproduction Steps
Workaround (verified working):
Logs/Context/Additional Information
Full reproducer repo: https://github.com/hokindeng/neuron-checkpoint-xla-tensor-crash