|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import itertools |
| 8 | +from dataclasses import dataclass |
| 9 | +from typing import List |
| 10 | + |
| 11 | +import torch |
| 12 | +from tabulate import tabulate |
| 13 | +from tqdm import tqdm |
| 14 | + |
| 15 | +from benchmarks.utils import benchmark_cuda_function_in_microseconds |
| 16 | +from torchao.prototype.moe_training.kernels.mxfp8 import ( |
| 17 | + torch_to_blocked_2d_K_groups, |
| 18 | + triton_mx_block_rearrange_2d_K_groups, |
| 19 | +) |
| 20 | +from torchao.prototype.moe_training.kernels.mxfp8.quant import ( |
| 21 | + triton_mx_block_rearrange_2d_K_groups_naive, |
| 22 | +) |
| 23 | +from torchao.prototype.moe_training.utils import generate_jagged_offs |
| 24 | + |
| 25 | +device = torch.device("cuda") |
| 26 | + |
| 27 | +# Needed since changing args to function causes recompiles |
| 28 | +torch._dynamo.config.cache_size_limit = 1000 |
| 29 | + |
| 30 | + |
| 31 | +@dataclass(frozen=True) |
| 32 | +class ExperimentConfig: |
| 33 | + input_shape: tuple[int] |
| 34 | + num_groups: int |
| 35 | + version: str # "naive" or "parallel" |
| 36 | + |
| 37 | + |
| 38 | +@dataclass(frozen=True) |
| 39 | +class ExperimentResult: |
| 40 | + torch_time_us: float |
| 41 | + triton_time_us: float |
| 42 | + torch_mem_bw_gbps: float |
| 43 | + triton_mem_bw_gbps: float |
| 44 | + |
| 45 | + |
| 46 | +@dataclass(frozen=True) |
| 47 | +class Experiment: |
| 48 | + config: ExperimentConfig |
| 49 | + result: ExperimentResult |
| 50 | + |
| 51 | + |
| 52 | +def get_configs() -> List[ExperimentConfig]: |
| 53 | + # Llama4 shapes. Input activations are scaled along K dim. |
| 54 | + block_size = 32 |
| 55 | + input_shapes = [ |
| 56 | + (5120 // block_size, 16640), |
| 57 | + (5120 // block_size, 131072), |
| 58 | + ] |
| 59 | + num_groups = [16] |
| 60 | + versions = ["naive", "parallel"] |
| 61 | + configs = [] |
| 62 | + for shape, groups, version in itertools.product( |
| 63 | + input_shapes, |
| 64 | + num_groups, |
| 65 | + versions, |
| 66 | + ): |
| 67 | + configs.append( |
| 68 | + ExperimentConfig( |
| 69 | + input_shape=shape, |
| 70 | + num_groups=groups, |
| 71 | + version=version, |
| 72 | + ) |
| 73 | + ) |
| 74 | + return configs |
| 75 | + |
| 76 | + |
| 77 | +def run_experiment(config: ExperimentConfig) -> ExperimentResult: |
| 78 | + input_shape, num_groups, version = ( |
| 79 | + config.input_shape, |
| 80 | + config.num_groups, |
| 81 | + config.version, |
| 82 | + ) |
| 83 | + input_tensor = torch.randint( |
| 84 | + low=0, |
| 85 | + high=256, |
| 86 | + size=input_shape, |
| 87 | + dtype=torch.uint8, |
| 88 | + device=device, |
| 89 | + ) |
| 90 | + |
| 91 | + M, Kg = input_shape |
| 92 | + block_size = 32 |
| 93 | + input_group_offsets = generate_jagged_offs(num_groups, Kg, multiple_of=block_size) |
| 94 | + |
| 95 | + # bench torch |
| 96 | + compiled_run_torch = torch.compile(torch_to_blocked_2d_K_groups) |
| 97 | + torch_out_scales, torch_group_offs = compiled_run_torch( |
| 98 | + input_tensor, |
| 99 | + input_group_offsets, |
| 100 | + block_size=block_size, |
| 101 | + ) |
| 102 | + torch_time_us = benchmark_cuda_function_in_microseconds( |
| 103 | + compiled_run_torch, |
| 104 | + input_tensor, |
| 105 | + input_group_offsets, |
| 106 | + block_size=block_size, |
| 107 | + ) |
| 108 | + |
| 109 | + # bench triton (naive or parallel based on config) |
| 110 | + if version == "naive": |
| 111 | + triton_fn = triton_mx_block_rearrange_2d_K_groups_naive |
| 112 | + else: |
| 113 | + triton_fn = triton_mx_block_rearrange_2d_K_groups |
| 114 | + |
| 115 | + triton_out_scales = triton_fn( |
| 116 | + input_tensor, |
| 117 | + input_group_offsets, |
| 118 | + ) |
| 119 | + triton_time_us = benchmark_cuda_function_in_microseconds( |
| 120 | + triton_fn, |
| 121 | + input_tensor, |
| 122 | + input_group_offsets, |
| 123 | + ) |
| 124 | + |
| 125 | + # mem bw calculations |
| 126 | + bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8 |
| 127 | + bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 |
| 128 | + |
| 129 | + read_bytes = input_tensor.numel() * bytes_per_input_el |
| 130 | + write_bytes = triton_out_scales.numel() * bytes_per_output_el |
| 131 | + |
| 132 | + torch_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6) |
| 133 | + triton_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) |
| 134 | + |
| 135 | + return ExperimentResult( |
| 136 | + torch_time_us=torch_time_us, |
| 137 | + triton_time_us=triton_time_us, |
| 138 | + torch_mem_bw_gbps=torch_mem_bw_gbps, |
| 139 | + triton_mem_bw_gbps=triton_mem_bw_gbps, |
| 140 | + ) |
| 141 | + |
| 142 | + |
| 143 | +def print_results(experiments: List[Experiment]): |
| 144 | + headers = [ |
| 145 | + "version", |
| 146 | + "input_shape", |
| 147 | + "torch_time_us", |
| 148 | + "triton_time_us", |
| 149 | + "torch_mem_bw_gbps", |
| 150 | + "triton_mem_bw_gbps", |
| 151 | + "triton_speedup", |
| 152 | + ] |
| 153 | + rows = [] |
| 154 | + for experiment in experiments: |
| 155 | + input_shape = ( |
| 156 | + f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]})" |
| 157 | + ) |
| 158 | + rows.append( |
| 159 | + [ |
| 160 | + experiment.config.version, |
| 161 | + input_shape, |
| 162 | + experiment.result.torch_time_us, |
| 163 | + experiment.result.triton_time_us, |
| 164 | + round(experiment.result.torch_mem_bw_gbps, 3), |
| 165 | + round(experiment.result.triton_mem_bw_gbps, 3), |
| 166 | + f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x", |
| 167 | + ] |
| 168 | + ) |
| 169 | + print(tabulate(rows, headers=headers)) |
| 170 | + |
| 171 | + |
| 172 | +def main(): |
| 173 | + torch.random.manual_seed(123) |
| 174 | + configs = get_configs() |
| 175 | + results = [] |
| 176 | + for config in tqdm(configs): |
| 177 | + result = run_experiment(config) |
| 178 | + results.append(Experiment(config=config, result=result)) |
| 179 | + |
| 180 | + # Use Tabulate to print results |
| 181 | + print_results(results) |
| 182 | + |
| 183 | + |
| 184 | +if __name__ == "__main__": |
| 185 | + main() |
0 commit comments