Skip to content

Add low precision support to conditional SFNO#800

Draft
mcgibbon wants to merge 6 commits intomainfrom
feature/bfloat16_sfno
Draft

Add low precision support to conditional SFNO#800
mcgibbon wants to merge 6 commits intomainfrom
feature/bfloat16_sfno

Conversation

@mcgibbon
Copy link
Contributor

@mcgibbon mcgibbon commented Feb 5, 2026

Short description of why the PR is needed and how it satisfies those requirements, in sentence form.

Changes:

  • symbol (e.g. fme.core.my_function) or script and concise description of changes or added feature

  • Can group multiple related symbols on a single bullet

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

Resolves # (delete if none)

"""
wc = torch.view_as_complex(weight)
return torch.einsum("bixy,iox->boxy", xc, wc)
r0 = torch.einsum("bixy,iox->boxy", x[..., 0], w[..., 0])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely some performance being left on the table here by doing these parts separately, but it shouldn't significantly affect the memory reductions from bfloat16/float16.

I would leave optimization of this function to happen alongside profiling tools. Doing so likely requires some manual permutes and matrix multiplications, which could hurt performance if done wrong.


with torch.amp.autocast("cuda", enabled=False):
x = self.forward_transform(x.float())
x = self.forward_transform(x)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float casting is now handled exactly when it's needed within the forward transform.

@@ -0,0 +1,222 @@
# flake8: noqa
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is added (copy-paste from sht_fix.py) in one commit and edited in the second, I recommend reviewing the second commit separately.

@mcgibbon mcgibbon changed the title Feature/bfloat16 sfno Add low precision support to conditional SFNO Feb 5, 2026

with amp.autocast(device_type="cuda", enabled=False):
x = self.forward_transform(x).contiguous()
x = torch.view_as_complex(self.forward_transform(x)).contiguous()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes were required because of the change to the API for the forward/inverse transform to take in real arrays with trailing [2] dim, which is needed for them to support low precision.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant