From 465492fb5b8870acb6c7f42353d5723f18011ac1 Mon Sep 17 00:00:00 2001 From: llcnt Date: Mon, 8 Dec 2025 11:11:51 +0000 Subject: [PATCH 01/26] feat: draft kernel tuning based on vllm fns --- src/pruna/algorithms/moe_kernel_tuner.py | 120 +++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 src/pruna/algorithms/moe_kernel_tuner.py diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py new file mode 100644 index 00000000..8a558f7c --- /dev/null +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -0,0 +1,120 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import functools +from collections.abc import Iterable +from typing import Any, Dict, Optional, Tuple + +import torch +import ray +import ray.experimental.tqdm_ray as tqdm_ray + +from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase +from pruna.algorithms.base.tags import AlgorithmTag as tags +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.model_checks import is_moe_lm, is_transformers_pipeline_with_moe_lm +from pruna.logging.logger import pruna_logger + + +class MoeKernelTuner(PrunaAlgorithmBase): + """ + Tune the MoE kernel for the model. + + Uses vLLM to tune the MoE kernel of the model. + """ + + algorithm_name: str = "moe_kernel_tuner" + group_tags: list[str] = [tags.KERNEL] + save_fn :None = None + references: dict[str, str] = { + "GitHub": "https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py", + } + tokenizer_required: bool = False + processor_required: bool = False + runs_on: list[str] = ["cuda", "accelerate"] + dataset_required: bool = False + compatible_before: Iterable[str] = [tags.KERNEL, tags.QUANTIZER, tags.PRUNER, tags.CACHER, tags.FACTORIZER, tags.BATCHER, tags.COMPILER] + compatible_after: Iterable[str] = [tags.KERNEL, tags.QUANTIZER, tags.PRUNER, tags.CACHER, tags.FACTORIZER, tags.BATCHER, tags.COMPILER] + required_install = "``uv pip install vllm``" + + def model_check_fn(self, model: Any) -> bool: + """ + Check if the model is a MoE model. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is a valid model for the algorithm, False otherwise. + """ + # Hunyuan3-image is a MoE model, but not depending on mixtral + if model.__class__.__name__ == "HunyuanImage3ForCausalMM": + return True + else: + return is_moe_lm(model) or is_transformers_pipeline_with_moe_lm(model) + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + """ + Wrap the model to use flash_attn3 where possible. + + Parameters + ---------- + model : Any + The model to wrap. + smash_config : SmashConfigPrefixWrapper + The configuration for the application of the algorithm. + + Returns + ------- + Any + The wrapped model. + """ + imported_packages = self.import_algorithm_packages() + + #TODO: Implement the MoE kernel tuning. + return model + + def import_algorithm_packages(self) -> Dict[str, Any]: + """ + Import the algorithm packages. + + Returns + ------- + Dict[str, Any] + The algorithm packages. + """ + from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + _get_config_dtype_str, + ) + import vllm.model_executor.layers.fused_moe.fused_moe as fused_moe + import vllm.platforms as vllm_platforms + from vllm.transformers_utils.config import get_config + from vllm.triton_utils import triton + from vllm.utils.argparse_utils import FlexibleArgumentParser + return dict( + FusedMoEQuantConfig=FusedMoEQuantConfig, + _get_config_dtype_str=_get_config_dtype_str, + FusedMoE=fused_moe, + vllm_platforms=vllm_platforms, + get_config=get_config, + triton=triton, + FlexibleArgumentParser=FlexibleArgumentParser, + tqdm_ray=tqdm_ray, + ) From c655e076609475e703d57405caf8415880987852 Mon Sep 17 00:00:00 2001 From: llcnt Date: Mon, 8 Dec 2025 11:12:28 +0000 Subject: [PATCH 02/26] feat: draft kernel tuning based on vllm fns --- src/pruna/algorithms/moe_kernel_tuner.py | 37 ++++++++++++++++-------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index 8a558f7c..59e794e4 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -13,19 +13,15 @@ # limitations under the License. from __future__ import annotations -import functools from collections.abc import Iterable -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict -import torch -import ray import ray.experimental.tqdm_ray as tqdm_ray from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase from pruna.algorithms.base.tags import AlgorithmTag as tags from pruna.config.smash_config import SmashConfigPrefixWrapper from pruna.engine.model_checks import is_moe_lm, is_transformers_pipeline_with_moe_lm -from pruna.logging.logger import pruna_logger class MoeKernelTuner(PrunaAlgorithmBase): @@ -37,7 +33,7 @@ class MoeKernelTuner(PrunaAlgorithmBase): algorithm_name: str = "moe_kernel_tuner" group_tags: list[str] = [tags.KERNEL] - save_fn :None = None + save_fn: None = None references: dict[str, str] = { "GitHub": "https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py", } @@ -45,8 +41,24 @@ class MoeKernelTuner(PrunaAlgorithmBase): processor_required: bool = False runs_on: list[str] = ["cuda", "accelerate"] dataset_required: bool = False - compatible_before: Iterable[str] = [tags.KERNEL, tags.QUANTIZER, tags.PRUNER, tags.CACHER, tags.FACTORIZER, tags.BATCHER, tags.COMPILER] - compatible_after: Iterable[str] = [tags.KERNEL, tags.QUANTIZER, tags.PRUNER, tags.CACHER, tags.FACTORIZER, tags.BATCHER, tags.COMPILER] + compatible_before: Iterable[str] = [ + tags.KERNEL, + tags.QUANTIZER, + tags.PRUNER, + tags.CACHER, + tags.FACTORIZER, + tags.BATCHER, + tags.COMPILER, + ] + compatible_after: Iterable[str] = [ + tags.KERNEL, + tags.QUANTIZER, + tags.PRUNER, + tags.CACHER, + tags.FACTORIZER, + tags.BATCHER, + tags.COMPILER, + ] required_install = "``uv pip install vllm``" def model_check_fn(self, model: Any) -> bool: @@ -87,7 +99,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: """ imported_packages = self.import_algorithm_packages() - #TODO: Implement the MoE kernel tuning. + # TODO: Implement the MoE kernel tuning. return model def import_algorithm_packages(self) -> Dict[str, Any]: @@ -99,16 +111,17 @@ def import_algorithm_packages(self) -> Dict[str, Any]: Dict[str, Any] The algorithm packages. """ + import vllm.model_executor.layers.fused_moe.fused_moe as fused_moe + import vllm.platforms as vllm_platforms from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, _get_config_dtype_str, ) - import vllm.model_executor.layers.fused_moe.fused_moe as fused_moe - import vllm.platforms as vllm_platforms from vllm.transformers_utils.config import get_config from vllm.triton_utils import triton from vllm.utils.argparse_utils import FlexibleArgumentParser - return dict( + + return dict( FusedMoEQuantConfig=FusedMoEQuantConfig, _get_config_dtype_str=_get_config_dtype_str, FusedMoE=fused_moe, From 5112b3646c503d2b5dfb8ce6cd23a699013689b7 Mon Sep 17 00:00:00 2001 From: llcnt Date: Mon, 8 Dec 2025 16:39:09 +0000 Subject: [PATCH 03/26] feat: add benchmark fn and saving draft --- src/pruna/algorithms/moe_kernel_tuner.py | 452 ++++++++++++++++++++++- 1 file changed, 446 insertions(+), 6 deletions(-) diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index 59e794e4..14b2a0df 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -14,14 +14,24 @@ from __future__ import annotations from collections.abc import Iterable -from typing import Any, Dict +from typing import Any, Dict, TypedDict import ray.experimental.tqdm_ray as tqdm_ray +import ray +import time +import torch +from ConfigSpace import CategoricalHyperparameter, OrdinalHyperparameter +from contextlib import nullcontext +from datetime import datetime +import os +import json +from itertools import product from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase from pruna.algorithms.base.tags import AlgorithmTag as tags from pruna.config.smash_config import SmashConfigPrefixWrapper from pruna.engine.model_checks import is_moe_lm, is_transformers_pipeline_with_moe_lm +from pruna.logging.logger import pruna_logger class MoeKernelTuner(PrunaAlgorithmBase): @@ -61,6 +71,36 @@ class MoeKernelTuner(PrunaAlgorithmBase): ] required_install = "``uv pip install vllm``" + def get_hyperparameters(self) -> list: + """ + Configure all algorithm-specific hyperparameters with ConfigSpace. + + Returns + ------- + list + The hyperparameters. + """ + return [ + CategoricalHyperparameter( + "compute_dtype", + choices=["bfloat16", "float16"], + default_value="bfloat16", + meta=dict(desc="Compute dtype to use."), + ), + CategoricalHyperparameter( + "weight_dtype", + choices=["fp8_w8a8", "int8_w8a16"], + default_value="fp8_w8a8", + meta=dict(desc="Dtype to use for the weights (and activations)."), + ), + OrdinalHyperparameter( + "tensor_parallel_size", + sequence=[1, 2, 4, 8, 16, 32], + default_value=1, + meta=dict(desc="Tensor parallel size to use if the model can not fit on a single GPU."), + ), + ] + def model_check_fn(self, model: Any) -> bool: """ Check if the model is a MoE model. @@ -99,7 +139,147 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: """ imported_packages = self.import_algorithm_packages() - # TODO: Implement the MoE kernel tuning. + @ray.remote(num_gpus=1) + class BenchmarkWorker: + def __init__(self, seed: int) -> None: + torch.set_default_device("cuda") + imported_packages["vllm_platforms"].current_platform.seed_everything(seed) + self.seed = seed + self.device_id = int(ray.get_gpu_ids()[0]) + + def tune( + self, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + search_space: list[dict[str, int]], + block_quant_shape: list[int], + use_deep_gemm: bool, + ) -> dict[str, int]: + best_config = None + best_time = float("inf") + + need_device_guard = False + + with torch.cuda.device(self.device_id) if need_device_guard else nullcontext(): + for config in tqdm_ray(search_space): + try: + kernel_time = benchmark_config( + config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=20, + block_quant_shape=block_quant_shape, + use_deep_gemm=use_deep_gemm, + ) + except imported_packages["triton"].runtime.autotuner.OutOfResources: + # Some configurations may be invalid and fail to compile. + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config + now = datetime.now() + pruna_logger.info(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") + assert best_config is not None + return best_config + + E = model.num_experts if is_moe_lm(model)or is_transformers_pipeline_with_moe_lm(model) else model.num_experts # number of experts + topk = model.num_experts_per_tok # number of active experts per token + intermediate_size = model.intermediate_size # 3072 # FFN intermediate size + hidden_size = model.hidden_size #4096 # model hidden dim + assert intermediate_size % smash_config["tensor_parallel_size"] == 0, ( + f"intermediate_size {intermediate_size} is not divisible by tp " + f"{smash_config['tensor_parallel_size']}." + ) + shard_intermediate_size = 2 * intermediate_size // smash_config["tensor_parallel_size"] + dtype = smash_config["compute_dtype"] + use_fp8_w8a8 = smash_config["weight_dtype"] == "fp8_w8a8" + use_int8_w8a16 = smash_config["weight_dtype"] == "int8_w8a16" + FP8_DTYPE = imported_packages["vllm_platforms"].current_platform.fp8_dtype() + batch_sizes = [ + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, + ] + + ray.init() + num_gpus = int(ray.available_resources()["GPU"]) + workers = [BenchmarkWorker.remote(0) for _ in range(num_gpus)] + + is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) + search_space = get_configs_compute_bound(is_fp16, None) + pruna_logger.info(f"Start tuning over {len(search_space)} configurations...") + + start = time.time() + outputs = [] + worker_idx = 0 + for batch_size in batch_sizes: + input_args = ( + batch_size, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + search_space, + None, + False, + ) + worker = workers[worker_idx] + worker_method = getattr(worker, "tune") + output = worker_method.remote(*input_args) + outputs.append(output) + worker_idx = (worker_idx + 1) % num_gpus + configs = ray.get(outputs) + + best_configs = { + M: sort_config(config) for M, config in zip(batch_sizes, configs) + } + self.save_configs( + best_configs, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + None, + args.save_dir, + imported_packages, + ) + end = time.time() + pruna_logger.info(f"Tuning took {end - start:.2f} seconds") + return model def import_algorithm_packages(self) -> Dict[str, Any]: @@ -117,17 +297,277 @@ def import_algorithm_packages(self) -> Dict[str, Any]: FusedMoEQuantConfig, _get_config_dtype_str, ) - from vllm.transformers_utils.config import get_config + from vllm.model_executor.layers.fused_moe import override_config from vllm.triton_utils import triton - from vllm.utils.argparse_utils import FlexibleArgumentParser return dict( FusedMoEQuantConfig=FusedMoEQuantConfig, _get_config_dtype_str=_get_config_dtype_str, FusedMoE=fused_moe, vllm_platforms=vllm_platforms, - get_config=get_config, triton=triton, - FlexibleArgumentParser=FlexibleArgumentParser, tqdm_ray=tqdm_ray, + override_config=override_config, + ) + + def save_configs( + configs: dict[int, BenchmarkConfig], + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + block_quant_shape: list[int], + save_dir: str, + imported_packages: Dict[str, Any], + ) -> None: + dtype_str = imported_packages["_get_config_dtype_str"]( + dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 + ) + + # NOTE(woosuk): The current naming convention uses w2.shape[2], which + # is the intermediate size after silu_and_mul. + filename = imported_packages["fused_moe"].get_config_file_name( + num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape + ) + + # We want to save at 3 different places: + # 1. The cache of vllm + # 2. The cache of kernels hub + # 3. The smashconfig (to be reused once mode is smashed and saved). + + + os.makedirs(save_dir, exist_ok=True) + filename = os.path.join(save_dir, filename) + pruna_logger.info(f"Writing best config to {filename}...") + with open(filename, "w") as f: + json.dump({"triton_version": imported_packages["triton"].__version__, **configs}, f, indent=4) + f.write("\n") + + +class BenchmarkConfig(TypedDict): + BLOCK_SIZE_M: int + BLOCK_SIZE_N: int + BLOCK_SIZE_K: int + GROUP_SIZE_M: int + num_warps: int + num_stages: int + +def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: + return { + "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], + "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], + "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], + "GROUP_SIZE_M": config["GROUP_SIZE_M"], + "num_warps": config["num_warps"], + "num_stages": config["num_stages"], + **( + {"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {} + ), + **( + {"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]} + if "matrix_instr_nonkdim" in config + else {} + ), + **({"kpack": config["kpack"]} if "kpack" in config else {}), + } + + +def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int]]: + configs: list[BenchmarkConfig] = [] + + # Reduced search space for faster tuning. + block_m_range = [16, 32, 64, 128, 256] + block_n_range = [32, 64, 128, 256] + block_k_range = [64, 128, 256] + num_warps_range = [4, 8] + group_m_range = [1, 16, 32, 64] + num_stage_range = [2, 3, 4, 5] + + param_ranges = { + "BLOCK_SIZE_M": block_m_range, + "BLOCK_SIZE_N": block_n_range, + "BLOCK_SIZE_K": block_k_range, + "GROUP_SIZE_M": group_m_range, + "num_warps": num_warps_range, + "num_stages": num_stage_range, + } + + keys, values = zip(*param_ranges.items()) + for config_values in product(*values): + config = dict(zip(keys, config_values)) + configs.append(config) + + # Remove configs that are not compatible with fp8 block quantization + # BLOCK_SIZE_K must be a multiple of block_k + # BLOCK_SIZE_N must be a multiple of block_n + if block_quant_shape is not None and not use_fp16: + block_n, block_k = block_quant_shape[0], block_quant_shape[1] + for config in configs[:]: + if ( + config["BLOCK_SIZE_K"] % block_k != 0 + or config["BLOCK_SIZE_N"] % block_n != 0 + ): + configs.remove(config) + return configs + + +def benchmark_config( + config: BenchmarkConfig, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + block_quant_shape: list[int] = None, + use_deep_gemm: bool = False, + imported_packages: Dict[str, Any] = None, +) -> float: + init_dtype = torch.float16 if use_fp8_w8a8 else dtype + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + if use_int8_w8a16: + w1 = torch.randint( + -127, + 127, + ( + num_experts, + shard_intermediate_size, + hidden_size, + ), + dtype=torch.int8, + ) + w2 = torch.randint( + -127, + 127, + ( + num_experts, + hidden_size, + shard_intermediate_size // 2, + ), + dtype=torch.int8, + ) + else: + w1 = torch.randn( + num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype + ) + w2 = torch.randn( + num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype + ) + gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32) + + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + if use_int8_w8a16: + w1_scale = torch.randn( + (num_experts, 2 * shard_intermediate_size), dtype=torch.float32 + ) + w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) + if use_deep_gemm: + # we use the default block shape for deepgemm + block_quant_shape = [128, 128] + if use_fp8_w8a8: + if block_quant_shape: + block_n, block_k = block_quant_shape[0], block_quant_shape[1] + E = num_experts + N = shard_intermediate_size // 2 + K = hidden_size + factor_for_scale = 1e-2 + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + w1_scale = ( + torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) + * factor_for_scale + ) + w2_scale = ( + torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + * factor_for_scale + ) + else: + w1_scale = torch.randn(num_experts, dtype=torch.float32) + w2_scale = torch.randn(num_experts, dtype=torch.float32) + + a1_scale = torch.randn(1, dtype=torch.float32) + a2_scale = torch.randn(1, dtype=torch.float32) + + w1 = w1.to(imported_packages["vllm_platforms"].current_platform.fp8_dtype()) + w2 = w2.to(imported_packages["vllm_platforms"].current_platform.fp8_dtype()) + + input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) + + def prepare(i: int): + input_gating.copy_(gating_output[i]) + + def run(): + if use_fp8_w8a8: + quant_dtype = torch.float8_e4m3fn + elif use_int8_w8a16: + quant_dtype = torch.int8 + else: + quant_dtype = None + + quant_config = imported_packages["FusedMoEQuantConfig"].make( + quant_dtype=quant_dtype, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_quant_shape, ) + + with imported_packages["override_config"](config): + topk_weights, topk_ids, token_expert_indices = imported_packages["FusedMoE"].fused_topk( + x, input_gating, topk, renormalize=not use_deep_gemm + ) + return imported_packages["FusedMoE"].fused_experts( + x, + w1, + w2, + topk_weights, + topk_ids, + inplace=True, + quant_config=quant_config, + allow_deep_gemm=use_deep_gemm, + ) + + # JIT compilation & warmup + run() + torch.cuda.synchronize() + + # Capture 10 invocations with CUDA graph + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(10): + run() + torch.cuda.synchronize() + + # Warmup + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + prepare(i) + torch.cuda.synchronize() + + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + graph.reset() + return avg From 617247df276bb2261802d19b05637d74d48de54a Mon Sep 17 00:00:00 2001 From: llcnt Date: Thu, 18 Dec 2025 16:27:04 +0000 Subject: [PATCH 04/26] feat: clean and simplify tuning and config saving --- src/pruna/algorithms/moe_kernel_tuner.py | 432 ++++++++++++++++------- 1 file changed, 305 insertions(+), 127 deletions(-) diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index 14b2a0df..6daaac57 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -21,22 +21,24 @@ import time import torch from ConfigSpace import CategoricalHyperparameter, OrdinalHyperparameter -from contextlib import nullcontext from datetime import datetime import os import json from itertools import product +from importlib.util import find_spec from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase from pruna.algorithms.base.tags import AlgorithmTag as tags +from pruna.config.hyperparameters import UnconstrainedHyperparameter from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.load import LOAD_FUNCTIONS from pruna.engine.model_checks import is_moe_lm, is_transformers_pipeline_with_moe_lm from pruna.logging.logger import pruna_logger class MoeKernelTuner(PrunaAlgorithmBase): """ - Tune the MoE kernel for the model. + Tune the MoE Triton kernel for the model. Uses vLLM to tune the MoE kernel of the model. """ @@ -99,6 +101,28 @@ def get_hyperparameters(self) -> list: default_value=1, meta=dict(desc="Tensor parallel size to use if the model can not fit on a single GPU."), ), + UnconstrainedHyperparameter( + "path_to_huggingface_hub_cache", + default_value="~", + meta=dict( + desc=( + "Path to the Hugging Face Hub cache directory " + "(that contains `kernels` configs). If not provided, " + "the cache will be saved in the current working directory." + ) + ), + ), + UnconstrainedHyperparameter( + "path_to_vllm_cache", + default_value="vllm/model_executor/layers/fused_moe/configs", + meta=dict(desc="Path to the vLLM MoE configs directory."), + ), + OrdinalHyperparameter( + "num_iters", + sequence=[1, 20, 50, 100], + default_value=20, + meta=dict(desc="Number of iterations to average the kernel times on."), + ), ] def model_check_fn(self, model: Any) -> bool: @@ -123,7 +147,7 @@ def model_check_fn(self, model: Any) -> bool: def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: """ - Wrap the model to use flash_attn3 where possible. + Tune the MoE Triton kernel for the model. Parameters ---------- @@ -135,79 +159,33 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: Returns ------- Any - The wrapped model. + The untouched model. """ + if is_transformers_pipeline_with_moe_lm(model): + return self._apply_to_model_within_transformers_pipeline(model, smash_config) + imported_packages = self.import_algorithm_packages() - @ray.remote(num_gpus=1) - class BenchmarkWorker: - def __init__(self, seed: int) -> None: - torch.set_default_device("cuda") - imported_packages["vllm_platforms"].current_platform.seed_everything(seed) - self.seed = seed - self.device_id = int(ray.get_gpu_ids()[0]) - - def tune( - self, - num_tokens: int, - num_experts: int, - shard_intermediate_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - search_space: list[dict[str, int]], - block_quant_shape: list[int], - use_deep_gemm: bool, - ) -> dict[str, int]: - best_config = None - best_time = float("inf") - - need_device_guard = False - - with torch.cuda.device(self.device_id) if need_device_guard else nullcontext(): - for config in tqdm_ray(search_space): - try: - kernel_time = benchmark_config( - config, - num_tokens, - num_experts, - shard_intermediate_size, - hidden_size, - topk, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - num_iters=20, - block_quant_shape=block_quant_shape, - use_deep_gemm=use_deep_gemm, - ) - except imported_packages["triton"].runtime.autotuner.OutOfResources: - # Some configurations may be invalid and fail to compile. - continue - - if kernel_time < best_time: - best_time = kernel_time - best_config = config - now = datetime.now() - pruna_logger.info(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") - assert best_config is not None - return best_config - - E = model.num_experts if is_moe_lm(model)or is_transformers_pipeline_with_moe_lm(model) else model.num_experts # number of experts - topk = model.num_experts_per_tok # number of active experts per token - intermediate_size = model.intermediate_size # 3072 # FFN intermediate size - hidden_size = model.hidden_size #4096 # model hidden dim + # (i) Get the MoE parameters + model_config = model.config + if model_config is None: + raise ValueError(f"Model {model.__class__.__name__} has no config.") + E = model_config.num_experts # number of experts + topk = model_config.num_experts_per_tok if is_moe_lm(model) else model_config.moe_topk[0] # number of active experts per token + intermediate_size = model_config.intermediate_size # 3072 # FFN intermediate size + hidden_size = model_config.hidden_size #4096 # model hidden dim assert intermediate_size % smash_config["tensor_parallel_size"] == 0, ( f"intermediate_size {intermediate_size} is not divisible by tp " f"{smash_config['tensor_parallel_size']}." ) shard_intermediate_size = 2 * intermediate_size // smash_config["tensor_parallel_size"] + + # (ii) Get the compute parameters dtype = smash_config["compute_dtype"] use_fp8_w8a8 = smash_config["weight_dtype"] == "fp8_w8a8" use_int8_w8a16 = smash_config["weight_dtype"] == "int8_w8a16" - FP8_DTYPE = imported_packages["vllm_platforms"].current_platform.fp8_dtype() + + # (iii) Tune the kernel over a range of batch sizes batch_sizes = [ 1, 2, @@ -230,20 +208,17 @@ def tune( ] ray.init() - num_gpus = int(ray.available_resources()["GPU"]) - workers = [BenchmarkWorker.remote(0) for _ in range(num_gpus)] is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) - search_space = get_configs_compute_bound(is_fp16, None) + search_space = get_configs_compute_bound(is_fp16) pruna_logger.info(f"Start tuning over {len(search_space)} configurations...") start = time.time() outputs = [] - worker_idx = 0 for batch_size in batch_sizes: - input_args = ( - batch_size, - E, + output = tune.remote( + batch_size, # num_tokens + E, # num_experts shard_intermediate_size, hidden_size, topk, @@ -251,35 +226,47 @@ def tune( use_fp8_w8a8, use_int8_w8a16, search_space, - None, - False, + None, # we don't suport block quantization for now + False, # not use_deep_gemm + imported_packages, + 0, # random seed + smash_config["num_iters"], ) - worker = workers[worker_idx] - worker_method = getattr(worker, "tune") - output = worker_method.remote(*input_args) outputs.append(output) - worker_idx = (worker_idx + 1) % num_gpus + configs = ray.get(outputs) + # (iv) Sort the configs by batch size and save the best configs best_configs = { M: sort_config(config) for M, config in zip(batch_sizes, configs) } - self.save_configs( + # save configs in caches (for hf and vllm) + save_configs( best_configs, E, shard_intermediate_size, - hidden_size, - topk, dtype, use_fp8_w8a8, use_int8_w8a16, None, - args.save_dir, + smash_config["path_to_huggingface_hub_cache"], + smash_config["path_to_vllm_cache"], imported_packages, ) + # attached configs to the smash config + smash_config["best_configs_moe_kernel"] = best_configs + # attached hyperparameters to the smash config for loading + smash_config["num_experts"] = E + smash_config["shard_intermediate_size"] = shard_intermediate_size + smash_config["dtype"] = dtype + smash_config["use_fp8_w8a8"] = use_fp8_w8a8 + smash_config["use_int8_w8a16"] = use_int8_w8a16 + # attached load function to the smash config for loading + smash_config.load_fns.append(LOAD_FUNCTIONS.moe_kernel_tuner.name) end = time.time() pruna_logger.info(f"Tuning took {end - start:.2f} seconds") + # (v) Return the model (untouched; only the configs are saved) return model def import_algorithm_packages(self) -> Dict[str, Any]: @@ -299,6 +286,7 @@ def import_algorithm_packages(self) -> Dict[str, Any]: ) from vllm.model_executor.layers.fused_moe import override_config from vllm.triton_utils import triton + import vllm.envs as envs return dict( FusedMoEQuantConfig=FusedMoEQuantConfig, @@ -306,12 +294,26 @@ def import_algorithm_packages(self) -> Dict[str, Any]: FusedMoE=fused_moe, vllm_platforms=vllm_platforms, triton=triton, - tqdm_ray=tqdm_ray, override_config=override_config, + envs=envs, ) - def save_configs( - configs: dict[int, BenchmarkConfig], +class BenchmarkConfig(TypedDict): + """ + The configuration for the matrix multiplication (tiling and warp scheduling). + """ + BLOCK_SIZE_M: int + BLOCK_SIZE_N: int + BLOCK_SIZE_K: int + GROUP_SIZE_M: int + num_warps: int + num_stages: int + +# Converts the function into a Ray actor and requests one GPU per actor instance +@ray.remote(num_gpus=1) +def tune( + self, + num_tokens: int, num_experts: int, shard_intermediate_size: int, hidden_size: int, @@ -319,43 +321,100 @@ def save_configs( dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + search_space: list[dict[str, int]], block_quant_shape: list[int], - save_dir: str, + use_deep_gemm: bool, imported_packages: Dict[str, Any], - ) -> None: - dtype_str = imported_packages["_get_config_dtype_str"]( - dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 - ) - - # NOTE(woosuk): The current naming convention uses w2.shape[2], which - # is the intermediate size after silu_and_mul. - filename = imported_packages["fused_moe"].get_config_file_name( - num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape - ) - - # We want to save at 3 different places: - # 1. The cache of vllm - # 2. The cache of kernels hub - # 3. The smashconfig (to be reused once mode is smashed and saved). + seed: int, + num_iters: int, + ) -> dict[str, int]: + """ + Tune a given Triton kernel. + + Parameters + ---------- + num_tokens: int + The number of tokens in the batch. + num_experts: int + The number of experts. + shard_intermediate_size: int + The intermediate size of the model in the shard (if using tensor parallelism). + hidden_size: int + The hidden size of the model. + topk: int + The number of active experts per token. + dtype: torch.dtype + The dtype to use for the weights and activations. + use_fp8_w8a8: bool + Whether to use fp8_w8a8. + use_int8_w8a16: bool + Whether to use int8_w8a16. + search_space: list[dict[str, int]] + The search space for the kernel (tiling and warp scheduling). + block_quant_shape: list[int] + The block shape for the kernel (None here). + use_deep_gemm: bool + Whether to use deep gemm (False here). + imported_packages: Dict[str, Any] + The imported packages (vllm, triton, etc.). + seed: int + The random seed. + num_iters: int + The number of iterations to average the kernel time on. + + Returns + ------- + dict[str, int] + The best config. + """ + imported_packages["vllm_platforms"].current_platform.seed_everything(seed) + best_config = None + best_time = float("inf") + + for config in tqdm_ray(search_space): + try: + kernel_time = benchmark_config( + config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=num_iters, + block_quant_shape=block_quant_shape, + use_deep_gemm=use_deep_gemm, + imported_packages=imported_packages, + num_iters=num_iters, + ) + except imported_packages["triton"].runtime.autotuner.OutOfResources: + # Some configurations may be invalid and fail to compile. + continue + if kernel_time < best_time: + best_time, best_config = kernel_time, config - os.makedirs(save_dir, exist_ok=True) - filename = os.path.join(save_dir, filename) - pruna_logger.info(f"Writing best config to {filename}...") - with open(filename, "w") as f: - json.dump({"triton_version": imported_packages["triton"].__version__, **configs}, f, indent=4) - f.write("\n") + now = datetime.now() + pruna_logger.info(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") + assert best_config is not None + return best_config +def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: + """ + Sort the configuration (tiling and warp scheduling). -class BenchmarkConfig(TypedDict): - BLOCK_SIZE_M: int - BLOCK_SIZE_N: int - BLOCK_SIZE_K: int - GROUP_SIZE_M: int - num_warps: int - num_stages: int + Parameters + ---------- + config: BenchmarkConfig + The configuration to sort. -def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: + Returns + ------- + BenchmarkConfig + The sorted configuration. + """ return { "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], @@ -375,7 +434,20 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: } -def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int]]: +def get_configs_compute_bound(use_fp16: bool) -> list[dict[str, int]]: + """ + Get the gridsearch space for the kernel (tiling and warp scheduling). + + Parameters + ---------- + use_fp16: bool + Whether to use fp16. + + Returns + ------- + list[dict[str, int]] + The search space for the kernel (tiling and warp scheduling). + """ configs: list[BenchmarkConfig] = [] # Reduced search space for faster tuning. @@ -400,17 +472,6 @@ def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int config = dict(zip(keys, config_values)) configs.append(config) - # Remove configs that are not compatible with fp8 block quantization - # BLOCK_SIZE_K must be a multiple of block_k - # BLOCK_SIZE_N must be a multiple of block_n - if block_quant_shape is not None and not use_fp16: - block_n, block_k = block_quant_shape[0], block_quant_shape[1] - for config in configs[:]: - if ( - config["BLOCK_SIZE_K"] % block_k != 0 - or config["BLOCK_SIZE_N"] % block_n != 0 - ): - configs.remove(config) return configs @@ -429,6 +490,46 @@ def benchmark_config( use_deep_gemm: bool = False, imported_packages: Dict[str, Any] = None, ) -> float: + """ + Benchmark a given Triton kernel using CUDAGraph. + + This function is copied from the vllm repository. + https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py + + Parameters + ---------- + config: BenchmarkConfig + The configuration to benchmark. + num_tokens: int + The number of tokens in the batch. + num_experts: int + The number of experts. + shard_intermediate_size: int + The intermediate size of the model in the shard (if using tensor parallelism). + hidden_size: int + The hidden size of the model. + topk: int + The number of active experts per token. + dtype: torch.dtype + The dtype to use for the weights and activations. + use_fp8_w8a8: bool + Whether to use fp8_w8a8. + use_int8_w8a16: bool + Whether to use int8_w8a16. + num_iters: int + The number of iterations to run the benchmark. + block_quant_shape: list[int] + The block shape for the kernel (None here). + use_deep_gemm: bool + Whether to use deep gemm (False here). + imported_packages: Dict[str, Any] + The imported packages (vllm, triton, etc.). + + Returns + ------- + float + The average latency of the kernel. + """ init_dtype = torch.float16 if use_fp8_w8a8 else dtype x = torch.randn(num_tokens, hidden_size, dtype=dtype) if use_int8_w8a16: @@ -571,3 +672,80 @@ def run(): avg = sum(latencies) / (num_iters * 10) * 1000 # us graph.reset() return avg + + +def save_configs( + configs: dict[int, BenchmarkConfig], + num_experts: int, + shard_intermediate_size: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + block_quant_shape: list[int], + path_to_huggingface_hub_cache: str, + path_to_vllm_cache: str, + imported_packages: Dict[str, Any], +) -> None: + """ + Save the best configs to the hf cache and vllm cache. + + Parameters + ---------- + configs: dict[int, BenchmarkConfig] + The best configs. + num_experts: int + The number of experts. + shard_intermediate_size: int + The intermediate size of the model in the shard (if using tensor parallelism). + hidden_size: int + The hidden size of the model. + topk: int + The number of active experts per token. + dtype: torch.dtype + The dtype to use for the weights and activations. + use_fp8_w8a8: bool + Whether to use fp8_w8a8. + use_int8_w8a16: bool + Whether to use int8_w8a16. + block_quant_shape: list[int] + The block shape for the kernel (None here). + path_to_huggingface_hub_cache: str + The path to the huggingface hub cache. + path_to_vllm_cache: str + The path to the vllm cache. + imported_packages: Dict[str, Any] + The imported packages (vllm, triton, etc.). + """ + dtype_str = imported_packages["_get_config_dtype_str"]( + dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 + ) + + # (i) Get the name of the config file + # NB from vllm: The current naming convention uses w2.shape[2], which + # is the intermediate size after silu_and_mul. + filename = imported_packages["fused_moe"].get_config_file_name( + num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape + ) + + # (ii) Save the config to the hf cache (where `kernels` lib expects to find it) + path_to_kernel_configs = os.path.join(path_to_huggingface_hub_cache, ".cache/huggingface/hub/models--RedHatAI--moe/blobs/configs") + os.makedirs(path_to_kernel_configs, exist_ok=True) + filename = os.path.join(path_to_kernel_configs, filename) + if not os.path.exists(filename): + pruna_logger.info(f"Writing best config to {filename}...") + with open(filename, "w") as f: + json.dump({"triton_version": imported_packages["triton"].__version__, **configs}, f, indent=4) + f.write("\n") + + # (iii) Save the config to the vllm cache (where `vllm` expects to find it) + path_to_vllm_configs = imported_packages["envs"].VLLM_TUNED_CONFIG_FOLDER + if path_to_vllm_configs is None: + path_where_vllm_is_installed = find_spec("vllm").submodule_search_locations[0] + path_to_vllm_configs = os.path.join(os.path.dirname(path_where_vllm_is_installed), path_to_vllm_cache, "configs") + os.makedirs(path_to_vllm_configs, exist_ok=True) + filename = os.path.join(path_to_vllm_configs, filename) + if not os.path.exists(filename): + pruna_logger.info(f"Writing best config to {filename}...") + with open(filename, "w") as f: + json.dump({"triton_version": imported_packages["triton"].__version__, **configs}, f, indent=4) + f.write("\n") From 30fe4eeac3a59184dd78cac502ad74ce6ee91f8d Mon Sep 17 00:00:00 2001 From: llcnt Date: Thu, 18 Dec 2025 16:27:45 +0000 Subject: [PATCH 05/26] feat: add custom loading fn --- src/pruna/engine/load.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index 7679b09a..2aac6d14 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -557,6 +557,39 @@ def load_hqq_diffusers(path: str | Path, smash_config: SmashConfig, **kwargs) -> return model +def load_moe_kernel_tuner(path: str | Path, smash_config: SmashConfig, **kwargs) -> Any: + """ + Load a tuned kernel config inside the hf/vllm cache, then load the model. + + Parameters + ---------- + path: str | Path + The path to the model directory. + smash_config: SmashConfig + The SmashConfig object containing the best configs for the MoE kernel tuner. + **kwargs: Any + Additional keyword arguments to pass to the model loading function. + + Returns + ------- + Any + The loaded model. + """ + from pruna.algorithms.moe_kernel_tuner import MoEKernelTuner, save_configs + imported_packages = MoEKernelTuner().import_algorithm_packages() + save_configs(smash_config["best_configs_moe_kernel"], + smash_config["num_experts"], + smash_config["shard_intermediate_size"], + smash_config["dtype"], + smash_config["use_fp8_w8a8"], + smash_config["use_int8_w8a16"], + smash_config["block_quant_shape"], + smash_config["path_to_huggingface_hub_cache"], + smash_config["path_to_vllm_cache"], + imported_packages) + return load_transformers_model(path, smash_config, **kwargs) + + class LOAD_FUNCTIONS(Enum): # noqa: N801 """ Enumeration of load functions for different model types. @@ -593,6 +626,7 @@ class LOAD_FUNCTIONS(Enum): # noqa: N801 pickled = partial(load_pickled) hqq = partial(load_hqq) hqq_diffusers = partial(load_hqq_diffusers) + moe_kernel_tuner = partial(load_moe_kernel_tuner) def __call__(self, *args, **kwargs) -> Any: """ From ef498027d1e75a0405cdc94c1fa3f87f79df6fb5 Mon Sep 17 00:00:00 2001 From: llcnt Date: Thu, 18 Dec 2025 16:28:25 +0000 Subject: [PATCH 06/26] feat: add custom loading fn --- src/pruna/engine/load.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index 2aac6d14..407d16fe 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -576,17 +576,20 @@ def load_moe_kernel_tuner(path: str | Path, smash_config: SmashConfig, **kwargs) The loaded model. """ from pruna.algorithms.moe_kernel_tuner import MoEKernelTuner, save_configs + imported_packages = MoEKernelTuner().import_algorithm_packages() - save_configs(smash_config["best_configs_moe_kernel"], - smash_config["num_experts"], - smash_config["shard_intermediate_size"], - smash_config["dtype"], - smash_config["use_fp8_w8a8"], - smash_config["use_int8_w8a16"], - smash_config["block_quant_shape"], - smash_config["path_to_huggingface_hub_cache"], - smash_config["path_to_vllm_cache"], - imported_packages) + save_configs( + smash_config["best_configs_moe_kernel"], + smash_config["num_experts"], + smash_config["shard_intermediate_size"], + smash_config["dtype"], + smash_config["use_fp8_w8a8"], + smash_config["use_int8_w8a16"], + smash_config["block_quant_shape"], + smash_config["path_to_huggingface_hub_cache"], + smash_config["path_to_vllm_cache"], + imported_packages, + ) return load_transformers_model(path, smash_config, **kwargs) From 2e191c771dbdc6214774b922a0a283064ae4adb6 Mon Sep 17 00:00:00 2001 From: llcnt Date: Fri, 19 Dec 2025 16:58:11 +0000 Subject: [PATCH 07/26] feat: add unit test --- src/pruna/algorithms/moe_kernel_tuner.py | 9 +++++---- src/pruna/engine/load.py | 2 +- tests/algorithms/testers/moe_kernel_tuner.py | 13 +++++++++++++ 3 files changed, 19 insertions(+), 5 deletions(-) create mode 100644 tests/algorithms/testers/moe_kernel_tuner.py diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index 6daaac57..3c421201 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -40,7 +40,7 @@ class MoeKernelTuner(PrunaAlgorithmBase): """ Tune the MoE Triton kernel for the model. - Uses vLLM to tune the MoE kernel of the model. + Uses vLLM to tune the MoE kernel. """ algorithm_name: str = "moe_kernel_tuner" @@ -172,8 +172,8 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: raise ValueError(f"Model {model.__class__.__name__} has no config.") E = model_config.num_experts # number of experts topk = model_config.num_experts_per_tok if is_moe_lm(model) else model_config.moe_topk[0] # number of active experts per token - intermediate_size = model_config.intermediate_size # 3072 # FFN intermediate size - hidden_size = model_config.hidden_size #4096 # model hidden dim + intermediate_size = model_config.intermediate_size # FFN intermediate size + hidden_size = model_config.hidden_size # model hidden dim assert intermediate_size % smash_config["tensor_parallel_size"] == 0, ( f"intermediate_size {intermediate_size} is not divisible by tp " f"{smash_config['tensor_parallel_size']}." @@ -207,6 +207,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: 4096, ] + # use ray to parallelize the tuning ray.init() is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) @@ -261,7 +262,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: smash_config["dtype"] = dtype smash_config["use_fp8_w8a8"] = use_fp8_w8a8 smash_config["use_int8_w8a16"] = use_int8_w8a16 - # attached load function to the smash config for loading + # attach load function to the smash config for loading smash_config.load_fns.append(LOAD_FUNCTIONS.moe_kernel_tuner.name) end = time.time() pruna_logger.info(f"Tuning took {end - start:.2f} seconds") diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index 407d16fe..d703ec2d 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -559,7 +559,7 @@ def load_hqq_diffusers(path: str | Path, smash_config: SmashConfig, **kwargs) -> def load_moe_kernel_tuner(path: str | Path, smash_config: SmashConfig, **kwargs) -> Any: """ - Load a tuned kernel config inside the hf/vllm cache, then load the model. + Load a tuned kernel config inside the hf/vllm caches, then load the model. Parameters ---------- diff --git a/tests/algorithms/testers/moe_kernel_tuner.py b/tests/algorithms/testers/moe_kernel_tuner.py new file mode 100644 index 00000000..5c4da5c8 --- /dev/null +++ b/tests/algorithms/testers/moe_kernel_tuner.py @@ -0,0 +1,13 @@ +from pruna.algorithms.moe_kernel_tuner import MoeKernelTuner + +from .base_tester import AlgorithmTesterBase + + +class TestMoeKernelTuner(AlgorithmTesterBase): + """Test the MoeKernelTuner.""" + + models = ["qwen3_coder_tiny"] + reject_models = ["sd_tiny_random"] + allow_pickle_files = False + algorithm_class = MoeKernelTuner + metrics = ["perplexity"] From 35a827dd19822d64fe0aff8fa5f9df79bac4c9a2 Mon Sep 17 00:00:00 2001 From: llcnt Date: Mon, 22 Dec 2025 18:09:13 +0000 Subject: [PATCH 08/26] feat: add vllm dep and upd torch version --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 18193b70..3c1b42cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,7 +137,8 @@ dependencies = [ "vbench-pruna; sys_platform != 'darwin'", "imageio-ffmpeg", "jaxtyping", - "peft>=0.17.1", + "peft>=0.17.1",, + "vllm>=0.11.0", ] [project.optional-dependencies] From fb7404ba5f1df71f6d45be14616e047d81da3757 Mon Sep 17 00:00:00 2001 From: llcnt Date: Mon, 22 Dec 2025 18:10:51 +0000 Subject: [PATCH 09/26] feat: change smashconfig to save artifacts and reload it --- src/pruna/algorithms/moe_kernel_tuner.py | 149 +++++++++++++---------- src/pruna/config/smash_config.py | 46 ++++++- src/pruna/engine/load.py | 34 ++++-- 3 files changed, 160 insertions(+), 69 deletions(-) diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index 3c421201..e03f2a20 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -91,8 +91,8 @@ def get_hyperparameters(self) -> list: ), CategoricalHyperparameter( "weight_dtype", - choices=["fp8_w8a8", "int8_w8a16"], - default_value="fp8_w8a8", + choices=["fp16", "fp8_w8a8", "int8_w8a16"], + default_value="fp16", meta=dict(desc="Dtype to use for the weights (and activations)."), ), OrdinalHyperparameter( @@ -123,6 +123,24 @@ def get_hyperparameters(self) -> list: default_value=20, meta=dict(desc="Number of iterations to average the kernel times on."), ), + OrdinalHyperparameter( + "block_size_m_max", + sequence=[4, 5, 6, 7, 8, 9, 10], + default_value=8, + meta=dict(desc="Maximum (log) block size for tiling through input dimension."), + ), + OrdinalHyperparameter( + "block_size_n_max", + sequence=[5, 6, 7, 8, 9, 10], + default_value=8, + meta=dict(desc="Maximum (log) block size for tiling through output dimension."), + ), + OrdinalHyperparameter( + "block_size_k_max", + sequence=[6, 7, 8, 9, 10], + default_value=8, + meta=dict(desc="Maximum (log) block size for tiling through intermediate dimension."), + ), ] def model_check_fn(self, model: Any) -> bool: @@ -182,36 +200,33 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: # (ii) Get the compute parameters dtype = smash_config["compute_dtype"] + if dtype == "bfloat16": + dtype = torch.bfloat16 + else: # default to float16 + dtype = torch.float16 use_fp8_w8a8 = smash_config["weight_dtype"] == "fp8_w8a8" use_int8_w8a16 = smash_config["weight_dtype"] == "int8_w8a16" # (iii) Tune the kernel over a range of batch sizes batch_sizes = [ 1, - 2, - 4, - 8, - 16, - 24, - 32, - 48, - 64, - 96, - 128, - 256, - 512, - 1024, - 1536, - 2048, - 3072, - 4096, - ] + 2,] + # 4, + # 8, + # 16, + # 24, + # 32, + # 48, + # 64, + # 96, + # 128, + # ] # use ray to parallelize the tuning ray.init() is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) - search_space = get_configs_compute_bound(is_fp16) + search_space = get_configs_compute_bound(is_fp16, smash_config) pruna_logger.info(f"Start tuning over {len(search_space)} configurations...") start = time.time() @@ -254,14 +269,18 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: smash_config["path_to_vllm_cache"], imported_packages, ) - # attached configs to the smash config - smash_config["best_configs_moe_kernel"] = best_configs - # attached hyperparameters to the smash config for loading - smash_config["num_experts"] = E - smash_config["shard_intermediate_size"] = shard_intermediate_size - smash_config["dtype"] = dtype - smash_config["use_fp8_w8a8"] = use_fp8_w8a8 - smash_config["use_int8_w8a16"] = use_int8_w8a16 + # stash results in the SmashConfig for later loading (cannot add new hyperparams to ConfigSpace here) + payload = dict( + best_configs_moe_kernel=best_configs, + num_experts=E, + shard_intermediate_size=shard_intermediate_size, + dtype=dtype, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + block_quant_shape=None, + ) + # store artifacts in SmashConfig so they persist across save/load + smash_config.artifacts["moe_kernel_tuner"] = payload # attach load function to the smash config for loading smash_config.load_fns.append(LOAD_FUNCTIONS.moe_kernel_tuner.name) end = time.time() @@ -313,7 +332,6 @@ class BenchmarkConfig(TypedDict): # Converts the function into a Ray actor and requests one GPU per actor instance @ray.remote(num_gpus=1) def tune( - self, num_tokens: int, num_experts: int, shard_intermediate_size: int, @@ -372,7 +390,7 @@ def tune( best_config = None best_time = float("inf") - for config in tqdm_ray(search_space): + for config in tqdm_ray.tqdm(search_space): try: kernel_time = benchmark_config( config, @@ -388,7 +406,6 @@ def tune( block_quant_shape=block_quant_shape, use_deep_gemm=use_deep_gemm, imported_packages=imported_packages, - num_iters=num_iters, ) except imported_packages["triton"].runtime.autotuner.OutOfResources: # Some configurations may be invalid and fail to compile. @@ -435,7 +452,7 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: } -def get_configs_compute_bound(use_fp16: bool) -> list[dict[str, int]]: +def get_configs_compute_bound(use_fp16: bool, smash_config: SmashConfigPrefixWrapper) -> list[dict[str, int]]: """ Get the gridsearch space for the kernel (tiling and warp scheduling). @@ -443,7 +460,8 @@ def get_configs_compute_bound(use_fp16: bool) -> list[dict[str, int]]: ---------- use_fp16: bool Whether to use fp16. - + smash_config: SmashConfigPrefixWrapper + The Smash configuration. Returns ------- list[dict[str, int]] @@ -452,9 +470,9 @@ def get_configs_compute_bound(use_fp16: bool) -> list[dict[str, int]]: configs: list[BenchmarkConfig] = [] # Reduced search space for faster tuning. - block_m_range = [16, 32, 64, 128, 256] - block_n_range = [32, 64, 128, 256] - block_k_range = [64, 128, 256] + block_m_range = [2**i for i in range(4, smash_config["block_size_m_max"]+1)] + block_n_range = [2**i for i in range(5, smash_config["block_size_n_max"]+1)] + block_k_range = [2**i for i in range(6, smash_config["block_size_k_max"]+1)] num_warps_range = [4, 8] group_m_range = [1, 16, 32, 64] num_stage_range = [2, 3, 4, 5] @@ -531,8 +549,14 @@ def benchmark_config( float The average latency of the kernel. """ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for MoeKernelTuner.") + # Ray sets CUDA_VISIBLE_DEVICES per worker to the GPU it scheduled + torch.cuda.set_device(0) + device = torch.device("cuda") + init_dtype = torch.float16 if use_fp8_w8a8 else dtype - x = torch.randn(num_tokens, hidden_size, dtype=dtype) + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) if use_int8_w8a16: w1 = torch.randint( -127, @@ -543,6 +567,7 @@ def benchmark_config( hidden_size, ), dtype=torch.int8, + device=device, ) w2 = torch.randint( -127, @@ -553,15 +578,16 @@ def benchmark_config( shard_intermediate_size // 2, ), dtype=torch.int8, + device=device, ) else: w1 = torch.randn( - num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype + num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype, device=device ) w2 = torch.randn( - num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype + num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype, device=device ) - gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32) + gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32, device=device) w1_scale = None w2_scale = None @@ -569,9 +595,9 @@ def benchmark_config( a2_scale = None if use_int8_w8a16: w1_scale = torch.randn( - (num_experts, 2 * shard_intermediate_size), dtype=torch.float32 + (num_experts, 2 * shard_intermediate_size), dtype=torch.float32, device=device ) - w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) + w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32, device=device) if use_deep_gemm: # we use the default block shape for deepgemm block_quant_shape = [128, 128] @@ -587,24 +613,24 @@ def benchmark_config( k_tiles_w1 = (K + block_k - 1) // block_k k_tiles_w2 = (N + block_k - 1) // block_k w1_scale = ( - torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) + torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device=device) * factor_for_scale ) w2_scale = ( - torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device=device) * factor_for_scale ) else: - w1_scale = torch.randn(num_experts, dtype=torch.float32) - w2_scale = torch.randn(num_experts, dtype=torch.float32) + w1_scale = torch.randn(num_experts, dtype=torch.float32, device=device) + w2_scale = torch.randn(num_experts, dtype=torch.float32, device=device) - a1_scale = torch.randn(1, dtype=torch.float32) - a2_scale = torch.randn(1, dtype=torch.float32) + a1_scale = torch.randn(1, dtype=torch.float32, device=device) + a2_scale = torch.randn(1, dtype=torch.float32, device=device) - w1 = w1.to(imported_packages["vllm_platforms"].current_platform.fp8_dtype()) - w2 = w2.to(imported_packages["vllm_platforms"].current_platform.fp8_dtype()) + w1 = w1.to(device=device, dtype=imported_packages["vllm_platforms"].current_platform.fp8_dtype()) + w2 = w2.to(device=device, dtype=imported_packages["vllm_platforms"].current_platform.fp8_dtype()) - input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) + input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32, device=device) def prepare(i: int): input_gating.copy_(gating_output[i]) @@ -662,6 +688,7 @@ def run(): latencies: list[float] = [] for i in range(num_iters): + print(f"Iteration {i} of {num_iters}") prepare(i) torch.cuda.synchronize() @@ -724,17 +751,17 @@ def save_configs( # (i) Get the name of the config file # NB from vllm: The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. - filename = imported_packages["fused_moe"].get_config_file_name( + filename = imported_packages["FusedMoE"].get_config_file_name( num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape ) # (ii) Save the config to the hf cache (where `kernels` lib expects to find it) path_to_kernel_configs = os.path.join(path_to_huggingface_hub_cache, ".cache/huggingface/hub/models--RedHatAI--moe/blobs/configs") os.makedirs(path_to_kernel_configs, exist_ok=True) - filename = os.path.join(path_to_kernel_configs, filename) - if not os.path.exists(filename): - pruna_logger.info(f"Writing best config to {filename}...") - with open(filename, "w") as f: + filename_hf = os.path.join(path_to_kernel_configs, filename) + if not os.path.exists(filename_hf): + pruna_logger.info(f"Writing best config to {filename_hf}...") + with open(filename_hf, "w") as f: json.dump({"triton_version": imported_packages["triton"].__version__, **configs}, f, indent=4) f.write("\n") @@ -744,9 +771,9 @@ def save_configs( path_where_vllm_is_installed = find_spec("vllm").submodule_search_locations[0] path_to_vllm_configs = os.path.join(os.path.dirname(path_where_vllm_is_installed), path_to_vllm_cache, "configs") os.makedirs(path_to_vllm_configs, exist_ok=True) - filename = os.path.join(path_to_vllm_configs, filename) - if not os.path.exists(filename): - pruna_logger.info(f"Writing best config to {filename}...") - with open(filename, "w") as f: + filename_vllm = os.path.join(path_to_vllm_configs, filename) + if not os.path.exists(filename_vllm): + pruna_logger.info(f"Writing best config to {filename_vllm}...") + with open(filename_vllm, "w") as f: json.dump({"triton_version": imported_packages["triton"].__version__, **configs}, f, indent=4) f.write("\n") diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index 8b0f640d..dc68962f 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -39,6 +39,7 @@ "device", "device_map", "cache_dir", + "artifacts", "save_fns", "save_artifacts_fns", "load_fns", @@ -105,6 +106,9 @@ def __init__( # internal variable to indicated that a model has been smashed for a specific batch size self.__locked_batch_size = False + # generic container for algorithm-produced artifacts that should be saved with the config + self.artifacts: dict[str, Any] = {} + # ensure the cache directory is deleted on program exit atexit.register(self.cleanup_cache_dir) @@ -321,7 +325,10 @@ def save_to_json(self, path: str | Path) -> None: config_dict[key] = convert_numpy_types(value) for name in ADDITIONAL_ARGS: - config_dict[name] = getattr(self, name) + value = getattr(self, name) + if name == "artifacts": + value = convert_artifacts_for_json(value) + config_dict[name] = value # do not save the old cache directory or device if "cache_dir" in config_dict: @@ -757,3 +764,40 @@ def convert_numpy_types(input_value: Any) -> Any: return input_value.item() else: return input_value + + +def convert_artifacts_for_json(value: Any) -> Any: + """ + Convert artifacts to JSON-serializable forms. + + - torch.dtype -> its string name (e.g., 'float16', 'bfloat16') + - Path -> str + - sets/tuples -> lists + - recursively handle dicts/lists + + Parameters + ---------- + value : Any + The value to convert. + + Returns + ------- + Any + The converted value. + """ + if isinstance(value, dict): + return {k: convert_artifacts_for_json(v) for k, v in value.items()} + if isinstance(value, list): + return [convert_artifacts_for_json(v) for v in value] + if isinstance(value, tuple) or isinstance(value, set): + return [convert_artifacts_for_json(v) for v in value] + if isinstance(value, torch.dtype): + # map to canonical string + if value == getattr(torch, "float16", None): + return "float16" + if value == getattr(torch, "bfloat16", None): + return "bfloat16" + return str(value) + if isinstance(value, Path): + return str(value) + return value diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index d703ec2d..5efe4aea 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -578,18 +578,38 @@ def load_moe_kernel_tuner(path: str | Path, smash_config: SmashConfig, **kwargs) from pruna.algorithms.moe_kernel_tuner import MoEKernelTuner, save_configs imported_packages = MoEKernelTuner().import_algorithm_packages() + payload = getattr(smash_config, "artifacts", {}).get("moe_kernel_tuner") + if not payload: + pruna_logger.error( + "MoE kernel tuner artifacts not found in SmashConfig. " + "Ensure the tuner ran successfully before saving/loading." + ) + best_configs = payload["best_configs_moe_kernel"] + num_experts = payload["num_experts"] + shard_intermediate_size = payload["shard_intermediate_size"] + dtype = payload["dtype"] + # Convert dtype string back to torch.dtype if needed + if dtype == "bfloat16": + dtype = torch.bfloat16 + else: + dtype = torch.float16 + use_fp8_w8a8 = payload["use_fp8_w8a8"] + use_int8_w8a16 = payload["use_int8_w8a16"] + block_quant_shape = payload["block_quant_shape"] + save_configs( - smash_config["best_configs_moe_kernel"], - smash_config["num_experts"], - smash_config["shard_intermediate_size"], - smash_config["dtype"], - smash_config["use_fp8_w8a8"], - smash_config["use_int8_w8a16"], - smash_config["block_quant_shape"], + best_configs, + num_experts, + shard_intermediate_size, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + block_quant_shape, smash_config["path_to_huggingface_hub_cache"], smash_config["path_to_vllm_cache"], imported_packages, ) + smash_config.load_fns.remove(LOAD_FUNCTIONS.moe_kernel_tuner.name) return load_transformers_model(path, smash_config, **kwargs) From 9e55dae195c43b17882204e89a742f0994f4b4d0 Mon Sep 17 00:00:00 2001 From: llcnt Date: Tue, 23 Dec 2025 10:34:11 +0000 Subject: [PATCH 10/26] fix: adapt parameter names inside smashconfig --- src/pruna/engine/load.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index 5efe4aea..626ff9f7 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -575,9 +575,9 @@ def load_moe_kernel_tuner(path: str | Path, smash_config: SmashConfig, **kwargs) Any The loaded model. """ - from pruna.algorithms.moe_kernel_tuner import MoEKernelTuner, save_configs + from pruna.algorithms.moe_kernel_tuner import MoeKernelTuner, save_configs - imported_packages = MoEKernelTuner().import_algorithm_packages() + imported_packages = MoeKernelTuner().import_algorithm_packages() payload = getattr(smash_config, "artifacts", {}).get("moe_kernel_tuner") if not payload: pruna_logger.error( @@ -597,6 +597,7 @@ def load_moe_kernel_tuner(path: str | Path, smash_config: SmashConfig, **kwargs) use_int8_w8a16 = payload["use_int8_w8a16"] block_quant_shape = payload["block_quant_shape"] + # save the config attached to smash_config, inside the hf and vllm caches. save_configs( best_configs, num_experts, @@ -605,8 +606,8 @@ def load_moe_kernel_tuner(path: str | Path, smash_config: SmashConfig, **kwargs) use_fp8_w8a8, use_int8_w8a16, block_quant_shape, - smash_config["path_to_huggingface_hub_cache"], - smash_config["path_to_vllm_cache"], + smash_config["moe_kernel_tuner_path_to_huggingface_hub_cache"], + smash_config["moe_kernel_tuner_path_to_vllm_cache"], imported_packages, ) smash_config.load_fns.remove(LOAD_FUNCTIONS.moe_kernel_tuner.name) From d008de83bb210b0d26dfe94aa9f4dd531b946ef9 Mon Sep 17 00:00:00 2001 From: llcnt Date: Tue, 23 Dec 2025 10:35:57 +0000 Subject: [PATCH 11/26] fix: moe intermediate size can differ from model intermediate size --- src/pruna/algorithms/moe_kernel_tuner.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index e03f2a20..27ed73cf 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -190,7 +190,12 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: raise ValueError(f"Model {model.__class__.__name__} has no config.") E = model_config.num_experts # number of experts topk = model_config.num_experts_per_tok if is_moe_lm(model) else model_config.moe_topk[0] # number of active experts per token - intermediate_size = model_config.intermediate_size # FFN intermediate size + # qwen_moe can use different intermediate size compared to mixtral. + intermediate_size = ( + model_config.moe_intermediate_size + if model_config.moe_intermediate_size is not None + else model_config.intermediate_size + ) hidden_size = model_config.hidden_size # model hidden dim assert intermediate_size % smash_config["tensor_parallel_size"] == 0, ( f"intermediate_size {intermediate_size} is not divisible by tp " @@ -208,19 +213,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: use_int8_w8a16 = smash_config["weight_dtype"] == "int8_w8a16" # (iii) Tune the kernel over a range of batch sizes - batch_sizes = [ - 1, - 2,] - # 4, - # 8, - # 16, - # 24, - # 32, - # 48, - # 64, - # 96, - # 128, - # ] + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] # use ray to parallelize the tuning ray.init() @@ -769,7 +762,7 @@ def save_configs( path_to_vllm_configs = imported_packages["envs"].VLLM_TUNED_CONFIG_FOLDER if path_to_vllm_configs is None: path_where_vllm_is_installed = find_spec("vllm").submodule_search_locations[0] - path_to_vllm_configs = os.path.join(os.path.dirname(path_where_vllm_is_installed), path_to_vllm_cache, "configs") + path_to_vllm_configs = os.path.join(os.path.dirname(path_where_vllm_is_installed), path_to_vllm_cache) os.makedirs(path_to_vllm_configs, exist_ok=True) filename_vllm = os.path.join(path_to_vllm_configs, filename) if not os.path.exists(filename_vllm): From 9849c5eb8ca8926b7ef757947cc3de54c7f08736 Mon Sep 17 00:00:00 2001 From: llcnt Date: Tue, 23 Dec 2025 17:43:34 +0000 Subject: [PATCH 12/26] feat: ruff linting --- src/pruna/algorithms/moe_kernel_tuner.py | 95 +++++++++++++----------- src/pruna/config/smash_config.py | 2 +- src/pruna/engine/load.py | 5 +- 3 files changed, 53 insertions(+), 49 deletions(-) diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index 27ed73cf..ea710591 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -13,19 +13,19 @@ # limitations under the License. from __future__ import annotations +import json +import pathlib +import time from collections.abc import Iterable +from datetime import datetime +from importlib.util import find_spec +from itertools import product from typing import Any, Dict, TypedDict -import ray.experimental.tqdm_ray as tqdm_ray import ray -import time +import ray.experimental.tqdm_ray as tqdm_ray import torch from ConfigSpace import CategoricalHyperparameter, OrdinalHyperparameter -from datetime import datetime -import os -import json -from itertools import product -from importlib.util import find_spec from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase from pruna.algorithms.base.tags import AlgorithmTag as tags @@ -188,27 +188,29 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: model_config = model.config if model_config is None: raise ValueError(f"Model {model.__class__.__name__} has no config.") - E = model_config.num_experts # number of experts - topk = model_config.num_experts_per_tok if is_moe_lm(model) else model_config.moe_topk[0] # number of active experts per token + nb_experts = model_config.num_experts # number of experts + # number of active experts per token + topk = ( + model_config.num_experts_per_tok + if is_moe_lm(model) + else model_config.moe_topk[0] + ) # qwen_moe can use different intermediate size compared to mixtral. intermediate_size = ( model_config.moe_intermediate_size if model_config.moe_intermediate_size is not None else model_config.intermediate_size ) - hidden_size = model_config.hidden_size # model hidden dim + hidden_size = model_config.hidden_size # model hidden dim assert intermediate_size % smash_config["tensor_parallel_size"] == 0, ( f"intermediate_size {intermediate_size} is not divisible by tp " f"{smash_config['tensor_parallel_size']}." ) shard_intermediate_size = 2 * intermediate_size // smash_config["tensor_parallel_size"] - + # (ii) Get the compute parameters dtype = smash_config["compute_dtype"] - if dtype == "bfloat16": - dtype = torch.bfloat16 - else: # default to float16 - dtype = torch.float16 + dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float16 use_fp8_w8a8 = smash_config["weight_dtype"] == "fp8_w8a8" use_int8_w8a16 = smash_config["weight_dtype"] == "int8_w8a16" @@ -227,7 +229,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: for batch_size in batch_sizes: output = tune.remote( batch_size, # num_tokens - E, # num_experts + nb_experts, # num_experts per block shard_intermediate_size, hidden_size, topk, @@ -252,7 +254,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: # save configs in caches (for hf and vllm) save_configs( best_configs, - E, + nb_experts, shard_intermediate_size, dtype, use_fp8_w8a8, @@ -265,7 +267,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: # stash results in the SmashConfig for later loading (cannot add new hyperparams to ConfigSpace here) payload = dict( best_configs_moe_kernel=best_configs, - num_experts=E, + num_experts=nb_experts, shard_intermediate_size=shard_intermediate_size, dtype=dtype, use_fp8_w8a8=use_fp8_w8a8, @@ -291,15 +293,15 @@ def import_algorithm_packages(self) -> Dict[str, Any]: Dict[str, Any] The algorithm packages. """ + import vllm.envs as envs import vllm.model_executor.layers.fused_moe.fused_moe as fused_moe import vllm.platforms as vllm_platforms + from vllm.model_executor.layers.fused_moe import override_config from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, _get_config_dtype_str, ) - from vllm.model_executor.layers.fused_moe import override_config from vllm.triton_utils import triton - import vllm.envs as envs return dict( FusedMoEQuantConfig=FusedMoEQuantConfig, @@ -311,10 +313,10 @@ def import_algorithm_packages(self) -> Dict[str, Any]: envs=envs, ) + class BenchmarkConfig(TypedDict): - """ - The configuration for the matrix multiplication (tiling and warp scheduling). - """ + """The configuration for the matrix multiplication (tiling and warp scheduling).""" + BLOCK_SIZE_M: int BLOCK_SIZE_N: int BLOCK_SIZE_K: int @@ -322,6 +324,7 @@ class BenchmarkConfig(TypedDict): num_warps: int num_stages: int + # Converts the function into a Ray actor and requests one GPU per actor instance @ray.remote(num_gpus=1) def tune( @@ -412,6 +415,7 @@ def tune( assert best_config is not None return best_config + def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: """ Sort the configuration (tiling and warp scheduling). @@ -455,6 +459,7 @@ def get_configs_compute_bound(use_fp16: bool, smash_config: SmashConfigPrefixWra Whether to use fp16. smash_config: SmashConfigPrefixWrapper The Smash configuration. + Returns ------- list[dict[str, int]] @@ -463,9 +468,9 @@ def get_configs_compute_bound(use_fp16: bool, smash_config: SmashConfigPrefixWra configs: list[BenchmarkConfig] = [] # Reduced search space for faster tuning. - block_m_range = [2**i for i in range(4, smash_config["block_size_m_max"]+1)] - block_n_range = [2**i for i in range(5, smash_config["block_size_n_max"]+1)] - block_k_range = [2**i for i in range(6, smash_config["block_size_k_max"]+1)] + block_m_range = [2**i for i in range(4, smash_config["block_size_m_max"] + 1)] + block_n_range = [2**i for i in range(5, smash_config["block_size_n_max"] + 1)] + block_k_range = [2**i for i in range(6, smash_config["block_size_k_max"] + 1)] num_warps_range = [4, 8] group_m_range = [1, 16, 32, 64] num_stage_range = [2, 3, 4, 5] @@ -597,20 +602,20 @@ def benchmark_config( if use_fp8_w8a8: if block_quant_shape: block_n, block_k = block_quant_shape[0], block_quant_shape[1] - E = num_experts - N = shard_intermediate_size // 2 - K = hidden_size + e = num_experts + n = shard_intermediate_size // 2 + k = hidden_size factor_for_scale = 1e-2 - n_tiles_w1 = (2 * N + block_n - 1) // block_n - n_tiles_w2 = (K + block_n - 1) // block_n - k_tiles_w1 = (K + block_k - 1) // block_k - k_tiles_w2 = (N + block_k - 1) // block_k + n_tiles_w1 = (2 * n + block_n - 1) // block_n + n_tiles_w2 = (k + block_n - 1) // block_n + k_tiles_w1 = (k + block_k - 1) // block_k + k_tiles_w2 = (n + block_k - 1) // block_k w1_scale = ( - torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device=device) + torch.rand((e, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device=device) * factor_for_scale ) w2_scale = ( - torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device=device) + torch.rand((e, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device=device) * factor_for_scale ) else: @@ -681,7 +686,6 @@ def run(): latencies: list[float] = [] for i in range(num_iters): - print(f"Iteration {i} of {num_iters}") prepare(i) torch.cuda.synchronize() @@ -749,10 +753,13 @@ def save_configs( ) # (ii) Save the config to the hf cache (where `kernels` lib expects to find it) - path_to_kernel_configs = os.path.join(path_to_huggingface_hub_cache, ".cache/huggingface/hub/models--RedHatAI--moe/blobs/configs") - os.makedirs(path_to_kernel_configs, exist_ok=True) - filename_hf = os.path.join(path_to_kernel_configs, filename) - if not os.path.exists(filename_hf): + path_to_kernel_configs = ( + pathlib.Path(path_to_huggingface_hub_cache) / + ".cache/huggingface/hub/models--RedHatAI--moe/blobs/configs" + ) + pathlib.Path(path_to_kernel_configs).mkdir(exist_ok=True, parents=True) + filename_hf = path_to_kernel_configs / filename + if not pathlib.Path(filename_hf).exists(): pruna_logger.info(f"Writing best config to {filename_hf}...") with open(filename_hf, "w") as f: json.dump({"triton_version": imported_packages["triton"].__version__, **configs}, f, indent=4) @@ -762,10 +769,10 @@ def save_configs( path_to_vllm_configs = imported_packages["envs"].VLLM_TUNED_CONFIG_FOLDER if path_to_vllm_configs is None: path_where_vllm_is_installed = find_spec("vllm").submodule_search_locations[0] - path_to_vllm_configs = os.path.join(os.path.dirname(path_where_vllm_is_installed), path_to_vllm_cache) - os.makedirs(path_to_vllm_configs, exist_ok=True) - filename_vllm = os.path.join(path_to_vllm_configs, filename) - if not os.path.exists(filename_vllm): + path_to_vllm_configs = pathlib.Path(path_where_vllm_is_installed).parent / path_to_vllm_cache + pathlib.Path(path_to_vllm_configs).mkdir(exist_ok=True, parents=True) + filename_vllm = path_to_vllm_configs / filename + if not pathlib.Path(filename_vllm).exists(): pruna_logger.info(f"Writing best config to {filename_vllm}...") with open(filename_vllm, "w") as f: json.dump({"triton_version": imported_packages["triton"].__version__, **configs}, f, indent=4) diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index dc68962f..df2d0c50 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -789,7 +789,7 @@ def convert_artifacts_for_json(value: Any) -> Any: return {k: convert_artifacts_for_json(v) for k, v in value.items()} if isinstance(value, list): return [convert_artifacts_for_json(v) for v in value] - if isinstance(value, tuple) or isinstance(value, set): + if isinstance(value, (tuple, set)): return [convert_artifacts_for_json(v) for v in value] if isinstance(value, torch.dtype): # map to canonical string diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index 626ff9f7..aada2ca0 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -589,10 +589,7 @@ def load_moe_kernel_tuner(path: str | Path, smash_config: SmashConfig, **kwargs) shard_intermediate_size = payload["shard_intermediate_size"] dtype = payload["dtype"] # Convert dtype string back to torch.dtype if needed - if dtype == "bfloat16": - dtype = torch.bfloat16 - else: - dtype = torch.float16 + dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float16 use_fp8_w8a8 = payload["use_fp8_w8a8"] use_int8_w8a16 = payload["use_int8_w8a16"] block_quant_shape = payload["block_quant_shape"] From 36f63a7c5648d33a4a44b225c958de2305581f2a Mon Sep 17 00:00:00 2001 From: llcnt Date: Tue, 23 Dec 2025 18:07:33 +0000 Subject: [PATCH 13/26] feat: ty check linting --- src/pruna/algorithms/moe_kernel_tuner.py | 13 +++--- src/pruna/engine/load.py | 50 ++++++++++++------------ 2 files changed, 33 insertions(+), 30 deletions(-) diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index ea710591..49562d71 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -272,7 +272,6 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: dtype=dtype, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, - block_quant_shape=None, ) # store artifacts in SmashConfig so they persist across save/load smash_config.artifacts["moe_kernel_tuner"] = payload @@ -438,14 +437,14 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: "num_warps": config["num_warps"], "num_stages": config["num_stages"], **( - {"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {} + {"waves_per_eu": config.get("waves_per_eu")} if "waves_per_eu" in config else {} ), **( - {"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]} + {"matrix_instr_nonkdim": config.get("matrix_instr_nonkdim")} if "matrix_instr_nonkdim" in config else {} ), - **({"kpack": config["kpack"]} if "kpack" in config else {}), + **({"kpack": config.get("kpack")} if "kpack" in config else {}), } @@ -768,7 +767,11 @@ def save_configs( # (iii) Save the config to the vllm cache (where `vllm` expects to find it) path_to_vllm_configs = imported_packages["envs"].VLLM_TUNED_CONFIG_FOLDER if path_to_vllm_configs is None: - path_where_vllm_is_installed = find_spec("vllm").submodule_search_locations[0] + submodule_locations = find_spec("vllm").submodule_search_locations + if submodule_locations is not None and len(submodule_locations) > 0: + path_where_vllm_is_installed = submodule_locations[0] + else: + raise RuntimeError("Could not determine installation path for vllm.") path_to_vllm_configs = pathlib.Path(path_where_vllm_is_installed).parent / path_to_vllm_cache pathlib.Path(path_to_vllm_configs).mkdir(exist_ok=True, parents=True) filename_vllm = path_to_vllm_configs / filename diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index aada2ca0..3b899469 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -584,31 +584,31 @@ def load_moe_kernel_tuner(path: str | Path, smash_config: SmashConfig, **kwargs) "MoE kernel tuner artifacts not found in SmashConfig. " "Ensure the tuner ran successfully before saving/loading." ) - best_configs = payload["best_configs_moe_kernel"] - num_experts = payload["num_experts"] - shard_intermediate_size = payload["shard_intermediate_size"] - dtype = payload["dtype"] - # Convert dtype string back to torch.dtype if needed - dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float16 - use_fp8_w8a8 = payload["use_fp8_w8a8"] - use_int8_w8a16 = payload["use_int8_w8a16"] - block_quant_shape = payload["block_quant_shape"] - - # save the config attached to smash_config, inside the hf and vllm caches. - save_configs( - best_configs, - num_experts, - shard_intermediate_size, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - block_quant_shape, - smash_config["moe_kernel_tuner_path_to_huggingface_hub_cache"], - smash_config["moe_kernel_tuner_path_to_vllm_cache"], - imported_packages, - ) - smash_config.load_fns.remove(LOAD_FUNCTIONS.moe_kernel_tuner.name) - return load_transformers_model(path, smash_config, **kwargs) + else: + best_configs = payload["best_configs_moe_kernel"] + num_experts = payload["num_experts"] + shard_intermediate_size = payload["shard_intermediate_size"] + dtype = payload["dtype"] + # Convert dtype string back to torch.dtype if needed + dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float16 + use_fp8_w8a8 = payload["use_fp8_w8a8"] + use_int8_w8a16 = payload["use_int8_w8a16"] + + # save the config attached to smash_config, inside the hf and vllm caches. + save_configs( + best_configs, + num_experts, + shard_intermediate_size, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + None, + smash_config["moe_kernel_tuner_path_to_huggingface_hub_cache"], + smash_config["moe_kernel_tuner_path_to_vllm_cache"], + imported_packages, + ) + smash_config.load_fns.remove(LOAD_FUNCTIONS.moe_kernel_tuner.name) + return load_transformers_model(path, smash_config, **kwargs) class LOAD_FUNCTIONS(Enum): # noqa: N801 From 8c53f8c842b9181b43b8e7da0c1a6f587bd78507 Mon Sep 17 00:00:00 2001 From: llcnt Date: Wed, 24 Dec 2025 09:16:31 +0000 Subject: [PATCH 14/26] fix: npdoc space issue --- src/pruna/algorithms/moe_kernel_tuner.py | 84 +++++++++++------------- src/pruna/engine/load.py | 8 +-- 2 files changed, 44 insertions(+), 48 deletions(-) diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index 49562d71..daf477be 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -347,33 +347,33 @@ def tune( Parameters ---------- - num_tokens: int + num_tokens : int The number of tokens in the batch. - num_experts: int + num_experts : int The number of experts. - shard_intermediate_size: int + shard_intermediate_size : int The intermediate size of the model in the shard (if using tensor parallelism). - hidden_size: int + hidden_size : int The hidden size of the model. - topk: int + topk : int The number of active experts per token. - dtype: torch.dtype + dtype : torch.dtype The dtype to use for the weights and activations. - use_fp8_w8a8: bool + use_fp8_w8a8 : bool Whether to use fp8_w8a8. - use_int8_w8a16: bool + use_int8_w8a16 : bool Whether to use int8_w8a16. - search_space: list[dict[str, int]] + search_space : list[dict[str, int]] The search space for the kernel (tiling and warp scheduling). - block_quant_shape: list[int] + block_quant_shape : list[int] The block shape for the kernel (None here). - use_deep_gemm: bool + use_deep_gemm : bool Whether to use deep gemm (False here). - imported_packages: Dict[str, Any] + imported_packages : Dict[str, Any] The imported packages (vllm, triton, etc.). - seed: int + seed : int The random seed. - num_iters: int + num_iters : int The number of iterations to average the kernel time on. Returns @@ -421,7 +421,7 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: Parameters ---------- - config: BenchmarkConfig + config : BenchmarkConfig The configuration to sort. Returns @@ -454,9 +454,9 @@ def get_configs_compute_bound(use_fp16: bool, smash_config: SmashConfigPrefixWra Parameters ---------- - use_fp16: bool + use_fp16 : bool Whether to use fp16. - smash_config: SmashConfigPrefixWrapper + smash_config : SmashConfigPrefixWrapper The Smash configuration. Returns @@ -514,31 +514,31 @@ def benchmark_config( Parameters ---------- - config: BenchmarkConfig + config : BenchmarkConfig The configuration to benchmark. - num_tokens: int + num_tokens : int The number of tokens in the batch. - num_experts: int + num_experts : int The number of experts. - shard_intermediate_size: int + shard_intermediate_size : int The intermediate size of the model in the shard (if using tensor parallelism). - hidden_size: int + hidden_size : int The hidden size of the model. - topk: int + topk : int The number of active experts per token. - dtype: torch.dtype + dtype : torch.dtype The dtype to use for the weights and activations. - use_fp8_w8a8: bool + use_fp8_w8a8 : bool Whether to use fp8_w8a8. - use_int8_w8a16: bool + use_int8_w8a16 : bool Whether to use int8_w8a16. - num_iters: int + num_iters : int The number of iterations to run the benchmark. - block_quant_shape: list[int] + block_quant_shape : list[int] The block shape for the kernel (None here). - use_deep_gemm: bool + use_deep_gemm : bool Whether to use deep gemm (False here). - imported_packages: Dict[str, Any] + imported_packages : Dict[str, Any] The imported packages (vllm, triton, etc.). Returns @@ -715,29 +715,25 @@ def save_configs( Parameters ---------- - configs: dict[int, BenchmarkConfig] + configs : dict[int, BenchmarkConfig] The best configs. - num_experts: int + num_experts : int The number of experts. - shard_intermediate_size: int + shard_intermediate_size : int The intermediate size of the model in the shard (if using tensor parallelism). - hidden_size: int - The hidden size of the model. - topk: int - The number of active experts per token. - dtype: torch.dtype + dtype : torch.dtype The dtype to use for the weights and activations. - use_fp8_w8a8: bool + use_fp8_w8a8 : bool Whether to use fp8_w8a8. - use_int8_w8a16: bool + use_int8_w8a16 : bool Whether to use int8_w8a16. - block_quant_shape: list[int] + block_quant_shape : list[int] The block shape for the kernel (None here). - path_to_huggingface_hub_cache: str + path_to_huggingface_hub_cache : str The path to the huggingface hub cache. - path_to_vllm_cache: str + path_to_vllm_cache : str The path to the vllm cache. - imported_packages: Dict[str, Any] + imported_packages : Dict[str, Any] The imported packages (vllm, triton, etc.). """ dtype_str = imported_packages["_get_config_dtype_str"]( diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index 3b899469..cace9019 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -563,17 +563,17 @@ def load_moe_kernel_tuner(path: str | Path, smash_config: SmashConfig, **kwargs) Parameters ---------- - path: str | Path + path : str | Path The path to the model directory. - smash_config: SmashConfig + smash_config : SmashConfig The SmashConfig object containing the best configs for the MoE kernel tuner. - **kwargs: Any + **kwargs : Any Additional keyword arguments to pass to the model loading function. Returns ------- Any - The loaded model. + The loaded MoE model. """ from pruna.algorithms.moe_kernel_tuner import MoeKernelTuner, save_configs From 15a93eed0bf9928da9a2d75bcaeb42c0c92f7eb4 Mon Sep 17 00:00:00 2001 From: llcnt Date: Wed, 7 Jan 2026 14:49:05 +0000 Subject: [PATCH 15/26] fix: minor bugs from review --- src/pruna/algorithms/moe_kernel_tuner.py | 13 ++++++++++--- src/pruna/config/smash_config.py | 1 + src/pruna/engine/load.py | 4 +++- src/pruna/engine/model_checks.py | 2 +- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index daf477be..e9b5734d 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -197,8 +197,8 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: ) # qwen_moe can use different intermediate size compared to mixtral. intermediate_size = ( - model_config.moe_intermediate_size - if model_config.moe_intermediate_size is not None + getattr(model_config, "moe_intermediate_size", None) + if getattr(model_config, "moe_intermediate_size", None) is not None else model_config.intermediate_size ) hidden_size = model_config.hidden_size # model hidden dim @@ -246,6 +246,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: outputs.append(output) configs = ray.get(outputs) + ray.shutdown() # (iv) Sort the configs by batch size and save the best configs best_configs = { @@ -411,7 +412,13 @@ def tune( now = datetime.now() pruna_logger.info(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") - assert best_config is not None + if best_config is None: + raise RuntimeError( + f"No valid kernel configuration was found for batch_size={num_tokens}. " + "All configurations failed (e.g., due to OutOfResources). " + "This can happen on GPUs with limited resources. " + "Consider reducing your model size, batch size, or tuning search space." + ) return best_config diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index df2d0c50..2577a207 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -364,6 +364,7 @@ def flush_configuration(self) -> None: self.save_artifacts_fns = [] self.load_artifacts_fns = [] self.reapply_after_load = {} + self.artifacts = {} # reset potentially previously used cache directory self.reset_cache_dir() diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index cace9019..c4fd88e8 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -580,10 +580,12 @@ def load_moe_kernel_tuner(path: str | Path, smash_config: SmashConfig, **kwargs) imported_packages = MoeKernelTuner().import_algorithm_packages() payload = getattr(smash_config, "artifacts", {}).get("moe_kernel_tuner") if not payload: - pruna_logger.error( + error_msg = ( "MoE kernel tuner artifacts not found in SmashConfig. " "Ensure the tuner ran successfully before saving/loading." ) + pruna_logger.error(error_msg) + raise RuntimeError(error_msg) else: best_configs = payload["best_configs_moe_kernel"] num_experts = payload["num_experts"] diff --git a/src/pruna/engine/model_checks.py b/src/pruna/engine/model_checks.py index d3163389..fa5fb763 100644 --- a/src/pruna/engine/model_checks.py +++ b/src/pruna/engine/model_checks.py @@ -121,7 +121,7 @@ def is_moe_lm(model: Any) -> bool: bool True if the model is a MoE LM, False otherwise. """ - return hasattr(model, "num_experts") + return hasattr(getattr(model, "config", None), "num_experts") def is_transformers_pipeline_with_causal_lm(model: Any) -> bool: From 01e30966afbf5b272324dd5a607ec4485b02308b Mon Sep 17 00:00:00 2001 From: llcnt Date: Mon, 9 Feb 2026 14:52:14 +0000 Subject: [PATCH 16/26] feat: adapt tuned configs saving to the new artifact savings --- src/pruna/algorithms/moe_kernel_tuner.py | 23 ++++---- src/pruna/config/smash_config.py | 47 +--------------- src/pruna/engine/load.py | 57 -------------------- src/pruna/engine/load_artifacts.py | 51 ++++++++++++++++++ src/pruna/engine/save_artifacts.py | 14 +++++ tests/algorithms/testers/moe_kernel_tuner.py | 2 +- 6 files changed, 81 insertions(+), 113 deletions(-) diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index e9b5734d..176c3495 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -31,8 +31,8 @@ from pruna.algorithms.base.tags import AlgorithmTag as tags from pruna.config.hyperparameters import UnconstrainedHyperparameter from pruna.config.smash_config import SmashConfigPrefixWrapper -from pruna.engine.load import LOAD_FUNCTIONS from pruna.engine.model_checks import is_moe_lm, is_transformers_pipeline_with_moe_lm +from pruna.engine.save_artifacts import SAVE_ARTIFACTS_FUNCTIONS from pruna.logging.logger import pruna_logger @@ -252,7 +252,10 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: best_configs = { M: sort_config(config) for M, config in zip(batch_sizes, configs) } - # save configs in caches (for hf and vllm) + end = time.time() + pruna_logger.info(f"Tuning took {end - start:.2f} seconds") + + # save configs in hf/vllm caches for direct usage save_configs( best_configs, nb_experts, @@ -265,7 +268,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: smash_config["path_to_vllm_cache"], imported_packages, ) - # stash results in the SmashConfig for later loading (cannot add new hyperparams to ConfigSpace here) + # results to be saved for later loading payload = dict( best_configs_moe_kernel=best_configs, num_experts=nb_experts, @@ -274,12 +277,14 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, ) - # store artifacts in SmashConfig so they persist across save/load - smash_config.artifacts["moe_kernel_tuner"] = payload - # attach load function to the smash config for loading - smash_config.load_fns.append(LOAD_FUNCTIONS.moe_kernel_tuner.name) - end = time.time() - pruna_logger.info(f"Tuning took {end - start:.2f} seconds") + # store artifacts in pruna cache for later saving/loading + save_dir = smash_config.cache_dir / "moe_kernel_tuned_configs" + save_dir.mkdir(parents=True, exist_ok=True) + with open(save_dir / "moe_kernel_tuner.json", "w") as f: + json.dump(payload, f) + + # attach artifacts save function to the smash config for saving + smash_config.save_artifacts_fns.append(SAVE_ARTIFACTS_FUNCTIONS.moe_kernel_tuner_artifacts.name) # (v) Return the model (untouched; only the configs are saved) return model diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index 2577a207..8b0f640d 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -39,7 +39,6 @@ "device", "device_map", "cache_dir", - "artifacts", "save_fns", "save_artifacts_fns", "load_fns", @@ -106,9 +105,6 @@ def __init__( # internal variable to indicated that a model has been smashed for a specific batch size self.__locked_batch_size = False - # generic container for algorithm-produced artifacts that should be saved with the config - self.artifacts: dict[str, Any] = {} - # ensure the cache directory is deleted on program exit atexit.register(self.cleanup_cache_dir) @@ -325,10 +321,7 @@ def save_to_json(self, path: str | Path) -> None: config_dict[key] = convert_numpy_types(value) for name in ADDITIONAL_ARGS: - value = getattr(self, name) - if name == "artifacts": - value = convert_artifacts_for_json(value) - config_dict[name] = value + config_dict[name] = getattr(self, name) # do not save the old cache directory or device if "cache_dir" in config_dict: @@ -364,7 +357,6 @@ def flush_configuration(self) -> None: self.save_artifacts_fns = [] self.load_artifacts_fns = [] self.reapply_after_load = {} - self.artifacts = {} # reset potentially previously used cache directory self.reset_cache_dir() @@ -765,40 +757,3 @@ def convert_numpy_types(input_value: Any) -> Any: return input_value.item() else: return input_value - - -def convert_artifacts_for_json(value: Any) -> Any: - """ - Convert artifacts to JSON-serializable forms. - - - torch.dtype -> its string name (e.g., 'float16', 'bfloat16') - - Path -> str - - sets/tuples -> lists - - recursively handle dicts/lists - - Parameters - ---------- - value : Any - The value to convert. - - Returns - ------- - Any - The converted value. - """ - if isinstance(value, dict): - return {k: convert_artifacts_for_json(v) for k, v in value.items()} - if isinstance(value, list): - return [convert_artifacts_for_json(v) for v in value] - if isinstance(value, (tuple, set)): - return [convert_artifacts_for_json(v) for v in value] - if isinstance(value, torch.dtype): - # map to canonical string - if value == getattr(torch, "float16", None): - return "float16" - if value == getattr(torch, "bfloat16", None): - return "bfloat16" - return str(value) - if isinstance(value, Path): - return str(value) - return value diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index c4fd88e8..7679b09a 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -557,62 +557,6 @@ def load_hqq_diffusers(path: str | Path, smash_config: SmashConfig, **kwargs) -> return model -def load_moe_kernel_tuner(path: str | Path, smash_config: SmashConfig, **kwargs) -> Any: - """ - Load a tuned kernel config inside the hf/vllm caches, then load the model. - - Parameters - ---------- - path : str | Path - The path to the model directory. - smash_config : SmashConfig - The SmashConfig object containing the best configs for the MoE kernel tuner. - **kwargs : Any - Additional keyword arguments to pass to the model loading function. - - Returns - ------- - Any - The loaded MoE model. - """ - from pruna.algorithms.moe_kernel_tuner import MoeKernelTuner, save_configs - - imported_packages = MoeKernelTuner().import_algorithm_packages() - payload = getattr(smash_config, "artifacts", {}).get("moe_kernel_tuner") - if not payload: - error_msg = ( - "MoE kernel tuner artifacts not found in SmashConfig. " - "Ensure the tuner ran successfully before saving/loading." - ) - pruna_logger.error(error_msg) - raise RuntimeError(error_msg) - else: - best_configs = payload["best_configs_moe_kernel"] - num_experts = payload["num_experts"] - shard_intermediate_size = payload["shard_intermediate_size"] - dtype = payload["dtype"] - # Convert dtype string back to torch.dtype if needed - dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float16 - use_fp8_w8a8 = payload["use_fp8_w8a8"] - use_int8_w8a16 = payload["use_int8_w8a16"] - - # save the config attached to smash_config, inside the hf and vllm caches. - save_configs( - best_configs, - num_experts, - shard_intermediate_size, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - None, - smash_config["moe_kernel_tuner_path_to_huggingface_hub_cache"], - smash_config["moe_kernel_tuner_path_to_vllm_cache"], - imported_packages, - ) - smash_config.load_fns.remove(LOAD_FUNCTIONS.moe_kernel_tuner.name) - return load_transformers_model(path, smash_config, **kwargs) - - class LOAD_FUNCTIONS(Enum): # noqa: N801 """ Enumeration of load functions for different model types. @@ -649,7 +593,6 @@ class LOAD_FUNCTIONS(Enum): # noqa: N801 pickled = partial(load_pickled) hqq = partial(load_hqq) hqq_diffusers = partial(load_hqq_diffusers) - moe_kernel_tuner = partial(load_moe_kernel_tuner) def __call__(self, *args, **kwargs) -> Any: """ diff --git a/src/pruna/engine/load_artifacts.py b/src/pruna/engine/load_artifacts.py index f79cf5b9..fdef7c87 100644 --- a/src/pruna/engine/load_artifacts.py +++ b/src/pruna/engine/load_artifacts.py @@ -81,6 +81,56 @@ def load_torch_artifacts(model: Any, model_path: str | Path, smash_config: Smash torch.compiler.load_cache_artifacts(artifact_bytes) +def load_moe_kernel_tuner_artifacts(path: str | Path, smash_config: SmashConfig, **kwargs) -> Any: + """ + Load a tuned kernel config inside the hf/vllm caches, then load the model. + + Parameters + ---------- + path : str | Path + The path to the model directory. + smash_config : SmashConfig + The SmashConfig object containing the best configs for the MoE kernel tuner. + **kwargs : Any + Additional keyword arguments to pass to the model loading function. + + Returns + ------- + Any + The loaded MoE model. + """ + from pruna.algorithms.moe_kernel_tuner import MoeKernelTuner, save_configs + + imported_packages = MoeKernelTuner().import_algorithm_packages() + save_dir = Path(path) / "moe_kernel_tuned_configs" + payload = json.load(open(save_dir / "moe_kernel_tuner.json")) + if not payload: + raise ValueError(f"MoE kernel tuner artifacts not found in {save_dir}") + else: + best_configs = payload["best_configs_moe_kernel"] + num_experts = payload["num_experts"] + shard_intermediate_size = payload["shard_intermediate_size"] + dtype = payload["dtype"] + # Convert dtype string back to torch.dtype if needed + dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float16 + use_fp8_w8a8 = payload["use_fp8_w8a8"] + use_int8_w8a16 = payload["use_int8_w8a16"] + + # save the config attached to smash_config, inside the hf and vllm caches. + save_configs( + best_configs, + num_experts, + shard_intermediate_size, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + None, + smash_config["moe_kernel_tuner_path_to_huggingface_hub_cache"], + smash_config["moe_kernel_tuner_path_to_vllm_cache"], + imported_packages, + ) + + class LOAD_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 """ Enumeration of *artifact* load functions. @@ -116,6 +166,7 @@ class LOAD_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 """ torch_artifacts = partial(load_torch_artifacts) + moe_kernel_tuner_artifacts = partial(load_moe_kernel_tuner_artifacts) def __call__(self, *args, **kwargs) -> None: """Call the underlying load function.""" diff --git a/src/pruna/engine/save_artifacts.py b/src/pruna/engine/save_artifacts.py index e9da3e08..f7d2b421 100644 --- a/src/pruna/engine/save_artifacts.py +++ b/src/pruna/engine/save_artifacts.py @@ -14,6 +14,7 @@ from enum import Enum from functools import partial from pathlib import Path +import shutil from typing import Any import torch @@ -86,6 +87,18 @@ def save_torch_artifacts(model: Any, model_path: str | Path, smash_config: Smash smash_config.load_artifacts_fns.append(LOAD_ARTIFACTS_FUNCTIONS.torch_artifacts.name) +def save_moe_kernel_tuner_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: + """ + Move the tuned config from pruna cache into the model directory. + """ + save_dir = Path(model_path) / "moe_kernel_tuned_configs" + save_dir.mkdir(parents=True, exist_ok=True) + shutil.move(Path(smash_config.cache_dir) / "moe_kernel_tuned_configs", save_dir) + + # define here the load artifacts function + smash_config.load_artifacts_fns.append(LOAD_ARTIFACTS_FUNCTIONS.moe_kernel_tuner_artifacts.name) + + class SAVE_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 """ Enumeration of *artifact* save functions. @@ -121,6 +134,7 @@ class SAVE_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 """ torch_artifacts = partial(save_torch_artifacts) + moe_kernel_tuner_artifacts = partial(save_moe_kernel_tuner_artifacts) def __call__(self, *args, **kwargs) -> None: """ diff --git a/tests/algorithms/testers/moe_kernel_tuner.py b/tests/algorithms/testers/moe_kernel_tuner.py index 5c4da5c8..17612492 100644 --- a/tests/algorithms/testers/moe_kernel_tuner.py +++ b/tests/algorithms/testers/moe_kernel_tuner.py @@ -6,7 +6,7 @@ class TestMoeKernelTuner(AlgorithmTesterBase): """Test the MoeKernelTuner.""" - models = ["qwen3_coder_tiny"] + models = ["qwen3_next_moe_tiny_random"] reject_models = ["sd_tiny_random"] allow_pickle_files = False algorithm_class = MoeKernelTuner From 3f3bffa384dcdce099f42a56b2fe293ea936008f Mon Sep 17 00:00:00 2001 From: llcnt Date: Mon, 9 Feb 2026 14:54:41 +0000 Subject: [PATCH 17/26] fix: rebase issue adding double commas --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3c1b42cb..d2363d15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,7 +137,7 @@ dependencies = [ "vbench-pruna; sys_platform != 'darwin'", "imageio-ffmpeg", "jaxtyping", - "peft>=0.17.1",, + "peft>=0.17.1", "vllm>=0.11.0", ] From e779dbb2f1c06ca55b8cc3b6a786e9df4613d3b1 Mon Sep 17 00:00:00 2001 From: llcnt Date: Mon, 9 Feb 2026 15:08:54 +0000 Subject: [PATCH 18/26] fix: ruff linting --- src/pruna/engine/load_artifacts.py | 4 +++- src/pruna/engine/save_artifacts.py | 6 ++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/pruna/engine/load_artifacts.py b/src/pruna/engine/load_artifacts.py index fdef7c87..c2deb62d 100644 --- a/src/pruna/engine/load_artifacts.py +++ b/src/pruna/engine/load_artifacts.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +import json from enum import Enum from functools import partial from pathlib import Path @@ -103,7 +104,8 @@ def load_moe_kernel_tuner_artifacts(path: str | Path, smash_config: SmashConfig, imported_packages = MoeKernelTuner().import_algorithm_packages() save_dir = Path(path) / "moe_kernel_tuned_configs" - payload = json.load(open(save_dir / "moe_kernel_tuner.json")) + with open(save_dir / "moe_kernel_tuner.json") as f: + payload = json.load(f) if not payload: raise ValueError(f"MoE kernel tuner artifacts not found in {save_dir}") else: diff --git a/src/pruna/engine/save_artifacts.py b/src/pruna/engine/save_artifacts.py index f7d2b421..403fc43f 100644 --- a/src/pruna/engine/save_artifacts.py +++ b/src/pruna/engine/save_artifacts.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import shutil from enum import Enum from functools import partial from pathlib import Path -import shutil from typing import Any import torch @@ -88,9 +88,7 @@ def save_torch_artifacts(model: Any, model_path: str | Path, smash_config: Smash def save_moe_kernel_tuner_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: - """ - Move the tuned config from pruna cache into the model directory. - """ + """Move the tuned config from pruna cache into the model directory.""" save_dir = Path(model_path) / "moe_kernel_tuned_configs" save_dir.mkdir(parents=True, exist_ok=True) shutil.move(Path(smash_config.cache_dir) / "moe_kernel_tuned_configs", save_dir) From 61cc1d27c4cc4c8460da353d8c948414ca2f2f65 Mon Sep 17 00:00:00 2001 From: llcnt Date: Mon, 9 Feb 2026 15:20:01 +0000 Subject: [PATCH 19/26] feat: upd docstrings on moe artifacts --- src/pruna/engine/load_artifacts.py | 6 +++--- src/pruna/engine/save_artifacts.py | 18 +++++++++++++++++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/pruna/engine/load_artifacts.py b/src/pruna/engine/load_artifacts.py index c2deb62d..eb0a74df 100644 --- a/src/pruna/engine/load_artifacts.py +++ b/src/pruna/engine/load_artifacts.py @@ -84,16 +84,16 @@ def load_torch_artifacts(model: Any, model_path: str | Path, smash_config: Smash def load_moe_kernel_tuner_artifacts(path: str | Path, smash_config: SmashConfig, **kwargs) -> Any: """ - Load a tuned kernel config inside the hf/vllm caches, then load the model. + Load a tuned kernel config inside the hf/vllm caches. Parameters ---------- path : str | Path The path to the model directory. smash_config : SmashConfig - The SmashConfig object containing the best configs for the MoE kernel tuner. + The SmashConfig object. **kwargs : Any - Additional keyword arguments to pass to the model loading function. + Additional keyword arguments to pass to the function. Returns ------- diff --git a/src/pruna/engine/save_artifacts.py b/src/pruna/engine/save_artifacts.py index 403fc43f..7a954595 100644 --- a/src/pruna/engine/save_artifacts.py +++ b/src/pruna/engine/save_artifacts.py @@ -88,7 +88,23 @@ def save_torch_artifacts(model: Any, model_path: str | Path, smash_config: Smash def save_moe_kernel_tuner_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: - """Move the tuned config from pruna cache into the model directory.""" + """ + Move the tuned config from pruna cache into the model directory. + + Parameters + ---------- + model : Any + The model to save artifacts for. + model_path : str | Path + The directory where the model and its artifacts will be saved. + smash_config : SmashConfig + The SmashConfig object. + + Returns + ------- + None + This function does not return anything. + """ save_dir = Path(model_path) / "moe_kernel_tuned_configs" save_dir.mkdir(parents=True, exist_ok=True) shutil.move(Path(smash_config.cache_dir) / "moe_kernel_tuned_configs", save_dir) From dc52abb2f923b19a046c897f1b079af082d275c8 Mon Sep 17 00:00:00 2001 From: llcnt Date: Mon, 9 Feb 2026 16:21:03 +0000 Subject: [PATCH 20/26] feat: add try execpt for ray --- src/pruna/algorithms/moe_kernel_tuner.py | 60 +++++++++++++----------- src/pruna/engine/save_artifacts.py | 6 ++- 2 files changed, 37 insertions(+), 29 deletions(-) diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index 176c3495..6206342a 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -218,35 +218,43 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] # use ray to parallelize the tuning - ray.init() + ray.init(ignore_reinit_error=True) - is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) - search_space = get_configs_compute_bound(is_fp16, smash_config) + search_space = get_configs_compute_bound(smash_config) pruna_logger.info(f"Start tuning over {len(search_space)} configurations...") start = time.time() outputs = [] - for batch_size in batch_sizes: - output = tune.remote( - batch_size, # num_tokens - nb_experts, # num_experts per block - shard_intermediate_size, - hidden_size, - topk, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - search_space, - None, # we don't suport block quantization for now - False, # not use_deep_gemm - imported_packages, - 0, # random seed - smash_config["num_iters"], - ) - outputs.append(output) + configs = [] - configs = ray.get(outputs) - ray.shutdown() + # try/except to catch any exceptions that may stuck the ray workers. + try: + for batch_size in batch_sizes: + output = tune.remote( + batch_size, # num_tokens + nb_experts, # num_experts per block + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + search_space, + None, # we don't suport block quantization for now + False, # not use_deep_gemm + imported_packages, + 0, # random seed + smash_config["num_iters"], + ) + outputs.append(output) + + configs = ray.get(outputs) + finally: + # Try to shutdown ray, and catch any exceptions. + try: + ray.shutdown() + except Exception as e: + pruna_logger.warning(f"Exception during ray.shutdown(): {e}") # (iv) Sort the configs by batch size and save the best configs best_configs = { @@ -273,7 +281,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: best_configs_moe_kernel=best_configs, num_experts=nb_experts, shard_intermediate_size=shard_intermediate_size, - dtype=dtype, + dtype="bfloat16" if dtype == torch.bfloat16 else "float16", use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, ) @@ -460,14 +468,12 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: } -def get_configs_compute_bound(use_fp16: bool, smash_config: SmashConfigPrefixWrapper) -> list[dict[str, int]]: +def get_configs_compute_bound(smash_config: SmashConfigPrefixWrapper) -> list[dict[str, int]]: """ Get the gridsearch space for the kernel (tiling and warp scheduling). Parameters ---------- - use_fp16 : bool - Whether to use fp16. smash_config : SmashConfigPrefixWrapper The Smash configuration. diff --git a/src/pruna/engine/save_artifacts.py b/src/pruna/engine/save_artifacts.py index 7a954595..453b8fee 100644 --- a/src/pruna/engine/save_artifacts.py +++ b/src/pruna/engine/save_artifacts.py @@ -106,8 +106,10 @@ def save_moe_kernel_tuner_artifacts(model: Any, model_path: str | Path, smash_co This function does not return anything. """ save_dir = Path(model_path) / "moe_kernel_tuned_configs" - save_dir.mkdir(parents=True, exist_ok=True) - shutil.move(Path(smash_config.cache_dir) / "moe_kernel_tuned_configs", save_dir) + src_dir = Path(smash_config.cache_dir) / "moe_kernel_tuned_configs" + if save_dir.exists(): + shutil.rmtree(save_dir) + shutil.move(src_dir, save_dir) # define here the load artifacts function smash_config.load_artifacts_fns.append(LOAD_ARTIFACTS_FUNCTIONS.moe_kernel_tuner_artifacts.name) From 243e1b6b088627321f6ea95c230360cf9eaf3e4a Mon Sep 17 00:00:00 2001 From: llcnt Date: Tue, 10 Feb 2026 13:35:47 +0000 Subject: [PATCH 21/26] fix: moe artifacts load fn takes 3 args not 2 --- src/pruna/engine/load_artifacts.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/pruna/engine/load_artifacts.py b/src/pruna/engine/load_artifacts.py index eb0a74df..54de87f3 100644 --- a/src/pruna/engine/load_artifacts.py +++ b/src/pruna/engine/load_artifacts.py @@ -82,13 +82,15 @@ def load_torch_artifacts(model: Any, model_path: str | Path, smash_config: Smash torch.compiler.load_cache_artifacts(artifact_bytes) -def load_moe_kernel_tuner_artifacts(path: str | Path, smash_config: SmashConfig, **kwargs) -> Any: +def load_moe_kernel_tuner_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig, **kwargs) -> Any: """ Load a tuned kernel config inside the hf/vllm caches. Parameters ---------- - path : str | Path + model : Any + The model to load the artifacts for. + model_path : str | Path The path to the model directory. smash_config : SmashConfig The SmashConfig object. @@ -103,7 +105,7 @@ def load_moe_kernel_tuner_artifacts(path: str | Path, smash_config: SmashConfig, from pruna.algorithms.moe_kernel_tuner import MoeKernelTuner, save_configs imported_packages = MoeKernelTuner().import_algorithm_packages() - save_dir = Path(path) / "moe_kernel_tuned_configs" + save_dir = Path(model_path) / "moe_kernel_tuned_configs" with open(save_dir / "moe_kernel_tuner.json") as f: payload = json.load(f) if not payload: From e940c6410909044fcebf3f2f941ef454651ede20 Mon Sep 17 00:00:00 2001 From: llcnt Date: Tue, 17 Feb 2026 17:31:35 +0000 Subject: [PATCH 22/26] fix: review comment draft --- .github/actions/setup-uv-project/action.yml | 2 +- pyproject.toml | 5 +- src/pruna/algorithms/base/tags.py | 18 + src/pruna/algorithms/moe_kernel_tuner.py | 754 +++++------------- .../algorithms/utils/moe_kernel_tuner.py | 549 +++++++++++++ src/pruna/engine/load_artifacts.py | 64 +- src/pruna/engine/save_artifacts.py | 8 +- tests/fixtures.py | 4 +- 8 files changed, 824 insertions(+), 580 deletions(-) create mode 100644 src/pruna/algorithms/utils/moe_kernel_tuner.py diff --git a/.github/actions/setup-uv-project/action.yml b/.github/actions/setup-uv-project/action.yml index d32b697c..ecd89282 100644 --- a/.github/actions/setup-uv-project/action.yml +++ b/.github/actions/setup-uv-project/action.yml @@ -12,4 +12,4 @@ runs: github-token: ${{ github.token }} - shell: bash - run: uv sync --extra dev + run: uv sync --extra dev --extra vllm diff --git a/pyproject.toml b/pyproject.toml index d2363d15..5f2fdb61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -138,10 +138,13 @@ dependencies = [ "imageio-ffmpeg", "jaxtyping", "peft>=0.17.1", - "vllm>=0.11.0", ] [project.optional-dependencies] +vllm = [ + "vllm>=0.11.0", + "ray", +] stable-fast = [ "xformers>=0.0.30", "stable-fast-pruna==1.0.8", diff --git a/src/pruna/algorithms/base/tags.py b/src/pruna/algorithms/base/tags.py index e5dcd426..3d9d9478 100644 --- a/src/pruna/algorithms/base/tags.py +++ b/src/pruna/algorithms/base/tags.py @@ -80,6 +80,24 @@ class AlgorithmTag(Enum): "Resamplers change the shape of image or video latents during generation to speed up inference.", ) + @classmethod + def tags_compatible_with_moe_kernel(cls) -> list["AlgorithmTag"]: + """ + Return tags that the MoE kernel tuner is compatible with (for ordering). + + Used so that compatible_before / compatible_after can be derived from + the enum and stay in sync when new tags are added. + """ + return [ + cls.KERNEL, + cls.QUANTIZER, + cls.PRUNER, + cls.CACHER, + cls.FACTORIZER, + cls.BATCHER, + cls.COMPILER, + ] + def __init__(self, name: str, description: str): """ Initialize an algorithm tag with name and description. diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index 6206342a..8a5e8594 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -14,16 +14,10 @@ from __future__ import annotations import json -import pathlib import time from collections.abc import Iterable -from datetime import datetime -from importlib.util import find_spec -from itertools import product -from typing import Any, Dict, TypedDict +from typing import Any -import ray -import ray.experimental.tqdm_ray as tqdm_ray import torch from ConfigSpace import CategoricalHyperparameter, OrdinalHyperparameter @@ -32,15 +26,17 @@ from pruna.config.hyperparameters import UnconstrainedHyperparameter from pruna.config.smash_config import SmashConfigPrefixWrapper from pruna.engine.model_checks import is_moe_lm, is_transformers_pipeline_with_moe_lm +from pruna.engine.save import SAVE_FUNCTIONS from pruna.engine.save_artifacts import SAVE_ARTIFACTS_FUNCTIONS from pruna.logging.logger import pruna_logger class MoeKernelTuner(PrunaAlgorithmBase): """ - Tune the MoE Triton kernel for the model. + Tune the MoE (Mixture of Experts) Triton kernel for the model. - Uses vLLM to tune the MoE kernel. + Uses vLLM to tune the MoE kernel. If an existing artifact exists and the + Triton version matches, tuning is skipped and cached configs are reused. """ algorithm_name: str = "moe_kernel_tuner" @@ -53,25 +49,9 @@ class MoeKernelTuner(PrunaAlgorithmBase): processor_required: bool = False runs_on: list[str] = ["cuda", "accelerate"] dataset_required: bool = False - compatible_before: Iterable[str] = [ - tags.KERNEL, - tags.QUANTIZER, - tags.PRUNER, - tags.CACHER, - tags.FACTORIZER, - tags.BATCHER, - tags.COMPILER, - ] - compatible_after: Iterable[str] = [ - tags.KERNEL, - tags.QUANTIZER, - tags.PRUNER, - tags.CACHER, - tags.FACTORIZER, - tags.BATCHER, - tags.COMPILER, - ] - required_install = "``uv pip install vllm``" + compatible_before: Iterable[str] = tags.tags_compatible_with_moe_kernel() + compatible_after: Iterable[str] = tags.tags_compatible_with_moe_kernel() + required_install = "``uv pip install vllm>=0.11.0``" def get_hyperparameters(self) -> list: """ @@ -157,16 +137,17 @@ def model_check_fn(self, model: Any) -> bool: bool True if the model is a valid model for the algorithm, False otherwise. """ - # Hunyuan3-image is a MoE model, but not depending on mixtral if model.__class__.__name__ == "HunyuanImage3ForCausalMM": return True - else: - return is_moe_lm(model) or is_transformers_pipeline_with_moe_lm(model) + return is_moe_lm(model) or is_transformers_pipeline_with_moe_lm(model) def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: """ Tune the MoE Triton kernel for the model. + If an existing artifact exists in the cache with the same Triton version, + tuning is skipped and the cached configs are restored to the HF/vLLM caches. + Parameters ---------- model : Any @@ -184,29 +165,20 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: imported_packages = self.import_algorithm_packages() - # (i) Get the MoE parameters - model_config = model.config + # (i) Get the MoE parameters: exception first (Hunyuan), then general case. + model_config = getattr(model, "config", None) if model_config is None: raise ValueError(f"Model {model.__class__.__name__} has no config.") - nb_experts = model_config.num_experts # number of experts - # number of active experts per token - topk = ( - model_config.num_experts_per_tok - if is_moe_lm(model) - else model_config.moe_topk[0] - ) - # qwen_moe can use different intermediate size compared to mixtral. - intermediate_size = ( - getattr(model_config, "moe_intermediate_size", None) - if getattr(model_config, "moe_intermediate_size", None) is not None - else model_config.intermediate_size - ) - hidden_size = model_config.hidden_size # model hidden dim - assert intermediate_size % smash_config["tensor_parallel_size"] == 0, ( - f"intermediate_size {intermediate_size} is not divisible by tp " - f"{smash_config['tensor_parallel_size']}." - ) - shard_intermediate_size = 2 * intermediate_size // smash_config["tensor_parallel_size"] + + tensor_parallel_size = int(smash_config["tensor_parallel_size"]) + if model.__class__.__name__ == "HunyuanImage3ForCausalMM": + nb_experts, shard_intermediate_size, hidden_size, topk = extract_hunyuan_dimensions( + model, model_config, tensor_parallel_size + ) + else: + nb_experts, shard_intermediate_size, hidden_size, topk = ( + extract_transformers_moe_dimensions(model, model_config, tensor_parallel_size) + ) # (ii) Get the compute parameters dtype = smash_config["compute_dtype"] @@ -214,58 +186,73 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: use_fp8_w8a8 = smash_config["weight_dtype"] == "fp8_w8a8" use_int8_w8a16 = smash_config["weight_dtype"] == "int8_w8a16" - # (iii) Tune the kernel over a range of batch sizes + # (iii) Tune the kernel over a range of batch sizes (single GPU per Ray worker). batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] + ray = imported_packages["ray"] + tune_kernel = imported_packages["tune_kernel"] + get_configs_compute_bound = imported_packages["get_configs_compute_bound"] + ensure_benchmark_config = imported_packages["ensure_benchmark_config"] + save_configs = imported_packages["save_configs"] + NoValidConfigError = imported_packages["NoValidConfigError"] - # use ray to parallelize the tuning ray.init(ignore_reinit_error=True) - search_space = get_configs_compute_bound(smash_config) pruna_logger.info(f"Start tuning over {len(search_space)} configurations...") start = time.time() - outputs = [] - configs = [] + ray_outputs: list[Any] = [] + tune_kernel = ray.remote(num_gpus=1)(tune_kernel) - # try/except to catch any exceptions that may stuck the ray workers. + for batch_size in batch_sizes: + out = tune_kernel.remote( + batch_size, + nb_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + search_space, + None, + False, + imported_packages, + 0, # fixed seed for reproducibility + smash_config["num_iters"], + ) + ray_outputs.append(out) + + # Shutdown Ray processes on other devices and log any resulting exception as warnings. + tuned_config_by_batch_size: dict[int, Any] = {} try: - for batch_size in batch_sizes: - output = tune.remote( - batch_size, # num_tokens - nb_experts, # num_experts per block - shard_intermediate_size, - hidden_size, - topk, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - search_space, - None, # we don't suport block quantization for now - False, # not use_deep_gemm - imported_packages, - 0, # random seed - smash_config["num_iters"], - ) - outputs.append(output) - - configs = ray.get(outputs) + for batch_size, output_ref in zip(batch_sizes, ray_outputs): + try: + raw_config = ray.get(output_ref) + tuned_config_by_batch_size[batch_size] = ensure_benchmark_config(raw_config) + except NoValidConfigError: + pruna_logger.warning( + "No valid config for batch_size=%s; skipping (smaller batch sizes may still be used).", + batch_size, + ) finally: - # Try to shutdown ray, and catch any exceptions. try: ray.shutdown() except Exception as e: - pruna_logger.warning(f"Exception during ray.shutdown(): {e}") + pruna_logger.warning("Exception during ray.shutdown(): %s", e) + + if not tuned_config_by_batch_size: + raise RuntimeError( + "No valid kernel configuration was found for any batch size. " + "All configurations failed (e.g., due to OutOfResources). " + "This can happen on GPUs with limited resources. " + "Consider reducing your model size or tuning search space." + ) - # (iv) Sort the configs by batch size and save the best configs - best_configs = { - M: sort_config(config) for M, config in zip(batch_sizes, configs) - } end = time.time() pruna_logger.info(f"Tuning took {end - start:.2f} seconds") - # save configs in hf/vllm caches for direct usage save_configs( - best_configs, + tuned_config_by_batch_size, nb_experts, shard_intermediate_size, dtype, @@ -276,36 +263,46 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: smash_config["path_to_vllm_cache"], imported_packages, ) - # results to be saved for later loading - payload = dict( - best_configs_moe_kernel=best_configs, - num_experts=nb_experts, - shard_intermediate_size=shard_intermediate_size, - dtype="bfloat16" if dtype == torch.bfloat16 else "float16", - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - ) - # store artifacts in pruna cache for later saving/loading - save_dir = smash_config.cache_dir / "moe_kernel_tuned_configs" - save_dir.mkdir(parents=True, exist_ok=True) - with open(save_dir / "moe_kernel_tuner.json", "w") as f: - json.dump(payload, f) - # attach artifacts save function to the smash config for saving - smash_config.save_artifacts_fns.append(SAVE_ARTIFACTS_FUNCTIONS.moe_kernel_tuner_artifacts.name) + best_configs_and_hyperparameters = { + "triton_version": triton_version, + "best_configs_moe_kernel": tuned_config_by_batch_size, + "num_experts": nb_experts, + "shard_intermediate_size": shard_intermediate_size, + "dtype": "bfloat16" if dtype == torch.bfloat16 else "float16", + "use_fp8_w8a8": use_fp8_w8a8, + "use_int8_w8a16": use_int8_w8a16, + } + with open(smash_config.cache_dir / "moe_kernel_tuner.json", "w") as f: + json.dump(best_configs_and_hyperparameters, f) - # (v) Return the model (untouched; only the configs are saved) + smash_config.save_artifacts_fns.append( + SAVE_ARTIFACTS_FUNCTIONS.moe_kernel_tuner_artifacts.name + ) return model - def import_algorithm_packages(self) -> Dict[str, Any]: + def import_algorithm_packages(self) -> dict[str, Any]: """ - Import the algorithm packages. + Import the algorithm packages (vLLM, Triton, Ray, and utils). + + Ray is imported here so it is not a top-level dependency of the package. + Utils are imported here so they are only loaded when this algorithm is used. Returns ------- - Dict[str, Any] - The algorithm packages. + dict[str, Any] + The algorithm packages and helpers (tune, save_configs, etc.). """ + import ray + from pruna.algorithms.utils.moe_kernel_tuner import ( + BenchmarkConfig, + NoValidConfigError, + ensure_benchmark_config, + get_configs_compute_bound, + save_configs, + tune_kernel, + ) + import vllm.envs as envs import vllm.model_executor.layers.fused_moe.fused_moe as fused_moe import vllm.platforms as vllm_platforms @@ -324,473 +321,140 @@ def import_algorithm_packages(self) -> Dict[str, Any]: triton=triton, override_config=override_config, envs=envs, + ray=ray, + tune_kernel=_tune_kernel, + get_configs_compute_bound=get_configs_compute_bound, + ensure_benchmark_config=ensure_benchmark_config, + save_configs=save_configs, + NoValidConfigError=NoValidConfigError, + BenchmarkConfig=BenchmarkConfig, ) -class BenchmarkConfig(TypedDict): - """The configuration for the matrix multiplication (tiling and warp scheduling).""" - - BLOCK_SIZE_M: int - BLOCK_SIZE_N: int - BLOCK_SIZE_K: int - GROUP_SIZE_M: int - num_warps: int - num_stages: int - - -# Converts the function into a Ray actor and requests one GPU per actor instance -@ray.remote(num_gpus=1) -def tune( - num_tokens: int, - num_experts: int, - shard_intermediate_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - search_space: list[dict[str, int]], - block_quant_shape: list[int], - use_deep_gemm: bool, - imported_packages: Dict[str, Any], - seed: int, - num_iters: int, - ) -> dict[str, int]: +def extract_hunyuan_dimensions( + model: Any, + model_config: Any, + tensor_parallel_size: int, +) -> tuple[int, int, int, int]: """ - Tune a given Triton kernel. + Extract MoE dimensions for HunyuanImage3ForCausalMM. Parameters ---------- - num_tokens : int - The number of tokens in the batch. - num_experts : int - The number of experts. - shard_intermediate_size : int - The intermediate size of the model in the shard (if using tensor parallelism). - hidden_size : int - The hidden size of the model. - topk : int - The number of active experts per token. - dtype : torch.dtype - The dtype to use for the weights and activations. - use_fp8_w8a8 : bool - Whether to use fp8_w8a8. - use_int8_w8a16 : bool - Whether to use int8_w8a16. - search_space : list[dict[str, int]] - The search space for the kernel (tiling and warp scheduling). - block_quant_shape : list[int] - The block shape for the kernel (None here). - use_deep_gemm : bool - Whether to use deep gemm (False here). - imported_packages : Dict[str, Any] - The imported packages (vllm, triton, etc.). - seed : int - The random seed. - num_iters : int - The number of iterations to average the kernel time on. + model : Any + The model (used only for type context). + model_config : Any + The model's config object. Must have num_experts, moe_topk, intermediate_size, hidden_size. + tensor_parallel_size : int + Tensor parallel size (must divide intermediate_size). Returns ------- - dict[str, int] - The best config. + tuple[int, int, int, int] + (nb_experts, shard_intermediate_size, hidden_size, topk). + - nb_experts: number of experts in the MoE layer. + - shard_intermediate_size: intermediate dimension per GPU shard (after SiLU-and-mul). + - hidden_size: model hidden dimension. + - topk: number of active experts per token. """ - imported_packages["vllm_platforms"].current_platform.seed_everything(seed) - best_config = None - best_time = float("inf") - - for config in tqdm_ray.tqdm(search_space): - try: - kernel_time = benchmark_config( - config, - num_tokens, - num_experts, - shard_intermediate_size, - hidden_size, - topk, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - num_iters=num_iters, - block_quant_shape=block_quant_shape, - use_deep_gemm=use_deep_gemm, - imported_packages=imported_packages, - ) - except imported_packages["triton"].runtime.autotuner.OutOfResources: - # Some configurations may be invalid and fail to compile. - continue - - if kernel_time < best_time: - best_time, best_config = kernel_time, config - - now = datetime.now() - pruna_logger.info(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") - if best_config is None: - raise RuntimeError( - f"No valid kernel configuration was found for batch_size={num_tokens}. " - "All configurations failed (e.g., due to OutOfResources). " - "This can happen on GPUs with limited resources. " - "Consider reducing your model size, batch size, or tuning search space." + if getattr(model_config, "num_experts", None) is None: + raise ValueError( + f"Model config {type(model_config).__name__} is missing attribute 'num_experts'." + ) + if getattr(model_config, "moe_topk", None) is None: + raise ValueError( + f"Model config {type(model_config).__name__} is missing attribute 'moe_topk'." + ) + if getattr(model_config, "intermediate_size", None) is None: + raise ValueError( + f"Model config {type(model_config).__name__} is missing attribute 'intermediate_size'." + ) + if getattr(model_config, "hidden_size", None) is None: + raise ValueError( + f"Model config {type(model_config).__name__} is missing attribute 'hidden_size'." ) - return best_config - - -def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: - """ - Sort the configuration (tiling and warp scheduling). - - Parameters - ---------- - config : BenchmarkConfig - The configuration to sort. - - Returns - ------- - BenchmarkConfig - The sorted configuration. - """ - return { - "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], - "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], - "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], - "GROUP_SIZE_M": config["GROUP_SIZE_M"], - "num_warps": config["num_warps"], - "num_stages": config["num_stages"], - **( - {"waves_per_eu": config.get("waves_per_eu")} if "waves_per_eu" in config else {} - ), - **( - {"matrix_instr_nonkdim": config.get("matrix_instr_nonkdim")} - if "matrix_instr_nonkdim" in config - else {} - ), - **({"kpack": config.get("kpack")} if "kpack" in config else {}), - } - -def get_configs_compute_bound(smash_config: SmashConfigPrefixWrapper) -> list[dict[str, int]]: - """ - Get the gridsearch space for the kernel (tiling and warp scheduling). + nb_experts = int(model_config.num_experts) + topk = int(model_config.moe_topk[0]) + intermediate_size = int(model_config.intermediate_size) + hidden_size = int(model_config.hidden_size) - Parameters - ---------- - smash_config : SmashConfigPrefixWrapper - The Smash configuration. + if intermediate_size % tensor_parallel_size != 0: + raise ValueError( + f"Expected tensor_parallel_size to be a divisor of the model's intermediate size " + f"(MLP hidden dimension) {intermediate_size}, but got {tensor_parallel_size}." + ) + shard_intermediate_size = 2 * intermediate_size // tensor_parallel_size + return nb_experts, shard_intermediate_size, hidden_size, topk - Returns - ------- - list[dict[str, int]] - The search space for the kernel (tiling and warp scheduling). - """ - configs: list[BenchmarkConfig] = [] - - # Reduced search space for faster tuning. - block_m_range = [2**i for i in range(4, smash_config["block_size_m_max"] + 1)] - block_n_range = [2**i for i in range(5, smash_config["block_size_n_max"] + 1)] - block_k_range = [2**i for i in range(6, smash_config["block_size_k_max"] + 1)] - num_warps_range = [4, 8] - group_m_range = [1, 16, 32, 64] - num_stage_range = [2, 3, 4, 5] - - param_ranges = { - "BLOCK_SIZE_M": block_m_range, - "BLOCK_SIZE_N": block_n_range, - "BLOCK_SIZE_K": block_k_range, - "GROUP_SIZE_M": group_m_range, - "num_warps": num_warps_range, - "num_stages": num_stage_range, - } - keys, values = zip(*param_ranges.items()) - for config_values in product(*values): - config = dict(zip(keys, config_values)) - configs.append(config) - - return configs - - -def benchmark_config( - config: BenchmarkConfig, - num_tokens: int, - num_experts: int, - shard_intermediate_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - num_iters: int = 100, - block_quant_shape: list[int] = None, - use_deep_gemm: bool = False, - imported_packages: Dict[str, Any] = None, -) -> float: +def extract_transformers_moe_dimensions( + model: Any, + model_config: Any, + tensor_parallel_size: int, +) -> tuple[int, int, int, int]: """ - Benchmark a given Triton kernel using CUDAGraph. - - This function is copied from the vllm repository. - https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py + Extract MoE dimensions for standard transformers MoE LMs (e.g. Mixtral, Qwen-MoE). Parameters ---------- - config : BenchmarkConfig - The configuration to benchmark. - num_tokens : int - The number of tokens in the batch. - num_experts : int - The number of experts. - shard_intermediate_size : int - The intermediate size of the model in the shard (if using tensor parallelism). - hidden_size : int - The hidden size of the model. - topk : int - The number of active experts per token. - dtype : torch.dtype - The dtype to use for the weights and activations. - use_fp8_w8a8 : bool - Whether to use fp8_w8a8. - use_int8_w8a16 : bool - Whether to use int8_w8a16. - num_iters : int - The number of iterations to run the benchmark. - block_quant_shape : list[int] - The block shape for the kernel (None here). - use_deep_gemm : bool - Whether to use deep gemm (False here). - imported_packages : Dict[str, Any] - The imported packages (vllm, triton, etc.). + model : Any + The model (used to decide num_experts_per_tok vs moe_topk). + model_config : Any + The model's config. Must have num_experts, intermediate_size (or moe_intermediate_size), + hidden_size, and either num_experts_per_tok or moe_topk. + tensor_parallel_size : int + Tensor parallel size (must divide intermediate_size). Returns ------- - float - The average latency of the kernel. + tuple[int, int, int, int] + (nb_experts, shard_intermediate_size, hidden_size, topk). + - nb_experts: number of experts in the MoE layer. + - shard_intermediate_size: intermediate dimension per GPU shard. + - hidden_size: model hidden dimension. + - topk: number of active experts per token. """ - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required for MoeKernelTuner.") - # Ray sets CUDA_VISIBLE_DEVICES per worker to the GPU it scheduled - torch.cuda.set_device(0) - device = torch.device("cuda") - - init_dtype = torch.float16 if use_fp8_w8a8 else dtype - x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) - if use_int8_w8a16: - w1 = torch.randint( - -127, - 127, - ( - num_experts, - shard_intermediate_size, - hidden_size, - ), - dtype=torch.int8, - device=device, - ) - w2 = torch.randint( - -127, - 127, - ( - num_experts, - hidden_size, - shard_intermediate_size // 2, - ), - dtype=torch.int8, - device=device, - ) - else: - w1 = torch.randn( - num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype, device=device + if getattr(model_config, "num_experts", None) is None: + raise ValueError( + f"Model config {type(model_config).__name__} is missing attribute 'num_experts'." ) - w2 = torch.randn( - num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype, device=device + if getattr(model_config, "hidden_size", None) is None: + raise ValueError( + f"Model config {type(model_config).__name__} is missing attribute 'hidden_size'." ) - gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32, device=device) - - w1_scale = None - w2_scale = None - a1_scale = None - a2_scale = None - if use_int8_w8a16: - w1_scale = torch.randn( - (num_experts, 2 * shard_intermediate_size), dtype=torch.float32, device=device - ) - w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32, device=device) - if use_deep_gemm: - # we use the default block shape for deepgemm - block_quant_shape = [128, 128] - if use_fp8_w8a8: - if block_quant_shape: - block_n, block_k = block_quant_shape[0], block_quant_shape[1] - e = num_experts - n = shard_intermediate_size // 2 - k = hidden_size - factor_for_scale = 1e-2 - n_tiles_w1 = (2 * n + block_n - 1) // block_n - n_tiles_w2 = (k + block_n - 1) // block_n - k_tiles_w1 = (k + block_k - 1) // block_k - k_tiles_w2 = (n + block_k - 1) // block_k - w1_scale = ( - torch.rand((e, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device=device) - * factor_for_scale - ) - w2_scale = ( - torch.rand((e, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device=device) - * factor_for_scale - ) - else: - w1_scale = torch.randn(num_experts, dtype=torch.float32, device=device) - w2_scale = torch.randn(num_experts, dtype=torch.float32, device=device) - - a1_scale = torch.randn(1, dtype=torch.float32, device=device) - a2_scale = torch.randn(1, dtype=torch.float32, device=device) - - w1 = w1.to(device=device, dtype=imported_packages["vllm_platforms"].current_platform.fp8_dtype()) - w2 = w2.to(device=device, dtype=imported_packages["vllm_platforms"].current_platform.fp8_dtype()) - - input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32, device=device) - - def prepare(i: int): - input_gating.copy_(gating_output[i]) - def run(): - if use_fp8_w8a8: - quant_dtype = torch.float8_e4m3fn - elif use_int8_w8a16: - quant_dtype = torch.int8 - else: - quant_dtype = None - - quant_config = imported_packages["FusedMoEQuantConfig"].make( - quant_dtype=quant_dtype, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_quant_shape, - ) + nb_experts = int(model_config.num_experts) + hidden_size = int(model_config.hidden_size) - with imported_packages["override_config"](config): - topk_weights, topk_ids, token_expert_indices = imported_packages["FusedMoE"].fused_topk( - x, input_gating, topk, renormalize=not use_deep_gemm + if is_moe_lm(model): + if getattr(model_config, "num_experts_per_tok", None) is None: + raise ValueError( + f"Model config {type(model_config).__name__} is missing attribute 'num_experts_per_tok'." ) - return imported_packages["FusedMoE"].fused_experts( - x, - w1, - w2, - topk_weights, - topk_ids, - inplace=True, - quant_config=quant_config, - allow_deep_gemm=use_deep_gemm, + topk = int(model_config.num_experts_per_tok) + else: + if getattr(model_config, "moe_topk", None) is None: + raise ValueError( + f"Model config {type(model_config).__name__} is missing attribute 'moe_topk'." ) + topk = int(model_config.moe_topk[0]) - # JIT compilation & warmup - run() - torch.cuda.synchronize() - - # Capture 10 invocations with CUDA graph - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): - for _ in range(10): - run() - torch.cuda.synchronize() - - # Warmup - for _ in range(5): - graph.replay() - torch.cuda.synchronize() - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - latencies: list[float] = [] - for i in range(num_iters): - prepare(i) - torch.cuda.synchronize() - - start_event.record() - graph.replay() - end_event.record() - end_event.synchronize() - latencies.append(start_event.elapsed_time(end_event)) - avg = sum(latencies) / (num_iters * 10) * 1000 # us - graph.reset() - return avg - - -def save_configs( - configs: dict[int, BenchmarkConfig], - num_experts: int, - shard_intermediate_size: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - block_quant_shape: list[int], - path_to_huggingface_hub_cache: str, - path_to_vllm_cache: str, - imported_packages: Dict[str, Any], -) -> None: - """ - Save the best configs to the hf cache and vllm cache. + moe_intermediate = getattr(model_config, "moe_intermediate_size", None) + if moe_intermediate is not None: + intermediate_size = int(moe_intermediate) + else: + if getattr(model_config, "intermediate_size", None) is None: + raise ValueError( + f"Model config {type(model_config).__name__} is missing attribute " + "'intermediate_size' or 'moe_intermediate_size'." + ) + intermediate_size = int(model_config.intermediate_size) - Parameters - ---------- - configs : dict[int, BenchmarkConfig] - The best configs. - num_experts : int - The number of experts. - shard_intermediate_size : int - The intermediate size of the model in the shard (if using tensor parallelism). - dtype : torch.dtype - The dtype to use for the weights and activations. - use_fp8_w8a8 : bool - Whether to use fp8_w8a8. - use_int8_w8a16 : bool - Whether to use int8_w8a16. - block_quant_shape : list[int] - The block shape for the kernel (None here). - path_to_huggingface_hub_cache : str - The path to the huggingface hub cache. - path_to_vllm_cache : str - The path to the vllm cache. - imported_packages : Dict[str, Any] - The imported packages (vllm, triton, etc.). - """ - dtype_str = imported_packages["_get_config_dtype_str"]( - dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 - ) - - # (i) Get the name of the config file - # NB from vllm: The current naming convention uses w2.shape[2], which - # is the intermediate size after silu_and_mul. - filename = imported_packages["FusedMoE"].get_config_file_name( - num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape - ) - - # (ii) Save the config to the hf cache (where `kernels` lib expects to find it) - path_to_kernel_configs = ( - pathlib.Path(path_to_huggingface_hub_cache) / - ".cache/huggingface/hub/models--RedHatAI--moe/blobs/configs" - ) - pathlib.Path(path_to_kernel_configs).mkdir(exist_ok=True, parents=True) - filename_hf = path_to_kernel_configs / filename - if not pathlib.Path(filename_hf).exists(): - pruna_logger.info(f"Writing best config to {filename_hf}...") - with open(filename_hf, "w") as f: - json.dump({"triton_version": imported_packages["triton"].__version__, **configs}, f, indent=4) - f.write("\n") - - # (iii) Save the config to the vllm cache (where `vllm` expects to find it) - path_to_vllm_configs = imported_packages["envs"].VLLM_TUNED_CONFIG_FOLDER - if path_to_vllm_configs is None: - submodule_locations = find_spec("vllm").submodule_search_locations - if submodule_locations is not None and len(submodule_locations) > 0: - path_where_vllm_is_installed = submodule_locations[0] - else: - raise RuntimeError("Could not determine installation path for vllm.") - path_to_vllm_configs = pathlib.Path(path_where_vllm_is_installed).parent / path_to_vllm_cache - pathlib.Path(path_to_vllm_configs).mkdir(exist_ok=True, parents=True) - filename_vllm = path_to_vllm_configs / filename - if not pathlib.Path(filename_vllm).exists(): - pruna_logger.info(f"Writing best config to {filename_vllm}...") - with open(filename_vllm, "w") as f: - json.dump({"triton_version": imported_packages["triton"].__version__, **configs}, f, indent=4) - f.write("\n") + if intermediate_size % tensor_parallel_size != 0: + raise ValueError( + f"Expected tensor_parallel_size to be a divisor of the model's intermediate size " + f"(MLP hidden dimension) {intermediate_size}, but got {tensor_parallel_size}." + ) + shard_intermediate_size = 2 * intermediate_size // tensor_parallel_size + return nb_experts, shard_intermediate_size, hidden_size, topk diff --git a/src/pruna/algorithms/utils/moe_kernel_tuner.py b/src/pruna/algorithms/utils/moe_kernel_tuner.py new file mode 100644 index 00000000..747bbd98 --- /dev/null +++ b/src/pruna/algorithms/utils/moe_kernel_tuner.py @@ -0,0 +1,549 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import json +import pathlib +from importlib.util import find_spec +from itertools import product +from typing import Any, Dict, TypedDict + +import torch + +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.logging.logger import pruna_logger + + +class NoValidConfigError(RuntimeError): + """ + Raised when no valid kernel configuration was found for a given batch size. + + All configurations failed (e.g., due to OutOfResources). This can happen on + GPUs with limited resources. The caller may retry with smaller batch sizes. + """ + + +class BenchmarkConfig(TypedDict): + """The configuration for the matrix multiplication (tiling and warp scheduling).""" + + BLOCK_SIZE_M: int + BLOCK_SIZE_N: int + BLOCK_SIZE_K: int + GROUP_SIZE_M: int + num_warps: int + num_stages: int + + +def tune_kernel( + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + search_space: list[dict[str, int]], + block_quant_shape: list[int] | None, + use_deep_gemm: bool, + imported_packages: Dict[str, Any], + seed: int, + num_iters: int, +) -> BenchmarkConfig: + """ + Tune a given Triton kernel (run inside a Ray worker; no Ray at module level). + + Parameters + ---------- + num_tokens : int + Batch size in tokens (sequence length × batch dimension). + num_experts : int + Number of experts in the MoE layer. + shard_intermediate_size : int + Intermediate size of the model in the shard (if using tensor parallelism). + hidden_size : int + Model hidden dimension (input/output of the expert layer). + topk : int + Number of active experts per token. + dtype : torch.dtype + Dtype for weights and activations. + use_fp8_w8a8 : bool + Whether to use fp8_w8a8. + use_int8_w8a16 : bool + Whether to use int8_w8a16. + search_space : list[dict[str, int]] + Search space for the kernel (tiling and warp scheduling). + block_quant_shape : list[int] | None + Block shape for the kernel (None here). + use_deep_gemm : bool + Whether to use deep gemm (False here). + imported_packages : Dict[str, Any] + Imported packages (vllm, triton, etc.). + seed : int + Random seed for reproducibility. + num_iters : int + Number of iterations to average the kernel time on. + + Returns + ------- + BenchmarkConfig + The best config found. + """ + import ray.experimental.tqdm_ray as tqdm_ray + from datetime import datetime + + imported_packages["vllm_platforms"].current_platform.seed_everything(seed) + best_config = None + best_time = float("inf") + + for config in tqdm_ray.tqdm(search_space): + try: + kernel_time = benchmark_config( + config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=num_iters, + block_quant_shape=block_quant_shape, + use_deep_gemm=use_deep_gemm, + imported_packages=imported_packages, + ) + except imported_packages["triton"].runtime.autotuner.OutOfResources: + continue + + if kernel_time < best_time: + best_time, best_config = kernel_time, config + + now = datetime.now() + pruna_logger.info(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") + if best_config is None: + raise NoValidConfigError( + f"No valid kernel configuration was found for batch_size={num_tokens}. " + "All configurations failed (e.g., due to OutOfResources). " + "This can happen on GPUs with limited resources. " + "Consider reducing your model size, or tuning search space." + ) + return best_config + + +def ensure_benchmark_config(config: dict[str, Any]) -> BenchmarkConfig: + """ + Convert the raw dict returned by tune to a canonical BenchmarkConfig. + + Preserves optional keys (e.g. waves_per_eu, matrix_instr_nonkdim, kpack) when + present so that the config file format stays compatible with vLLM. + + Parameters + ---------- + config : dict[str, Any] + Raw config from tune (same keys as BenchmarkConfig, possibly with extras). + + Returns + ------- + BenchmarkConfig + Config with required keys and any optional keys that were present. + """ + result: dict[str, Any] = { + "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], + "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], + "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], + "GROUP_SIZE_M": config["GROUP_SIZE_M"], + "num_warps": config["num_warps"], + "num_stages": config["num_stages"], + } + if "waves_per_eu" in config: + result["waves_per_eu"] = config.get("waves_per_eu") + if "matrix_instr_nonkdim" in config: + result["matrix_instr_nonkdim"] = config.get("matrix_instr_nonkdim") + if "kpack" in config: + result["kpack"] = config.get("kpack") + return result + + +def get_configs_compute_bound(smash_config: SmashConfigPrefixWrapper) -> list[dict[str, int]]: + """ + Get the grid-search space for the kernel (tiling and warp scheduling). + + Parameters + ---------- + smash_config : SmashConfigPrefixWrapper + The Smash configuration. + + Returns + ------- + list[dict[str, int]] + The search space for the kernel (tiling and warp scheduling). + """ + configs: list[dict[str, int]] = [] + + block_m_range = [2**i for i in range(4, smash_config["block_size_m_max"] + 1)] + block_n_range = [2**i for i in range(5, smash_config["block_size_n_max"] + 1)] + block_k_range = [2**i for i in range(6, smash_config["block_size_k_max"] + 1)] + num_warps_range = [4, 8] + group_m_range = [1, 16, 32, 64] + num_stage_range = [2, 3, 4, 5] + + param_ranges = { + "BLOCK_SIZE_M": block_m_range, + "BLOCK_SIZE_N": block_n_range, + "BLOCK_SIZE_K": block_k_range, + "GROUP_SIZE_M": group_m_range, + "num_warps": num_warps_range, + "num_stages": num_stage_range, + } + + keys, values = zip(*param_ranges.items()) + for config_values in product(*values): + config = dict(zip(keys, config_values)) + configs.append(config) + + return configs + + +def benchmark_config( + config: BenchmarkConfig, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + block_quant_shape: list[int] | None = None, + use_deep_gemm: bool = False, + imported_packages: Dict[str, Any] | None = None, +) -> float: + """ + Benchmark a given Triton kernel using CUDAGraph. + + This function is copied from the vLLM repository. + https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py + + Tuning runs on a single GPU per Ray worker; Ray sets CUDA_VISIBLE_DEVICES + so that set_device(0) in this function selects the assigned GPU. + + Parameters + ---------- + config : BenchmarkConfig + The configuration to benchmark. + num_tokens : int + Batch size in tokens. + num_experts : int + Number of experts. + shard_intermediate_size : int + Intermediate size of the model in the shard (if using tensor parallelism). + hidden_size : int + Model hidden dimension. + topk : int + Number of active experts per token. + dtype : torch.dtype + Dtype for weights and activations. + use_fp8_w8a8 : bool + Whether to use fp8_w8a8. + use_int8_w8a16 : bool + Whether to use int8_w8a16. + num_iters : int + Number of iterations to run the benchmark. + block_quant_shape : list[int] | None + Block shape for the kernel (None here). + use_deep_gemm : bool + Whether to use deep gemm (False here). + imported_packages : Dict[str, Any] | None + Imported packages (vllm, triton, etc.). + + Returns + ------- + float + Average latency per kernel invocation in microseconds (each CUDA graph + replay runs 10 kernel invocations; we average over num_iters replays + then convert to µs). + """ + if imported_packages is None: + raise ValueError("imported_packages is required") + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for MoeKernelTuner.") + # Ray sets CUDA_VISIBLE_DEVICES per worker to the GPU it scheduled; use device 0. + torch.cuda.set_device(0) + device = torch.device("cuda") + + init_dtype = torch.float16 if use_fp8_w8a8 else dtype + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + if use_int8_w8a16: + w1 = torch.randint( + -127, + 127, + ( + num_experts, + shard_intermediate_size, + hidden_size, + ), + dtype=torch.int8, + device=device, + ) + w2 = torch.randint( + -127, + 127, + ( + num_experts, + hidden_size, + shard_intermediate_size // 2, + ), + dtype=torch.int8, + device=device, + ) + else: + w1 = torch.randn( + num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype, device=device + ) + w2 = torch.randn( + num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype, device=device + ) + gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32, device=device) + + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + if use_int8_w8a16: + w1_scale = torch.randn( + (num_experts, 2 * shard_intermediate_size), dtype=torch.float32, device=device + ) + w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32, device=device) + if use_deep_gemm: + block_quant_shape = [128, 128] + if use_fp8_w8a8: + if block_quant_shape: + block_n, block_k = block_quant_shape[0], block_quant_shape[1] + e = num_experts + n = shard_intermediate_size // 2 + k = hidden_size + factor_for_scale = 1e-2 + n_tiles_w1 = (2 * n + block_n - 1) // block_n + n_tiles_w2 = (k + block_n - 1) // block_n + k_tiles_w1 = (k + block_k - 1) // block_k + k_tiles_w2 = (n + block_k - 1) // block_k + w1_scale = ( + torch.rand((e, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device=device) + * factor_for_scale + ) + w2_scale = ( + torch.rand((e, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device=device) + * factor_for_scale + ) + else: + w1_scale = torch.randn(num_experts, dtype=torch.float32, device=device) + w2_scale = torch.randn(num_experts, dtype=torch.float32, device=device) + + a1_scale = torch.randn(1, dtype=torch.float32, device=device) + a2_scale = torch.randn(1, dtype=torch.float32, device=device) + + w1 = w1.to(device=device, dtype=imported_packages["vllm_platforms"].current_platform.fp8_dtype()) + w2 = w2.to(device=device, dtype=imported_packages["vllm_platforms"].current_platform.fp8_dtype()) + + input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32, device=device) + + def prepare(i: int): + input_gating.copy_(gating_output[i]) + + def run(): + if use_fp8_w8a8: + quant_dtype = torch.float8_e4m3fn + elif use_int8_w8a16: + quant_dtype = torch.int8 + else: + quant_dtype = None + + quant_config = imported_packages["FusedMoEQuantConfig"].make( + quant_dtype=quant_dtype, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_quant_shape, + ) + + with imported_packages["override_config"](config): + topk_weights, topk_ids, token_expert_indices = imported_packages["FusedMoE"].fused_topk( + x, input_gating, topk, renormalize=not use_deep_gemm + ) + return imported_packages["FusedMoE"].fused_experts( + x, + w1, + w2, + topk_weights, + topk_ids, + inplace=True, + quant_config=quant_config, + allow_deep_gemm=use_deep_gemm, + ) + + run() + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(10): + run() + torch.cuda.synchronize() + + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + prepare(i) + torch.cuda.synchronize() + + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + # Average latency per kernel invocation in microseconds (each replay runs 10 invocations). + avg_us = sum(latencies) / (num_iters * 10) * 1000 + graph.reset() + return avg_us + + +def write_config_file( + path: pathlib.Path, + data: dict[str, Any], + merge_if_exists: bool, + write_only_if_missing: bool, + log_label: str, +) -> None: + """ + Write config dict to a JSON file, optionally merging with existing content. + + Parameters + ---------- + path : pathlib.Path + Target file path. + data : dict[str, Any] + Config to write (will be merged into existing if merge_if_exists and file exists). + merge_if_exists : bool + If True and path exists, load existing JSON, update with data, then write. + write_only_if_missing : bool + If True and path already exists, do not write (preserves existing file). + log_label : str + Label for the log message when writing. + """ + if write_only_if_missing and path.exists(): + return + if merge_if_exists and path.exists(): + with open(path) as f: + existing: dict[str, Any] = json.load(f) + existing.update(data) + data = existing + path.parent.mkdir(parents=True, exist_ok=True) + pruna_logger.info(f"Writing best config to {log_label}...") + with open(path, "w") as f: + json.dump(data, f, indent=4) + f.write("\n") + + +def save_configs( + configs: dict[int, BenchmarkConfig], + num_experts: int, + shard_intermediate_size: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + block_quant_shape: list[int] | None, + path_to_huggingface_hub_cache: str, + path_to_vllm_cache: str, + imported_packages: Dict[str, Any], +) -> None: + """ + Save the best configs to the HF cache and vLLM cache. + + Writes one file per cache; for the HF path we merge with existing file content + if present so that other keys are preserved. Uses a shared helper to avoid + duplicating the write logic. + + Parameters + ---------- + configs : dict[int, BenchmarkConfig] + Best config per batch size (keys are batch sizes). + num_experts : int + Number of experts. + shard_intermediate_size : int + Intermediate size of the model in the shard (if using tensor parallelism). + dtype : torch.dtype + Dtype for weights and activations. + use_fp8_w8a8 : bool + Whether to use fp8_w8a8. + use_int8_w8a16 : bool + Whether to use int8_w8a16. + block_quant_shape : list[int] | None + Block shape for the kernel (None here). + path_to_huggingface_hub_cache : str + Path to the Hugging Face Hub cache. + path_to_vllm_cache : str + Path to the vLLM MoE configs directory (used when VLLM_TUNED_CONFIG_FOLDER is None). + imported_packages : Dict[str, Any] + Imported packages (vllm, triton, etc.). + """ + dtype_str = imported_packages["_get_config_dtype_str"]( + dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 + ) + + filename = imported_packages["FusedMoE"].get_config_file_name( + num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape + ) + + triton_version = imported_packages["triton"].__version__ + # JSON keys must be strings; configs keys are int (batch sizes). + data: dict[str, Any] = {"triton_version": triton_version} + for k, v in configs.items(): + data[str(k)] = v + + path_to_kernel_configs = ( + pathlib.Path(path_to_huggingface_hub_cache) / + ".cache/huggingface/hub/models--RedHatAI--moe/blobs/configs" + ) + path_hf = path_to_kernel_configs / filename + write_config_file( + path_hf, data, + merge_if_exists=True, + write_only_if_missing=False, + log_label=str(path_hf), + ) + + path_to_vllm_configs = imported_packages["envs"].VLLM_TUNED_CONFIG_FOLDER + if path_to_vllm_configs is None: + submodule_locations = find_spec("vllm").submodule_search_locations + if submodule_locations is not None and len(submodule_locations) > 0: + path_where_vllm_is_installed = submodule_locations[0] + else: + raise RuntimeError("Could not determine installation path for vllm.") + path_to_vllm_configs = pathlib.Path(path_where_vllm_is_installed).parent / path_to_vllm_cache + path_vllm = pathlib.Path(path_to_vllm_configs) / filename + write_config_file( + path_vllm, data, + merge_if_exists=False, + write_only_if_missing=True, + log_label=str(path_vllm), + ) diff --git a/src/pruna/engine/load_artifacts.py b/src/pruna/engine/load_artifacts.py index 54de87f3..a759b74f 100644 --- a/src/pruna/engine/load_artifacts.py +++ b/src/pruna/engine/load_artifacts.py @@ -102,37 +102,49 @@ def load_moe_kernel_tuner_artifacts(model: Any, model_path: str | Path, smash_co Any The loaded MoE model. """ - from pruna.algorithms.moe_kernel_tuner import MoeKernelTuner, save_configs + from pruna.algorithms.moe_kernel_tuner import MoeKernelTuner + from pruna.algorithms.utils.moe_kernel_tuner import save_configs imported_packages = MoeKernelTuner().import_algorithm_packages() - save_dir = Path(model_path) / "moe_kernel_tuned_configs" + save_dir = Path(model_path) with open(save_dir / "moe_kernel_tuner.json") as f: - payload = json.load(f) - if not payload: + best_configs_and_hyperparameters = json.load(f) + if not best_configs_and_hyperparameters: raise ValueError(f"MoE kernel tuner artifacts not found in {save_dir}") else: - best_configs = payload["best_configs_moe_kernel"] - num_experts = payload["num_experts"] - shard_intermediate_size = payload["shard_intermediate_size"] - dtype = payload["dtype"] - # Convert dtype string back to torch.dtype if needed - dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float16 - use_fp8_w8a8 = payload["use_fp8_w8a8"] - use_int8_w8a16 = payload["use_int8_w8a16"] - - # save the config attached to smash_config, inside the hf and vllm caches. - save_configs( - best_configs, - num_experts, - shard_intermediate_size, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - None, - smash_config["moe_kernel_tuner_path_to_huggingface_hub_cache"], - smash_config["moe_kernel_tuner_path_to_vllm_cache"], - imported_packages, - ) + # check if the triton version is the same as the one used to tune the kernel + triton_version = best_configs_and_hyperparameters["triton_version"] + if triton_version != imported_packages["triton"].__version__: + msg = ( + f"Triton version mismatch: {triton_version} != " + f"{imported_packages['triton'].__version__}. " + "Performance may be degraded or config may be invalid. " + "We recommend re-tuning the kernel in your environment." + ) + raise pruna_logger.info(msg) + else: + best_configs = best_configs_and_hyperparameters["best_configs_moe_kernel"] + num_experts = best_configs_and_hyperparameters["num_experts"] + shard_intermediate_size = best_configs_and_hyperparameters["shard_intermediate_size"] + dtype = best_configs_and_hyperparameters["dtype"] + # Convert dtype string back to torch.dtype if needed + dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float16 + use_fp8_w8a8 = best_configs_and_hyperparameters["use_fp8_w8a8"] + use_int8_w8a16 = best_configs_and_hyperparameters["use_int8_w8a16"] + + # save the config attached to smash_config, inside the hf and vllm caches. + save_configs( + best_configs, + num_experts, + shard_intermediate_size, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + None, + smash_config["moe_kernel_tuner_path_to_huggingface_hub_cache"], + smash_config["moe_kernel_tuner_path_to_vllm_cache"], + imported_packages, + ) class LOAD_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 diff --git a/src/pruna/engine/save_artifacts.py b/src/pruna/engine/save_artifacts.py index 453b8fee..e13e2060 100644 --- a/src/pruna/engine/save_artifacts.py +++ b/src/pruna/engine/save_artifacts.py @@ -105,11 +105,9 @@ def save_moe_kernel_tuner_artifacts(model: Any, model_path: str | Path, smash_co None This function does not return anything. """ - save_dir = Path(model_path) / "moe_kernel_tuned_configs" - src_dir = Path(smash_config.cache_dir) / "moe_kernel_tuned_configs" - if save_dir.exists(): - shutil.rmtree(save_dir) - shutil.move(src_dir, save_dir) + src_file = Path(smash_config.cache_dir) / "moe_kernel_tuner.json" + dest_file = Path(model_path) / "moe_kernel_tuner.json" + shutil.move(src_file, dest_file) # define here the load artifacts function smash_config.load_artifacts_fns.append(LOAD_ARTIFACTS_FUNCTIONS.moe_kernel_tuner_artifacts.name) diff --git a/tests/fixtures.py b/tests/fixtures.py index 1c2f24c1..f9bb2f7f 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -197,7 +197,7 @@ def get_autoregressive_text_to_image_model(model_id: str) -> tuple[Any, SmashCon "wan_tiny_random": partial(get_diffusers_model, "pruna-test/wan-t2v-tiny-random", torch_dtype=torch.bfloat16), "flux_tiny": partial(get_diffusers_model, "pruna-test/tiny_flux", torch_dtype=torch.float16), "tiny_llama": partial(get_automodel_transformers, "pruna-test/tiny_llama", torch_dtype=torch.bfloat16), - "qwen3_next_moe_tiny_random": partial( - get_automodel_transformers, "tiny-random/qwen3-next-moe", torch_dtype=torch.bfloat16 + "qwen_moe_tiny_random": partial( + get_automodel_transformers, "yujiepan/qwen1.5-moe-tiny-random", torch_dtype=torch.bfloat16 ), } From 12d5b9f65f73f257ad531cecf5cbd59edfa6417d Mon Sep 17 00:00:00 2001 From: llcnt Date: Tue, 17 Feb 2026 18:21:30 +0000 Subject: [PATCH 23/26] fix: linting --- src/pruna/algorithms/moe_kernel_tuner.py | 67 +++++++------------ .../algorithms/utils/moe_kernel_tuner.py | 31 +++------ tests/algorithms/testers/moe_kernel_tuner.py | 2 +- 3 files changed, 34 insertions(+), 66 deletions(-) diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index 8a5e8594..8876f83a 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -26,8 +26,6 @@ from pruna.config.hyperparameters import UnconstrainedHyperparameter from pruna.config.smash_config import SmashConfigPrefixWrapper from pruna.engine.model_checks import is_moe_lm, is_transformers_pipeline_with_moe_lm -from pruna.engine.save import SAVE_FUNCTIONS -from pruna.engine.save_artifacts import SAVE_ARTIFACTS_FUNCTIONS from pruna.logging.logger import pruna_logger @@ -176,8 +174,8 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: model, model_config, tensor_parallel_size ) else: - nb_experts, shard_intermediate_size, hidden_size, topk = ( - extract_transformers_moe_dimensions(model, model_config, tensor_parallel_size) + nb_experts, shard_intermediate_size, hidden_size, topk = extract_transformers_moe_dimensions( + model, model_config, tensor_parallel_size ) # (ii) Get the compute parameters @@ -193,7 +191,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: get_configs_compute_bound = imported_packages["get_configs_compute_bound"] ensure_benchmark_config = imported_packages["ensure_benchmark_config"] save_configs = imported_packages["save_configs"] - NoValidConfigError = imported_packages["NoValidConfigError"] + no_valid_config_error = imported_packages["NoValidConfigError"] ray.init(ignore_reinit_error=True) search_space = get_configs_compute_bound(smash_config) @@ -229,7 +227,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: try: raw_config = ray.get(output_ref) tuned_config_by_batch_size[batch_size] = ensure_benchmark_config(raw_config) - except NoValidConfigError: + except no_valid_config_error: pruna_logger.warning( "No valid config for batch_size=%s; skipping (smaller batch sizes may still be used).", batch_size, @@ -265,7 +263,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: ) best_configs_and_hyperparameters = { - "triton_version": triton_version, + "triton_version": imported_packages["triton"].__version__, "best_configs_moe_kernel": tuned_config_by_batch_size, "num_experts": nb_experts, "shard_intermediate_size": shard_intermediate_size, @@ -276,9 +274,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: with open(smash_config.cache_dir / "moe_kernel_tuner.json", "w") as f: json.dump(best_configs_and_hyperparameters, f) - smash_config.save_artifacts_fns.append( - SAVE_ARTIFACTS_FUNCTIONS.moe_kernel_tuner_artifacts.name - ) + smash_config.save_artifacts_fns.append(SAVE_ARTIFACTS_FUNCTIONS.moe_kernel_tuner_artifacts.name) return model def import_algorithm_packages(self) -> dict[str, Any]: @@ -294,15 +290,6 @@ def import_algorithm_packages(self) -> dict[str, Any]: The algorithm packages and helpers (tune, save_configs, etc.). """ import ray - from pruna.algorithms.utils.moe_kernel_tuner import ( - BenchmarkConfig, - NoValidConfigError, - ensure_benchmark_config, - get_configs_compute_bound, - save_configs, - tune_kernel, - ) - import vllm.envs as envs import vllm.model_executor.layers.fused_moe.fused_moe as fused_moe import vllm.platforms as vllm_platforms @@ -312,6 +299,14 @@ def import_algorithm_packages(self) -> dict[str, Any]: _get_config_dtype_str, ) from vllm.triton_utils import triton + from pruna.algorithms.utils.moe_kernel_tuner import ( + BenchmarkConfig, + NoValidConfigError, + ensure_benchmark_config, + get_configs_compute_bound, + save_configs, + tune_kernel, + ) return dict( FusedMoEQuantConfig=FusedMoEQuantConfig, @@ -322,7 +317,7 @@ def import_algorithm_packages(self) -> dict[str, Any]: override_config=override_config, envs=envs, ray=ray, - tune_kernel=_tune_kernel, + tune_kernel=tune_kernel, get_configs_compute_bound=get_configs_compute_bound, ensure_benchmark_config=ensure_benchmark_config, save_configs=save_configs, @@ -358,21 +353,13 @@ def extract_hunyuan_dimensions( - topk: number of active experts per token. """ if getattr(model_config, "num_experts", None) is None: - raise ValueError( - f"Model config {type(model_config).__name__} is missing attribute 'num_experts'." - ) + raise ValueError(f"Model config {type(model_config).__name__} is missing attribute 'num_experts'.") if getattr(model_config, "moe_topk", None) is None: - raise ValueError( - f"Model config {type(model_config).__name__} is missing attribute 'moe_topk'." - ) + raise ValueError(f"Model config {type(model_config).__name__} is missing attribute 'moe_topk'.") if getattr(model_config, "intermediate_size", None) is None: - raise ValueError( - f"Model config {type(model_config).__name__} is missing attribute 'intermediate_size'." - ) + raise ValueError(f"Model config {type(model_config).__name__} is missing attribute 'intermediate_size'.") if getattr(model_config, "hidden_size", None) is None: - raise ValueError( - f"Model config {type(model_config).__name__} is missing attribute 'hidden_size'." - ) + raise ValueError(f"Model config {type(model_config).__name__} is missing attribute 'hidden_size'.") nb_experts = int(model_config.num_experts) topk = int(model_config.moe_topk[0]) @@ -416,28 +403,20 @@ def extract_transformers_moe_dimensions( - topk: number of active experts per token. """ if getattr(model_config, "num_experts", None) is None: - raise ValueError( - f"Model config {type(model_config).__name__} is missing attribute 'num_experts'." - ) + raise ValueError(f"Model config {type(model_config).__name__} is missing attribute 'num_experts'.") if getattr(model_config, "hidden_size", None) is None: - raise ValueError( - f"Model config {type(model_config).__name__} is missing attribute 'hidden_size'." - ) + raise ValueError(f"Model config {type(model_config).__name__} is missing attribute 'hidden_size'.") nb_experts = int(model_config.num_experts) hidden_size = int(model_config.hidden_size) if is_moe_lm(model): if getattr(model_config, "num_experts_per_tok", None) is None: - raise ValueError( - f"Model config {type(model_config).__name__} is missing attribute 'num_experts_per_tok'." - ) + raise ValueError(f"Model config {type(model_config).__name__} is missing attribute 'num_experts_per_tok'.") topk = int(model_config.num_experts_per_tok) else: if getattr(model_config, "moe_topk", None) is None: - raise ValueError( - f"Model config {type(model_config).__name__} is missing attribute 'moe_topk'." - ) + raise ValueError(f"Model config {type(model_config).__name__} is missing attribute 'moe_topk'.") topk = int(model_config.moe_topk[0]) moe_intermediate = getattr(model_config, "moe_intermediate_size", None) diff --git a/src/pruna/algorithms/utils/moe_kernel_tuner.py b/src/pruna/algorithms/utils/moe_kernel_tuner.py index 747bbd98..51b3a258 100644 --- a/src/pruna/algorithms/utils/moe_kernel_tuner.py +++ b/src/pruna/algorithms/utils/moe_kernel_tuner.py @@ -310,12 +310,8 @@ def benchmark_config( device=device, ) else: - w1 = torch.randn( - num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype, device=device - ) - w2 = torch.randn( - num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype, device=device - ) + w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype, device=device) + w2 = torch.randn(num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype, device=device) gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32, device=device) w1_scale = None @@ -323,9 +319,7 @@ def benchmark_config( a1_scale = None a2_scale = None if use_int8_w8a16: - w1_scale = torch.randn( - (num_experts, 2 * shard_intermediate_size), dtype=torch.float32, device=device - ) + w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size), dtype=torch.float32, device=device) w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32, device=device) if use_deep_gemm: block_quant_shape = [128, 128] @@ -340,14 +334,8 @@ def benchmark_config( n_tiles_w2 = (k + block_n - 1) // block_n k_tiles_w1 = (k + block_k - 1) // block_k k_tiles_w2 = (n + block_k - 1) // block_k - w1_scale = ( - torch.rand((e, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device=device) - * factor_for_scale - ) - w2_scale = ( - torch.rand((e, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device=device) - * factor_for_scale - ) + w1_scale = torch.rand((e, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device=device) * factor_for_scale + w2_scale = torch.rand((e, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device=device) * factor_for_scale else: w1_scale = torch.randn(num_experts, dtype=torch.float32, device=device) w2_scale = torch.randn(num_experts, dtype=torch.float32, device=device) @@ -521,12 +509,12 @@ def save_configs( data[str(k)] = v path_to_kernel_configs = ( - pathlib.Path(path_to_huggingface_hub_cache) / - ".cache/huggingface/hub/models--RedHatAI--moe/blobs/configs" + pathlib.Path(path_to_huggingface_hub_cache) / ".cache/huggingface/hub/models--RedHatAI--moe/blobs/configs" ) path_hf = path_to_kernel_configs / filename write_config_file( - path_hf, data, + path_hf, + data, merge_if_exists=True, write_only_if_missing=False, log_label=str(path_hf), @@ -542,7 +530,8 @@ def save_configs( path_to_vllm_configs = pathlib.Path(path_where_vllm_is_installed).parent / path_to_vllm_cache path_vllm = pathlib.Path(path_to_vllm_configs) / filename write_config_file( - path_vllm, data, + path_vllm, + data, merge_if_exists=False, write_only_if_missing=True, log_label=str(path_vllm), diff --git a/tests/algorithms/testers/moe_kernel_tuner.py b/tests/algorithms/testers/moe_kernel_tuner.py index 17612492..d67a795c 100644 --- a/tests/algorithms/testers/moe_kernel_tuner.py +++ b/tests/algorithms/testers/moe_kernel_tuner.py @@ -6,7 +6,7 @@ class TestMoeKernelTuner(AlgorithmTesterBase): """Test the MoeKernelTuner.""" - models = ["qwen3_next_moe_tiny_random"] + models = ["qwen_moe_tiny_random"] reject_models = ["sd_tiny_random"] allow_pickle_files = False algorithm_class = MoeKernelTuner From b16a44f29b213c7c03fa85cbd492f0fbdabe47ad Mon Sep 17 00:00:00 2001 From: llcnt Date: Wed, 18 Feb 2026 08:43:05 +0000 Subject: [PATCH 24/26] feat: ruff on new imports --- src/pruna/algorithms/moe_kernel_tuner.py | 13 +++++++------ src/pruna/algorithms/utils/moe_kernel_tuner.py | 3 ++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index 8876f83a..2ea0cbac 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -26,6 +26,7 @@ from pruna.config.hyperparameters import UnconstrainedHyperparameter from pruna.config.smash_config import SmashConfigPrefixWrapper from pruna.engine.model_checks import is_moe_lm, is_transformers_pipeline_with_moe_lm +from pruna.engine.save_artifacts import SAVE_ARTIFACTS_FUNCTIONS from pruna.logging.logger import pruna_logger @@ -293,12 +294,6 @@ def import_algorithm_packages(self) -> dict[str, Any]: import vllm.envs as envs import vllm.model_executor.layers.fused_moe.fused_moe as fused_moe import vllm.platforms as vllm_platforms - from vllm.model_executor.layers.fused_moe import override_config - from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig, - _get_config_dtype_str, - ) - from vllm.triton_utils import triton from pruna.algorithms.utils.moe_kernel_tuner import ( BenchmarkConfig, NoValidConfigError, @@ -307,6 +302,12 @@ def import_algorithm_packages(self) -> dict[str, Any]: save_configs, tune_kernel, ) + from vllm.model_executor.layers.fused_moe import override_config + from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + _get_config_dtype_str, + ) + from vllm.triton_utils import triton return dict( FusedMoEQuantConfig=FusedMoEQuantConfig, diff --git a/src/pruna/algorithms/utils/moe_kernel_tuner.py b/src/pruna/algorithms/utils/moe_kernel_tuner.py index 51b3a258..99b73e09 100644 --- a/src/pruna/algorithms/utils/moe_kernel_tuner.py +++ b/src/pruna/algorithms/utils/moe_kernel_tuner.py @@ -100,9 +100,10 @@ def tune_kernel( BenchmarkConfig The best config found. """ - import ray.experimental.tqdm_ray as tqdm_ray from datetime import datetime + import ray.experimental.tqdm_ray as tqdm_ray + imported_packages["vllm_platforms"].current_platform.seed_everything(seed) best_config = None best_time = float("inf") From 2997fdfcb51ebb4c746283a3749e05a67afc0e22 Mon Sep 17 00:00:00 2001 From: llcnt Date: Wed, 18 Feb 2026 08:47:12 +0000 Subject: [PATCH 25/26] fix: ruff on new import --- src/pruna/algorithms/moe_kernel_tuner.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py index 2ea0cbac..9870a2bc 100644 --- a/src/pruna/algorithms/moe_kernel_tuner.py +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -294,6 +294,13 @@ def import_algorithm_packages(self) -> dict[str, Any]: import vllm.envs as envs import vllm.model_executor.layers.fused_moe.fused_moe as fused_moe import vllm.platforms as vllm_platforms + from vllm.model_executor.layers.fused_moe import override_config + from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + _get_config_dtype_str, + ) + from vllm.triton_utils import triton + from pruna.algorithms.utils.moe_kernel_tuner import ( BenchmarkConfig, NoValidConfigError, @@ -302,12 +309,6 @@ def import_algorithm_packages(self) -> dict[str, Any]: save_configs, tune_kernel, ) - from vllm.model_executor.layers.fused_moe import override_config - from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig, - _get_config_dtype_str, - ) - from vllm.triton_utils import triton return dict( FusedMoEQuantConfig=FusedMoEQuantConfig, From f684372f36ccaa81381441d6d8fcbb87ab034a84 Mon Sep 17 00:00:00 2001 From: llcnt Date: Wed, 18 Feb 2026 08:56:33 +0000 Subject: [PATCH 26/26] feat: adapt pruna logger when triton version mismatch --- src/pruna/engine/load_artifacts.py | 48 +++++++++++++++--------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/src/pruna/engine/load_artifacts.py b/src/pruna/engine/load_artifacts.py index a759b74f..607fe641 100644 --- a/src/pruna/engine/load_artifacts.py +++ b/src/pruna/engine/load_artifacts.py @@ -121,30 +121,30 @@ def load_moe_kernel_tuner_artifacts(model: Any, model_path: str | Path, smash_co "Performance may be degraded or config may be invalid. " "We recommend re-tuning the kernel in your environment." ) - raise pruna_logger.info(msg) - else: - best_configs = best_configs_and_hyperparameters["best_configs_moe_kernel"] - num_experts = best_configs_and_hyperparameters["num_experts"] - shard_intermediate_size = best_configs_and_hyperparameters["shard_intermediate_size"] - dtype = best_configs_and_hyperparameters["dtype"] - # Convert dtype string back to torch.dtype if needed - dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float16 - use_fp8_w8a8 = best_configs_and_hyperparameters["use_fp8_w8a8"] - use_int8_w8a16 = best_configs_and_hyperparameters["use_int8_w8a16"] - - # save the config attached to smash_config, inside the hf and vllm caches. - save_configs( - best_configs, - num_experts, - shard_intermediate_size, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - None, - smash_config["moe_kernel_tuner_path_to_huggingface_hub_cache"], - smash_config["moe_kernel_tuner_path_to_vllm_cache"], - imported_packages, - ) + pruna_logger.info(msg) + + best_configs = best_configs_and_hyperparameters["best_configs_moe_kernel"] + num_experts = best_configs_and_hyperparameters["num_experts"] + shard_intermediate_size = best_configs_and_hyperparameters["shard_intermediate_size"] + dtype = best_configs_and_hyperparameters["dtype"] + # Convert dtype string back to torch.dtype if needed + dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float16 + use_fp8_w8a8 = best_configs_and_hyperparameters["use_fp8_w8a8"] + use_int8_w8a16 = best_configs_and_hyperparameters["use_int8_w8a16"] + + # save the config attached to smash_config, inside the hf and vllm caches. + save_configs( + best_configs, + num_experts, + shard_intermediate_size, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + None, + smash_config["moe_kernel_tuner_path_to_huggingface_hub_cache"], + smash_config["moe_kernel_tuner_path_to_vllm_cache"], + imported_packages, + ) class LOAD_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801