Skip to content

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Dec 2, 2025

Stacked PRs:


[mxfp8 moe training] parallelize along col blocks in scale blocked format kernel for groups along K

Changes

  • Add kernel for converting e8m0 scales to blocked format for grouped gemm, that parallelizes across both rows and cols, instead of just row blocks.
  • The purpose of this strategy was to increase occupancy, which was very low in the loop-based implementation.
  • Memory bandwidth utilization is still low for both implementations. However, given we must operate at a granularity of 128x4 at minimum in order to do the per-block swizzle, I don't see how we can further increase occupancy.
  • NCU also flags some bank conflicts in SMEM and some uncoalesced global accesses when reading and writing the scale factors.

Benchmarks


kernel version    input_shape      torch_time_us    triton_time_us    torch_mem_bw_gbps    triton_mem_bw_gbps  triton_speedup
----------------  -------------  ---------------  ----------------  -------------------  --------------------  ----------------
naive             (5120, 512)            132.128            54.912               42.16                101.445  2.41x
parallel          (5120, 512)            132.096           147.84                42.171                37.68   0.89x
naive             (5120, 4096)          5898.82             58.368                7.166               724.211  101.06x
parallel          (5120, 4096)          5909.1             153.632                7.153               275.143  38.46x
naive             (8192, 512)           1762.96             52.224                5.056               170.667  33.76x
parallel          (8192, 512)           1763.68            149.776                5.054                59.508  11.78x
naive             (8192, 4096)         11212.8              99.232                6.032               681.566  113.00x
parallel          (8192, 4096)         11214.8             152.608                6.031               443.182  73.49x
naive             (7168, 512)           1568.77             53.44                 4.971               145.935  29.36x
parallel          (7168, 512)           1565.7             150.176                4.981                51.931  10.43x
naive             (7168, 4096)          9844.88            107.392                6.011               551.056  91.67x
parallel          (7168, 4096)          9832.75            152.576                6.019               387.866  64.44x
naive             (2048, 512)            602.24             51.968                3.7                  42.877  11.59x
parallel          (2048, 512)            601.856           149.024                3.702                14.952  4.04x
naive             (2048, 4096)          2984.74             54.304                5.665               311.364  54.96x
parallel          (2048, 4096)          2988.83            148.32                 5.657               113.999  20.15x

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3416

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 0f325d3 with merge base a6dbf45 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Dec 2, 2025
…rmat kernel for groups along K

stack-info: PR: #3416, branch: danielvegamyhre/stack/85
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/85 branch from 476f86c to 77b0dad Compare December 2, 2025 22:30
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 2, 2025
@danielvegamyhre danielvegamyhre added mx topic: not user facing Use this tag if you don't want this PR to show up in release notes moe labels Dec 2, 2025
danielvegamyhre added a commit that referenced this pull request Dec 2, 2025
…rmat kernel for groups along K

stack-info: PR: #3416, branch: danielvegamyhre/stack/85
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/85 branch from 77b0dad to 02fc19d Compare December 2, 2025 23:33
@danielvegamyhre danielvegamyhre marked this pull request as draft December 2, 2025 23:34
danielvegamyhre added a commit that referenced this pull request Dec 3, 2025
…rmat kernel for groups along K

stack-info: PR: #3416, branch: danielvegamyhre/stack/85
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/85 branch from 02fc19d to fdbe390 Compare December 3, 2025 00:04
danielvegamyhre added a commit that referenced this pull request Dec 3, 2025
…rmat kernel for groups along K

stack-info: PR: #3416, branch: danielvegamyhre/stack/85
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/85 branch from fdbe390 to 29a3c56 Compare December 3, 2025 00:22
…rmat kernel for groups along K

stack-info: PR: #3416, branch: danielvegamyhre/stack/85
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/85 branch from 29a3c56 to 0f325d3 Compare December 3, 2025 00:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. moe mx topic: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants