Skip to content

Commit 0f325d3

Browse files
[mxfp8 moe training] parallelize along col blocks in scale blocked format kernel for groups along K
stack-info: PR: #3416, branch: danielvegamyhre/stack/85
1 parent a6dbf45 commit 0f325d3

File tree

3 files changed

+369
-9
lines changed

3 files changed

+369
-9
lines changed
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import itertools
8+
from dataclasses import dataclass
9+
from typing import List
10+
11+
import torch
12+
from tabulate import tabulate
13+
from tqdm import tqdm
14+
15+
from benchmarks.utils import benchmark_cuda_function_in_microseconds
16+
from torchao.prototype.moe_training.kernels.mxfp8 import (
17+
torch_to_blocked_2d_K_groups,
18+
triton_mx_block_rearrange_2d_K_groups,
19+
)
20+
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
21+
triton_mx_block_rearrange_2d_K_groups_naive,
22+
)
23+
from torchao.prototype.moe_training.utils import generate_jagged_offs
24+
25+
device = torch.device("cuda")
26+
27+
# Needed since changing args to function causes recompiles
28+
torch._dynamo.config.cache_size_limit = 1000
29+
30+
31+
@dataclass(frozen=True)
32+
class ExperimentConfig:
33+
input_shape: tuple[int]
34+
num_groups: int
35+
version: str # "naive" or "parallel"
36+
37+
38+
@dataclass(frozen=True)
39+
class ExperimentResult:
40+
torch_time_us: float
41+
triton_time_us: float
42+
torch_mem_bw_gbps: float
43+
triton_mem_bw_gbps: float
44+
45+
46+
@dataclass(frozen=True)
47+
class Experiment:
48+
config: ExperimentConfig
49+
result: ExperimentResult
50+
51+
52+
def get_configs() -> List[ExperimentConfig]:
53+
# Llama4 and DSV3 671b shapes. Input activations are scaled along the total_M dim, which contains all the token groups.
54+
block_size = 32
55+
input_shapes = [
56+
(5120, 16384 // block_size),
57+
(5120, 131072 // block_size),
58+
(8192, 16384 // block_size),
59+
(8192, 131072 // block_size),
60+
(7168, 16384 // block_size),
61+
(7168, 131072 // block_size),
62+
(2048, 16384 // block_size),
63+
(2048, 131072 // block_size),
64+
]
65+
num_groups = [16]
66+
versions = ["naive", "parallel"]
67+
configs = []
68+
for shape, groups, version in itertools.product(
69+
input_shapes,
70+
num_groups,
71+
versions,
72+
):
73+
configs.append(
74+
ExperimentConfig(
75+
input_shape=shape,
76+
num_groups=groups,
77+
version=version,
78+
)
79+
)
80+
return configs
81+
82+
83+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
84+
input_shape, num_groups, version = (
85+
config.input_shape,
86+
config.num_groups,
87+
config.version,
88+
)
89+
input_tensor = torch.randint(
90+
low=0,
91+
high=256,
92+
size=input_shape,
93+
dtype=torch.uint8,
94+
device=device,
95+
)
96+
97+
M, Kg = input_shape
98+
block_size = 32
99+
input_group_offsets = generate_jagged_offs(num_groups, Kg, multiple_of=block_size)
100+
101+
# bench torch
102+
compiled_run_torch = torch.compile(torch_to_blocked_2d_K_groups)
103+
torch_out_scales, torch_group_offs = compiled_run_torch(
104+
input_tensor,
105+
input_group_offsets,
106+
block_size=block_size,
107+
)
108+
torch_time_us = benchmark_cuda_function_in_microseconds(
109+
compiled_run_torch,
110+
input_tensor,
111+
input_group_offsets,
112+
block_size=block_size,
113+
)
114+
115+
# bench triton (naive or parallel based on config)
116+
if version == "naive":
117+
triton_fn = triton_mx_block_rearrange_2d_K_groups_naive
118+
else:
119+
triton_fn = triton_mx_block_rearrange_2d_K_groups
120+
121+
triton_out_scales = triton_fn(
122+
input_tensor,
123+
input_group_offsets,
124+
)
125+
triton_time_us = benchmark_cuda_function_in_microseconds(
126+
triton_fn,
127+
input_tensor,
128+
input_group_offsets,
129+
)
130+
131+
# mem bw calculations
132+
bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
133+
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
134+
135+
read_bytes = input_tensor.numel() * bytes_per_input_el
136+
write_bytes = triton_out_scales.numel() * bytes_per_output_el
137+
138+
torch_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6)
139+
triton_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6)
140+
141+
return ExperimentResult(
142+
torch_time_us=torch_time_us,
143+
triton_time_us=triton_time_us,
144+
torch_mem_bw_gbps=torch_mem_bw_gbps,
145+
triton_mem_bw_gbps=triton_mem_bw_gbps,
146+
)
147+
148+
149+
def print_results(experiments: List[Experiment]):
150+
headers = [
151+
"kernel version",
152+
"input_shape",
153+
"torch_time_us",
154+
"triton_time_us",
155+
"torch_mem_bw_gbps",
156+
"triton_mem_bw_gbps",
157+
"triton_speedup",
158+
]
159+
rows = []
160+
for experiment in experiments:
161+
input_shape = (
162+
f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]})"
163+
)
164+
rows.append(
165+
[
166+
experiment.config.version,
167+
input_shape,
168+
experiment.result.torch_time_us,
169+
experiment.result.triton_time_us,
170+
round(experiment.result.torch_mem_bw_gbps, 3),
171+
round(experiment.result.triton_mem_bw_gbps, 3),
172+
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
173+
]
174+
)
175+
print(tabulate(rows, headers=headers))
176+
177+
178+
def main():
179+
torch.random.manual_seed(123)
180+
configs = get_configs()
181+
results = []
182+
for config in tqdm(configs):
183+
result = run_experiment(config)
184+
results.append(Experiment(config=config, result=result))
185+
186+
# Use Tabulate to print results
187+
print_results(results)
188+
189+
190+
if __name__ == "__main__":
191+
main()

0 commit comments

Comments
 (0)