Describe the bug
After migrating from the deprecated XLA_USE_BF16=1 to the recommended model.to(torch.bfloat16), arithmetic involving fp32 intermediate tensors silently upcasts bf16 data to fp32 via PyTorch type promotion. When this fp32 result is passed to a bf16 Conv3d or Linear layer, it crashes with:
RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same
This is a regression from the migration path: code that worked with XLA_USE_BF16=1 breaks after following the official migration guide, because XLA_USE_BF16 globally casted everything to bf16 (hiding type mismatches), while model.to(torch.bfloat16) only casts model parameters.
Concrete example (diffusion/flow-matching training):
sigma = compute_sigma(timestep) # returns fp32 (timestep.float() / N)
noise = torch.randn_like(latents) # bf16 (same as latents)
noisy = (1 - sigma) * latents + sigma * noise # fp32! (sigma promotes everything)
output = model.conv3d(noisy) # CRASH: fp32 input vs bf16 weights
Workaround: Explicitly cast the result back: noisy = add_noise(latents, noise, sigma).to(latents.dtype)
Instance Type
trn2.3xlarge
Release version
torch-neuronx 2.9.0.2.13.24727+8e870898
torch-xla 2.9.0
PyTorch 2.9.1
neuronx-cc 2.24.5133.0+58f8de22
NRT runtime 2.x.47730.0
Python 3.12.3
Reproduction Steps
import os
os.environ["NEURON_CC_FLAGS"] = "--auto-cast=none"
import torch
import torch.nn as nn
import torch_xla
device = torch_xla.device()
# bf16 model on XLA (the recommended migration from XLA_USE_BF16)
conv = nn.Conv3d(16, 64, kernel_size=(1,2,2), stride=(1,2,2)).to(torch.bfloat16).to(device)
# bf16 inputs
latents = torch.randn(1, 16, 1, 8, 8, device=device, dtype=torch.bfloat16)
noise = torch.randn_like(latents)
# fp32 scheduling math (common in flow-matching / diffusion)
sigma = torch.tensor([0.5], device=device) # fp32
# Type promotion: bf16 * fp32 -> fp32
noisy = (1 - sigma.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)) * latents + sigma.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * noise
print(f"noisy dtype: {noisy.dtype}") # torch.float32 -- UPCASTED!
# CRASH: Input type (float) and bias type (c10::BFloat16) should be the same
out = conv(noisy)
torch_xla.sync()
Workaround (verified working):
noisy = ((1 - sigma_expanded) * latents + sigma_expanded * noise).to(latents.dtype)
out = conv(noisy) # works: bf16 input, bf16 weights
Logs/Context/Additional Information
| Tensor |
Expected dtype |
Actual dtype |
Cause |
latents |
bf16 |
bf16 |
Correct |
noise |
bf16 |
bf16 |
randn_like preserves dtype |
sigma |
fp32 |
fp32 |
timestep.float() / N |
noisy_latents |
bf16 |
fp32 |
bf16 * fp32 -> fp32 promotion |
- This is standard PyTorch type promotion behavior, but it was hidden by
XLA_USE_BF16=1 which globally downcast everything
- The migration guide recommends
model.to(torch.bfloat16) but doesn't warn about intermediate fp32 computations breaking the dtype chain
- Affects all diffusion/flow-matching training pipelines that compute noise schedules in fp32
- Discovered during Wan2.2 DiT video pretraining on Trainium2
Full reproducer repo: https://github.com/hokindeng/neuron-bf16-dtype-upcast-silent-mismatch
Describe the bug
After migrating from the deprecated
XLA_USE_BF16=1to the recommendedmodel.to(torch.bfloat16), arithmetic involving fp32 intermediate tensors silently upcasts bf16 data to fp32 via PyTorch type promotion. When this fp32 result is passed to a bf16 Conv3d or Linear layer, it crashes with:This is a regression from the migration path: code that worked with
XLA_USE_BF16=1breaks after following the official migration guide, becauseXLA_USE_BF16globally casted everything to bf16 (hiding type mismatches), whilemodel.to(torch.bfloat16)only casts model parameters.Concrete example (diffusion/flow-matching training):
Workaround: Explicitly cast the result back:
noisy = add_noise(latents, noise, sigma).to(latents.dtype)Instance Type
trn2.3xlargeRelease version
Reproduction Steps
Workaround (verified working):
Logs/Context/Additional Information
latentsnoiserandn_likepreserves dtypesigmatimestep.float() / Nnoisy_latentsbf16 * fp32 -> fp32promotionXLA_USE_BF16=1which globally downcast everythingmodel.to(torch.bfloat16)but doesn't warn about intermediate fp32 computations breaking the dtype chainFull reproducer repo: https://github.com/hokindeng/neuron-bf16-dtype-upcast-silent-mismatch