Add low precision support to conditional SFNO#800
Conversation
| """ | ||
| wc = torch.view_as_complex(weight) | ||
| return torch.einsum("bixy,iox->boxy", xc, wc) | ||
| r0 = torch.einsum("bixy,iox->boxy", x[..., 0], w[..., 0]) |
There was a problem hiding this comment.
Definitely some performance being left on the table here by doing these parts separately, but it shouldn't significantly affect the memory reductions from bfloat16/float16.
I would leave optimization of this function to happen alongside profiling tools. Doing so likely requires some manual permutes and matrix multiplications, which could hurt performance if done wrong.
|
|
||
| with torch.amp.autocast("cuda", enabled=False): | ||
| x = self.forward_transform(x.float()) | ||
| x = self.forward_transform(x) |
There was a problem hiding this comment.
float casting is now handled exactly when it's needed within the forward transform.
| @@ -0,0 +1,222 @@ | |||
| # flake8: noqa | |||
There was a problem hiding this comment.
This file is added (copy-paste from sht_fix.py) in one commit and edited in the second, I recommend reviewing the second commit separately.
|
|
||
| with amp.autocast(device_type="cuda", enabled=False): | ||
| x = self.forward_transform(x).contiguous() | ||
| x = torch.view_as_complex(self.forward_transform(x)).contiguous() |
There was a problem hiding this comment.
These changes were required because of the change to the API for the forward/inverse transform to take in real arrays with trailing [2] dim, which is needed for them to support low precision.
Short description of why the PR is needed and how it satisfies those requirements, in sentence form.
Changes:
symbol (e.g.
fme.core.my_function) or script and concise description of changes or added featureCan group multiple related symbols on a single bullet
Tests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated
Resolves # (delete if none)