Skip to content

[torch-xla 2.9] torch.save() on XLA model state_dict hangs or produces corrupt checkpoints #1315

@hokindeng

Description

@hokindeng

Describe the bug

Calling torch.save(model.state_dict(), path) when model parameters reside on the XLA device can hang indefinitely or produce corrupt checkpoint files. XLA tensors are lazy and cannot be directly serialized by pickle's torch.save path.

On small models with torch-xla 2.9, an implicit materialization may occur silently, but for large models (hundreds of millions of parameters), the save hangs indefinitely with no error message or timeout.

User-visible impact:

  • Checkpoint saves hang forever with no error -- requires kill -9
  • Training runs that take hours lose all progress
  • No warning or indication anything is wrong

Workaround: Move all tensors to CPU before saving, and call torch_xla.sync() to ensure the graph has been executed:

def state_dict_to_cpu(sd):
    return {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in sd.items()}

torch_xla.sync()
torch.save({"model": state_dict_to_cpu(model.state_dict())}, path)

Instance Type

trn2.3xlarge

Release version

torch-neuronx    2.9.0.2.13.24727+8e870898
torch-xla        2.9.0
PyTorch           2.9.1
neuronx-cc        2.24.5133.0+58f8de22
NRT runtime       2.x.47730.0
Python            3.12.3

Reproduction Steps

import os
os.environ["NEURON_CC_FLAGS"] = "--auto-cast=none"

import torch
import torch.nn as nn
import torch_xla

model = nn.Sequential(nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 64)).to(torch_xla.device())
x = torch.randn(4, 64, device=torch_xla.device())
loss = model(x).sum()
loss.backward()
torch_xla.sync()

# This may hang for large models:
print(f"Tensor device: {next(model.parameters()).device}")  # xla:0
torch.save({"model": model.state_dict()}, "/tmp/test.pt")  # HANG on large models

Workaround (verified working):

def state_dict_to_cpu(sd):
    return {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in sd.items()}

torch_xla.sync()
torch.save({"model": state_dict_to_cpu(model.state_dict())}, "/tmp/test.pt")
# Then reload normally with torch.load(..., map_location="cpu")

Logs/Context/Additional Information

  • The hang occurs because pickle tries to serialize XLA tensors that haven't been materialized to host memory
  • On small models, torch-xla 2.9 may implicitly materialize tensors, masking the issue
  • The bug is reliably triggered with models >100M parameters during training checkpointing
  • Discovered during Wan2.2 DiT pretraining on Trainium2

Full reproducer repo: https://github.com/hokindeng/neuron-checkpoint-xla-tensor-crash

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions