Skip to content

Add use_linear option to replace Conv3d tokenizer with Linear layers#48

Open
nicholasmalaya wants to merge 1 commit intoORNL:mainfrom
nicholasmalaya:feature/use-linear-tokenizer
Open

Add use_linear option to replace Conv3d tokenizer with Linear layers#48
nicholasmalaya wants to merge 1 commit intoORNL:mainfrom
nicholasmalaya:feature/use-linear-tokenizer

Conversation

@nicholasmalaya
Copy link
Copy Markdown

When kernel_size == stride (non-overlapping patches), Conv3d is mathematically equivalent to reshape + nn.Linear. This avoids the im2col/col2im overhead and replaces MIOpen's implicit GEMM backward-weight path with standard rocBLAS matmul backward.

Profiling on MI355X (gfx950) shows the backward-weight GEMM (kernel_batched_gemm_xdlops_bwd_weight) consumed 79.3% of compute time. With use_linear=True, this kernel is eliminated entirely, yielding a 2.87x end-to-end training speedup with identical loss convergence.

Enabled via config: use_linear: !!bool True (default False, fully backward compatible).

When kernel_size == stride (non-overlapping patches), Conv3d is
mathematically equivalent to reshape + nn.Linear. This avoids the
im2col/col2im overhead and replaces MIOpen's implicit GEMM
backward-weight path with standard rocBLAS matmul backward.

Profiling on MI355X (gfx950) shows the backward-weight GEMM
(kernel_batched_gemm_xdlops_bwd_weight) consumed 79.3% of compute
time. With use_linear=True, this kernel is eliminated entirely,
yielding a 2.87x end-to-end training speedup with identical loss
convergence.

Enabled via config: use_linear: !!bool True (default False,
fully backward compatible).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@pzhanggit pzhanggit requested review from TsChala and pzhanggit April 7, 2026 18:36
Copy link
Copy Markdown
Collaborator

@pzhanggit pzhanggit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nicholasmalaya thank you very much for the optimization, Nick!
@TsChala the PR looks good to me. Could you do a test run on Frontier when it's back from maintenance? We should extend the changes to other models for better performance as well. Thanks

@TsChala
Copy link
Copy Markdown
Collaborator

TsChala commented Apr 8, 2026

Thanks for the edits @nicholasmalaya !

@pzhanggit I ran some test on the JHUTDB dataset today. Using the Turbulence Transformer I see around 2x speed-up! This is only from the hMLP_stem and hMLP_output. Probably further speed-up can be achieved if we replace the conv3D's in the upsampling parts as well.

For the vit_all2all model I see similar runtimes so far with and without the use_linear. I'm not exactly sure why, I can look more into it, but we mainly use the TurbT anyways.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants