Describe the bug
Setting XLA_USE_BF16=1 with NEURON_CC_FLAGS=--auto-cast=none causes neuronx-cc to crash with INTERNAL_ERROR NCC_INLA001: the number of activation tables must be <= 8 when compiling a transformer forward+backward pass on Trainium2.
The crash is silent at graph trace time -- it only surfaces when torch_xla.sync() is called. The error message is an internal compiler error with no actionable user guidance.
Root cause: XLA_USE_BF16=1 globally casts all operations to bf16 at the XLA/HLO level. When combined with --auto-cast=none, the compiler receives a fully bf16 backward graph where GELU-like activation functions generate more activation lookup tables than the hardware limit of 8.
Scaling note: A tiny 2-layer model may compile successfully, but 20+ layer DiT architectures (Wan2.2 780M and 5B) reliably crash. The crash was discovered during video diffusion transformer pretraining.
Workaround: Use model.to(torch.bfloat16) instead of XLA_USE_BF16=1, per the migration guide.
Instance Type
trn2.3xlarge
Release version
torch-neuronx 2.9.0.2.13.24727+8e870898
torch-xla 2.9.0
PyTorch 2.9.1
neuronx-cc 2.24.5133.0+58f8de22
NRT runtime 2.x.47730.0
Python 3.12.3
Reproduction Steps
Save as reproduce_bf16_crash.py:
import os
os.environ["XLA_USE_BF16"] = "1"
os.environ["NEURON_CC_FLAGS"] = "--auto-cast=none"
import torch
import torch.nn as nn
import torch_xla
class Block(nn.Module):
def __init__(self, dim=128):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, 2, batch_first=True)
self.ffn = nn.Sequential(nn.Linear(dim, 512), nn.GELU(approximate="tanh"), nn.Linear(512, dim))
def forward(self, x):
h = self.norm(x); h, _ = self.attn(h, h, h); x = x + h
return x + self.ffn(self.norm(x))
class Model(nn.Module):
def __init__(self, n=20):
super().__init__()
self.embed = nn.Linear(16, 128)
self.blocks = nn.ModuleList([Block() for _ in range(n)])
self.head = nn.Linear(128, 16)
def forward(self, x):
x = self.embed(x)
for b in self.blocks: x = b(x)
return self.head(x)
device = torch_xla.device()
model = Model(n=20).to(device) # 20 layers to exceed activation table limit
x = torch.randn(1, 64, 16, device=device)
loss = nn.functional.mse_loss(model(x), torch.randn(1, 64, 16, device=device))
loss.backward()
torch_xla.sync() # CRASH: NCC_INLA001 activation tables must be <= 8
Run: python reproduce_bf16_crash.py
Workaround (verified working):
# Do NOT set XLA_USE_BF16=1
os.environ["NEURON_CC_FLAGS"] = "--auto-cast=none"
os.environ["NEURON_RT_STOCHASTIC_ROUNDING_EN"] = "1"
model = Model(n=20).to(torch.bfloat16).to(device)
# Also cast inputs to bf16
Logs/Context/Additional Information
Compiler error output:
[INTERNAL_ERROR] [NCC_INLA001] Unhandled exception with message:
Error from .../neuronxcc/walrus/lower_act/src/lower_act.cpp:296
in function 'generateInstLoadActFuncSet':
(LoadActFuncSet: I-14774-0-PWP), tensorizer(output tensor: bfloat16<128 x 16> $14774, id: 14774)
Instruction LoadActFuncSet I-14774-0-PWP: the number of activation tables must be <= 8
Full reproducer repo: https://github.com/hokindeng/neuron-xla-use-bf16-compiler-crash
Describe the bug
Setting
XLA_USE_BF16=1withNEURON_CC_FLAGS=--auto-cast=nonecausesneuronx-ccto crash withINTERNAL_ERROR NCC_INLA001: the number of activation tables must be <= 8when compiling a transformer forward+backward pass on Trainium2.The crash is silent at graph trace time -- it only surfaces when
torch_xla.sync()is called. The error message is an internal compiler error with no actionable user guidance.Root cause:
XLA_USE_BF16=1globally casts all operations to bf16 at the XLA/HLO level. When combined with--auto-cast=none, the compiler receives a fully bf16 backward graph where GELU-like activation functions generate more activation lookup tables than the hardware limit of 8.Scaling note: A tiny 2-layer model may compile successfully, but 20+ layer DiT architectures (Wan2.2 780M and 5B) reliably crash. The crash was discovered during video diffusion transformer pretraining.
Workaround: Use
model.to(torch.bfloat16)instead ofXLA_USE_BF16=1, per the migration guide.Instance Type
trn2.3xlargeRelease version
Reproduction Steps
Save as
reproduce_bf16_crash.py:Run:
python reproduce_bf16_crash.pyWorkaround (verified working):
Logs/Context/Additional Information
Compiler error output:
Full reproducer repo: https://github.com/hokindeng/neuron-xla-use-bf16-compiler-crash