Skip to content

[torch-xla 2.9] bf16 model + fp32 intermediate arithmetic causes dtype mismatch crash after XLA_USE_BF16 migration #1318

@hokindeng

Description

@hokindeng

Describe the bug

After migrating from the deprecated XLA_USE_BF16=1 to the recommended model.to(torch.bfloat16), arithmetic involving fp32 intermediate tensors silently upcasts bf16 data to fp32 via PyTorch type promotion. When this fp32 result is passed to a bf16 Conv3d or Linear layer, it crashes with:

RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same

This is a regression from the migration path: code that worked with XLA_USE_BF16=1 breaks after following the official migration guide, because XLA_USE_BF16 globally casted everything to bf16 (hiding type mismatches), while model.to(torch.bfloat16) only casts model parameters.

Concrete example (diffusion/flow-matching training):

sigma = compute_sigma(timestep)          # returns fp32 (timestep.float() / N)
noise = torch.randn_like(latents)        # bf16 (same as latents)
noisy = (1 - sigma) * latents + sigma * noise   # fp32! (sigma promotes everything)
output = model.conv3d(noisy)             # CRASH: fp32 input vs bf16 weights

Workaround: Explicitly cast the result back: noisy = add_noise(latents, noise, sigma).to(latents.dtype)

Instance Type

trn2.3xlarge

Release version

torch-neuronx    2.9.0.2.13.24727+8e870898
torch-xla        2.9.0
PyTorch           2.9.1
neuronx-cc        2.24.5133.0+58f8de22
NRT runtime       2.x.47730.0
Python            3.12.3

Reproduction Steps

import os
os.environ["NEURON_CC_FLAGS"] = "--auto-cast=none"

import torch
import torch.nn as nn
import torch_xla

device = torch_xla.device()

# bf16 model on XLA (the recommended migration from XLA_USE_BF16)
conv = nn.Conv3d(16, 64, kernel_size=(1,2,2), stride=(1,2,2)).to(torch.bfloat16).to(device)

# bf16 inputs
latents = torch.randn(1, 16, 1, 8, 8, device=device, dtype=torch.bfloat16)
noise = torch.randn_like(latents)

# fp32 scheduling math (common in flow-matching / diffusion)
sigma = torch.tensor([0.5], device=device)  # fp32

# Type promotion: bf16 * fp32 -> fp32
noisy = (1 - sigma.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)) * latents + sigma.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * noise
print(f"noisy dtype: {noisy.dtype}")  # torch.float32 -- UPCASTED!

# CRASH: Input type (float) and bias type (c10::BFloat16) should be the same
out = conv(noisy)
torch_xla.sync()

Workaround (verified working):

noisy = ((1 - sigma_expanded) * latents + sigma_expanded * noise).to(latents.dtype)
out = conv(noisy)  # works: bf16 input, bf16 weights

Logs/Context/Additional Information

Tensor Expected dtype Actual dtype Cause
latents bf16 bf16 Correct
noise bf16 bf16 randn_like preserves dtype
sigma fp32 fp32 timestep.float() / N
noisy_latents bf16 fp32 bf16 * fp32 -> fp32 promotion
  • This is standard PyTorch type promotion behavior, but it was hidden by XLA_USE_BF16=1 which globally downcast everything
  • The migration guide recommends model.to(torch.bfloat16) but doesn't warn about intermediate fp32 computations breaking the dtype chain
  • Affects all diffusion/flow-matching training pipelines that compute noise schedules in fp32
  • Discovered during Wan2.2 DiT video pretraining on Trainium2

Full reproducer repo: https://github.com/hokindeng/neuron-bf16-dtype-upcast-silent-mismatch

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions