Skip to content

NKI: trn1 verifier rejects shared-memory instruction in fused XLA graph spanning ao_to_mo_transform + mp2_energy #1311

@scttfrdmn

Description

@scttfrdmn

Summary

When two NKI kernels are dispatched back-to-back with all operands pre-pinned to the XLA device and no xm.mark_step() barrier between them, the XLA lazy evaluator fuses the two NKI programs into a single compiled graph. The resulting NEFF contains shared-memory instructions that are only valid on trn2, causing the trn1 Neuron Runtime verifier to reject it.

Environment

  • Hardware: trn1.2xlarge
  • Neuron SDK: 2.29.18.0
  • NKI: 0.3.0 (Stable)
  • PyTorch / torch_xla: 2.9 (from AWS Deep Learning AMI Deep Learning AMI Neuron PyTorch 2.9 Ubuntu 24.04)
  • trntensor: 0.5.0

Reproduction script

import sys, os, torch
sys.path.insert(0, '/home/ubuntu/trntensor')
os.environ['TRNTENSOR_REQUIRE_NKI'] = '1'
os.environ['XLA_IR_DEBUG'] = '1'
os.environ['XLA_HLO_DEBUG'] = '1'
import trntensor
from trntensor.nki.dispatch import to_xla

nbasis, nocc, nvir, naux = 128, 4, 8, 16
eri    = torch.randn(nbasis, nbasis, naux)
C_occ  = torch.randn(nbasis, nocc)
C_vir  = torch.randn(nbasis, nvir)
eps_occ = torch.randn(nocc)
eps_vir = torch.randn(nvir)

# Pre-pin all inputs to XLA device
eri_x     = to_xla(eri)
C_occ_x   = to_xla(C_occ)
C_vir_x   = to_xla(C_vir)
eps_occ_x = to_xla(eps_occ)
eps_vir_x = to_xla(eps_vir)

# Kernel 1 — compiles and runs fine in isolation
B_x = trntensor.ao_to_mo_transform(eri_x, C_occ_x, C_vir_x)
print("ao_to_mo_transform OK")

# Kernel 2 — with B_x still on XLA (no from_xla / mark_step between calls),
# the lazy evaluator fuses both kernels into one graph and the combined NEFF
# triggers trn2-only shared-memory instructions, failing the trn1 verifier.
E = trntensor.mp2_energy(B_x, eps_occ_x, eps_vir_x)
print(f"mp2_energy OK: E={E.item()}")

Error

The trn1 Neuron Runtime verifier rejects the fused NEFF with an error referencing shared-memory instructions that are specific to trn2 / NeuronCore v3. The exact error string is of the form:

RuntimeError: [NRT] Failed to load NEFF: Unsupported instruction ...
  (shared memory instruction not supported on NeuronCore v2)

(The exact instruction name varies; the common pattern is a shared-memory addressing mode introduced for trn2 that the trn1 verifier does not recognise.)

Workaround

Insert from_xla(B_x) between the two kernel calls to force graph materialization before the second dispatch:

B_x = trntensor.ao_to_mo_transform(eri_x, C_occ_x, C_vir_x)
import torch_xla.core.xla_model as xm
xm.mark_step()          # or: B = trntensor.from_xla(B_x); B_x = to_xla(B)
E = trntensor.mp2_energy(B_x, eps_occ_x, eps_vir_x)

This prevents lazy-graph fusion and keeps each kernel in its own NEFF.

Expected behavior

Two independently-valid NKI kernels, each of which compiles and runs correctly on trn1 individually, should not produce a NEFF containing trn2-only instructions when fused. Either:

  • The XLA→Neuron lowering should avoid emitting trn2-specific shared-memory ops when targeting a trn1 device, or
  • The fusion should be suppressed when it would produce a NEFF that is not compatible with the target NeuronCore version.

Impact

Prevents end-to-end pipeline execution (e.g., DF-MP2: ao_to_mo_transformmp2_energy) with pre-pinned operands, which is the primary performance pattern for chained NKI kernels. Tracked in trnsci/trntensor#39.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Trn1Trn2bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions