Skip to content

[neuronx-cc 2.24] NCC_INLA001: XLA_USE_BF16=1 + --auto-cast=none causes compiler crash on transformer backward pass (activation table limit) #1314

@hokindeng

Description

@hokindeng

Describe the bug

Setting XLA_USE_BF16=1 with NEURON_CC_FLAGS=--auto-cast=none causes neuronx-cc to crash with INTERNAL_ERROR NCC_INLA001: the number of activation tables must be <= 8 when compiling a transformer forward+backward pass on Trainium2.

The crash is silent at graph trace time -- it only surfaces when torch_xla.sync() is called. The error message is an internal compiler error with no actionable user guidance.

Root cause: XLA_USE_BF16=1 globally casts all operations to bf16 at the XLA/HLO level. When combined with --auto-cast=none, the compiler receives a fully bf16 backward graph where GELU-like activation functions generate more activation lookup tables than the hardware limit of 8.

Scaling note: A tiny 2-layer model may compile successfully, but 20+ layer DiT architectures (Wan2.2 780M and 5B) reliably crash. The crash was discovered during video diffusion transformer pretraining.

Workaround: Use model.to(torch.bfloat16) instead of XLA_USE_BF16=1, per the migration guide.

Instance Type

trn2.3xlarge

Release version

torch-neuronx    2.9.0.2.13.24727+8e870898
torch-xla        2.9.0
PyTorch           2.9.1
neuronx-cc        2.24.5133.0+58f8de22
NRT runtime       2.x.47730.0
Python            3.12.3

Reproduction Steps

Save as reproduce_bf16_crash.py:

import os
os.environ["XLA_USE_BF16"] = "1"
os.environ["NEURON_CC_FLAGS"] = "--auto-cast=none"

import torch
import torch.nn as nn
import torch_xla

class Block(nn.Module):
    def __init__(self, dim=128):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, 2, batch_first=True)
        self.ffn = nn.Sequential(nn.Linear(dim, 512), nn.GELU(approximate="tanh"), nn.Linear(512, dim))
    def forward(self, x):
        h = self.norm(x); h, _ = self.attn(h, h, h); x = x + h
        return x + self.ffn(self.norm(x))

class Model(nn.Module):
    def __init__(self, n=20):
        super().__init__()
        self.embed = nn.Linear(16, 128)
        self.blocks = nn.ModuleList([Block() for _ in range(n)])
        self.head = nn.Linear(128, 16)
    def forward(self, x):
        x = self.embed(x)
        for b in self.blocks: x = b(x)
        return self.head(x)

device = torch_xla.device()
model = Model(n=20).to(device)  # 20 layers to exceed activation table limit
x = torch.randn(1, 64, 16, device=device)
loss = nn.functional.mse_loss(model(x), torch.randn(1, 64, 16, device=device))
loss.backward()
torch_xla.sync()  # CRASH: NCC_INLA001 activation tables must be <= 8

Run: python reproduce_bf16_crash.py

Workaround (verified working):

# Do NOT set XLA_USE_BF16=1
os.environ["NEURON_CC_FLAGS"] = "--auto-cast=none"
os.environ["NEURON_RT_STOCHASTIC_ROUNDING_EN"] = "1"
model = Model(n=20).to(torch.bfloat16).to(device)
# Also cast inputs to bf16

Logs/Context/Additional Information

Compiler error output:

[INTERNAL_ERROR] [NCC_INLA001] Unhandled exception with message:
Error from .../neuronxcc/walrus/lower_act/src/lower_act.cpp:296
in function 'generateInstLoadActFuncSet':
(LoadActFuncSet: I-14774-0-PWP), tensorizer(output tensor: bfloat16<128 x 16> $14774, id: 14774)
Instruction LoadActFuncSet I-14774-0-PWP: the number of activation tables must be <= 8

Full reproducer repo: https://github.com/hokindeng/neuron-xla-use-bf16-compiler-crash

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions