Skip to content

fix: force float32 precision on CPU to prevent distorted structures#670

Open
joshuasteier wants to merge 1 commit intojwohlwend:mainfrom
joshuasteier:fix/cpu-float32-precision
Open

fix: force float32 precision on CPU to prevent distorted structures#670
joshuasteier wants to merge 1 commit intojwohlwend:mainfrom
joshuasteier:fix/cpu-float32-precision

Conversation

@joshuasteier
Copy link
Copy Markdown

@joshuasteier joshuasteier commented Apr 6, 2026

This updates PR #670 from the blanket CPU float32 workaround in main.py
to the underlying fix described in #662.

On CPU, PyTorch Lightning wraps the forward pass with
torch.autocast("cpu", dtype=torch.bfloat16), so
torch.autocast("cuda", enabled=False) is a no-op there. As a result,
the structure and affinity blocks in boltz2.py still run under CPU bf16
autocast even though they explicitly cast inputs to float32.

This patch:

  • removes the main.py CPU precision workaround from this PR
  • adds an autocast_device_type() helper in model/modules/utils.py
  • replaces the three hardcoded torch.autocast("cuda", enabled=False)
    blocks in src/boltz/model/models/boltz2.py with
    torch.autocast(autocast_device_type(s.device.type), enabled=False)

Local validation:

  • under torch.autocast("cpu", dtype=torch.bfloat16), the old path keeps
    matmul in torch.bfloat16
  • the new autocast_device_type(...) path correctly yields torch.float32

Credit:

Closes #653
Closes #662

@joshuasteier joshuasteier force-pushed the fix/cpu-float32-precision branch from 0820f97 to 83bb04c Compare April 10, 2026 19:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

1 participant