Skip to content

Commit 02fc19d

Browse files
[mxfp8 moe training] parallelize along col blocks in scale blocked format kernel for groups along K
stack-info: PR: #3416, branch: danielvegamyhre/stack/85
1 parent 5977905 commit 02fc19d

File tree

3 files changed

+366
-9
lines changed

3 files changed

+366
-9
lines changed
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import itertools
8+
from dataclasses import dataclass
9+
from typing import List
10+
11+
import torch
12+
from tabulate import tabulate
13+
from tqdm import tqdm
14+
15+
from benchmarks.utils import benchmark_cuda_function_in_microseconds
16+
from torchao.prototype.moe_training.kernels.mxfp8 import (
17+
torch_to_blocked_2d_K_groups,
18+
triton_mx_block_rearrange_2d_K_groups,
19+
)
20+
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
21+
triton_mx_block_rearrange_2d_K_groups_naive,
22+
)
23+
from torchao.prototype.moe_training.utils import generate_jagged_offs
24+
25+
device = torch.device("cuda")
26+
27+
# Needed since changing args to function causes recompiles
28+
torch._dynamo.config.cache_size_limit = 1000
29+
30+
31+
@dataclass(frozen=True)
32+
class ExperimentConfig:
33+
input_shape: tuple[int]
34+
num_groups: int
35+
version: str # "naive" or "parallel"
36+
37+
38+
@dataclass(frozen=True)
39+
class ExperimentResult:
40+
torch_time_us: float
41+
triton_time_us: float
42+
torch_mem_bw_gbps: float
43+
triton_mem_bw_gbps: float
44+
45+
46+
@dataclass(frozen=True)
47+
class Experiment:
48+
config: ExperimentConfig
49+
result: ExperimentResult
50+
51+
52+
def get_configs() -> List[ExperimentConfig]:
53+
# Llama4 shapes. Input activations are scaled along K dim.
54+
block_size = 32
55+
input_shapes = [
56+
(5120 // block_size, 16640),
57+
(5120 // block_size, 131072),
58+
]
59+
num_groups = [16]
60+
versions = ["naive", "parallel"]
61+
configs = []
62+
for shape, groups, version in itertools.product(
63+
input_shapes,
64+
num_groups,
65+
versions,
66+
):
67+
configs.append(
68+
ExperimentConfig(
69+
input_shape=shape,
70+
num_groups=groups,
71+
version=version,
72+
)
73+
)
74+
return configs
75+
76+
77+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
78+
input_shape, num_groups, version = (
79+
config.input_shape,
80+
config.num_groups,
81+
config.version,
82+
)
83+
input_tensor = torch.randint(
84+
low=0,
85+
high=256,
86+
size=input_shape,
87+
dtype=torch.uint8,
88+
device=device,
89+
)
90+
91+
M, Kg = input_shape
92+
block_size = 32
93+
input_group_offsets = generate_jagged_offs(num_groups, Kg, multiple_of=block_size)
94+
95+
# bench torch
96+
compiled_run_torch = torch.compile(torch_to_blocked_2d_K_groups)
97+
torch_out_scales, torch_group_offs = compiled_run_torch(
98+
input_tensor,
99+
input_group_offsets,
100+
block_size=block_size,
101+
)
102+
torch_time_us = benchmark_cuda_function_in_microseconds(
103+
compiled_run_torch,
104+
input_tensor,
105+
input_group_offsets,
106+
block_size=block_size,
107+
)
108+
109+
# bench triton (naive or parallel based on config)
110+
if version == "naive":
111+
triton_fn = triton_mx_block_rearrange_2d_K_groups_naive
112+
else:
113+
triton_fn = triton_mx_block_rearrange_2d_K_groups
114+
115+
triton_out_scales = triton_fn(
116+
input_tensor,
117+
input_group_offsets,
118+
)
119+
triton_time_us = benchmark_cuda_function_in_microseconds(
120+
triton_fn,
121+
input_tensor,
122+
input_group_offsets,
123+
)
124+
125+
# mem bw calculations
126+
bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
127+
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
128+
129+
read_bytes = input_tensor.numel() * bytes_per_input_el
130+
write_bytes = triton_out_scales.numel() * bytes_per_output_el
131+
132+
torch_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6)
133+
triton_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6)
134+
135+
return ExperimentResult(
136+
torch_time_us=torch_time_us,
137+
triton_time_us=triton_time_us,
138+
torch_mem_bw_gbps=torch_mem_bw_gbps,
139+
triton_mem_bw_gbps=triton_mem_bw_gbps,
140+
)
141+
142+
143+
def print_results(experiments: List[Experiment]):
144+
headers = [
145+
"version",
146+
"input_shape",
147+
"torch_time_us",
148+
"triton_time_us",
149+
"torch_mem_bw_gbps",
150+
"triton_mem_bw_gbps",
151+
"triton_speedup",
152+
]
153+
rows = []
154+
for experiment in experiments:
155+
input_shape = (
156+
f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]})"
157+
)
158+
rows.append(
159+
[
160+
experiment.config.version,
161+
input_shape,
162+
experiment.result.torch_time_us,
163+
experiment.result.triton_time_us,
164+
round(experiment.result.torch_mem_bw_gbps, 3),
165+
round(experiment.result.triton_mem_bw_gbps, 3),
166+
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
167+
]
168+
)
169+
print(tabulate(rows, headers=headers))
170+
171+
172+
def main():
173+
torch.random.manual_seed(123)
174+
configs = get_configs()
175+
results = []
176+
for config in tqdm(configs):
177+
result = run_experiment(config)
178+
results.append(Experiment(config=config, result=result))
179+
180+
# Use Tabulate to print results
181+
print_results(results)
182+
183+
184+
if __name__ == "__main__":
185+
main()

0 commit comments

Comments
 (0)