diff --git a/.github/actions/setup-uv-project/action.yml b/.github/actions/setup-uv-project/action.yml index d32b697c..ecd89282 100644 --- a/.github/actions/setup-uv-project/action.yml +++ b/.github/actions/setup-uv-project/action.yml @@ -12,4 +12,4 @@ runs: github-token: ${{ github.token }} - shell: bash - run: uv sync --extra dev + run: uv sync --extra dev --extra vllm diff --git a/pyproject.toml b/pyproject.toml index 18193b70..5f2fdb61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,6 +141,10 @@ dependencies = [ ] [project.optional-dependencies] +vllm = [ + "vllm>=0.11.0", + "ray", +] stable-fast = [ "xformers>=0.0.30", "stable-fast-pruna==1.0.8", diff --git a/src/pruna/algorithms/base/tags.py b/src/pruna/algorithms/base/tags.py index e5dcd426..3d9d9478 100644 --- a/src/pruna/algorithms/base/tags.py +++ b/src/pruna/algorithms/base/tags.py @@ -80,6 +80,24 @@ class AlgorithmTag(Enum): "Resamplers change the shape of image or video latents during generation to speed up inference.", ) + @classmethod + def tags_compatible_with_moe_kernel(cls) -> list["AlgorithmTag"]: + """ + Return tags that the MoE kernel tuner is compatible with (for ordering). + + Used so that compatible_before / compatible_after can be derived from + the enum and stay in sync when new tags are added. + """ + return [ + cls.KERNEL, + cls.QUANTIZER, + cls.PRUNER, + cls.CACHER, + cls.FACTORIZER, + cls.BATCHER, + cls.COMPILER, + ] + def __init__(self, name: str, description: str): """ Initialize an algorithm tag with name and description. diff --git a/src/pruna/algorithms/moe_kernel_tuner.py b/src/pruna/algorithms/moe_kernel_tuner.py new file mode 100644 index 00000000..9870a2bc --- /dev/null +++ b/src/pruna/algorithms/moe_kernel_tuner.py @@ -0,0 +1,441 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import json +import time +from collections.abc import Iterable +from typing import Any + +import torch +from ConfigSpace import CategoricalHyperparameter, OrdinalHyperparameter + +from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase +from pruna.algorithms.base.tags import AlgorithmTag as tags +from pruna.config.hyperparameters import UnconstrainedHyperparameter +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.model_checks import is_moe_lm, is_transformers_pipeline_with_moe_lm +from pruna.engine.save_artifacts import SAVE_ARTIFACTS_FUNCTIONS +from pruna.logging.logger import pruna_logger + + +class MoeKernelTuner(PrunaAlgorithmBase): + """ + Tune the MoE (Mixture of Experts) Triton kernel for the model. + + Uses vLLM to tune the MoE kernel. If an existing artifact exists and the + Triton version matches, tuning is skipped and cached configs are reused. + """ + + algorithm_name: str = "moe_kernel_tuner" + group_tags: list[str] = [tags.KERNEL] + save_fn: None = None + references: dict[str, str] = { + "GitHub": "https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py", + } + tokenizer_required: bool = False + processor_required: bool = False + runs_on: list[str] = ["cuda", "accelerate"] + dataset_required: bool = False + compatible_before: Iterable[str] = tags.tags_compatible_with_moe_kernel() + compatible_after: Iterable[str] = tags.tags_compatible_with_moe_kernel() + required_install = "``uv pip install vllm>=0.11.0``" + + def get_hyperparameters(self) -> list: + """ + Configure all algorithm-specific hyperparameters with ConfigSpace. + + Returns + ------- + list + The hyperparameters. + """ + return [ + CategoricalHyperparameter( + "compute_dtype", + choices=["bfloat16", "float16"], + default_value="bfloat16", + meta=dict(desc="Compute dtype to use."), + ), + CategoricalHyperparameter( + "weight_dtype", + choices=["fp16", "fp8_w8a8", "int8_w8a16"], + default_value="fp16", + meta=dict(desc="Dtype to use for the weights (and activations)."), + ), + OrdinalHyperparameter( + "tensor_parallel_size", + sequence=[1, 2, 4, 8, 16, 32], + default_value=1, + meta=dict(desc="Tensor parallel size to use if the model can not fit on a single GPU."), + ), + UnconstrainedHyperparameter( + "path_to_huggingface_hub_cache", + default_value="~", + meta=dict( + desc=( + "Path to the Hugging Face Hub cache directory " + "(that contains `kernels` configs). If not provided, " + "the cache will be saved in the current working directory." + ) + ), + ), + UnconstrainedHyperparameter( + "path_to_vllm_cache", + default_value="vllm/model_executor/layers/fused_moe/configs", + meta=dict(desc="Path to the vLLM MoE configs directory."), + ), + OrdinalHyperparameter( + "num_iters", + sequence=[1, 20, 50, 100], + default_value=20, + meta=dict(desc="Number of iterations to average the kernel times on."), + ), + OrdinalHyperparameter( + "block_size_m_max", + sequence=[4, 5, 6, 7, 8, 9, 10], + default_value=8, + meta=dict(desc="Maximum (log) block size for tiling through input dimension."), + ), + OrdinalHyperparameter( + "block_size_n_max", + sequence=[5, 6, 7, 8, 9, 10], + default_value=8, + meta=dict(desc="Maximum (log) block size for tiling through output dimension."), + ), + OrdinalHyperparameter( + "block_size_k_max", + sequence=[6, 7, 8, 9, 10], + default_value=8, + meta=dict(desc="Maximum (log) block size for tiling through intermediate dimension."), + ), + ] + + def model_check_fn(self, model: Any) -> bool: + """ + Check if the model is a MoE model. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is a valid model for the algorithm, False otherwise. + """ + if model.__class__.__name__ == "HunyuanImage3ForCausalMM": + return True + return is_moe_lm(model) or is_transformers_pipeline_with_moe_lm(model) + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + """ + Tune the MoE Triton kernel for the model. + + If an existing artifact exists in the cache with the same Triton version, + tuning is skipped and the cached configs are restored to the HF/vLLM caches. + + Parameters + ---------- + model : Any + The model to wrap. + smash_config : SmashConfigPrefixWrapper + The configuration for the application of the algorithm. + + Returns + ------- + Any + The untouched model. + """ + if is_transformers_pipeline_with_moe_lm(model): + return self._apply_to_model_within_transformers_pipeline(model, smash_config) + + imported_packages = self.import_algorithm_packages() + + # (i) Get the MoE parameters: exception first (Hunyuan), then general case. + model_config = getattr(model, "config", None) + if model_config is None: + raise ValueError(f"Model {model.__class__.__name__} has no config.") + + tensor_parallel_size = int(smash_config["tensor_parallel_size"]) + if model.__class__.__name__ == "HunyuanImage3ForCausalMM": + nb_experts, shard_intermediate_size, hidden_size, topk = extract_hunyuan_dimensions( + model, model_config, tensor_parallel_size + ) + else: + nb_experts, shard_intermediate_size, hidden_size, topk = extract_transformers_moe_dimensions( + model, model_config, tensor_parallel_size + ) + + # (ii) Get the compute parameters + dtype = smash_config["compute_dtype"] + dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float16 + use_fp8_w8a8 = smash_config["weight_dtype"] == "fp8_w8a8" + use_int8_w8a16 = smash_config["weight_dtype"] == "int8_w8a16" + + # (iii) Tune the kernel over a range of batch sizes (single GPU per Ray worker). + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] + ray = imported_packages["ray"] + tune_kernel = imported_packages["tune_kernel"] + get_configs_compute_bound = imported_packages["get_configs_compute_bound"] + ensure_benchmark_config = imported_packages["ensure_benchmark_config"] + save_configs = imported_packages["save_configs"] + no_valid_config_error = imported_packages["NoValidConfigError"] + + ray.init(ignore_reinit_error=True) + search_space = get_configs_compute_bound(smash_config) + pruna_logger.info(f"Start tuning over {len(search_space)} configurations...") + + start = time.time() + ray_outputs: list[Any] = [] + tune_kernel = ray.remote(num_gpus=1)(tune_kernel) + + for batch_size in batch_sizes: + out = tune_kernel.remote( + batch_size, + nb_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + search_space, + None, + False, + imported_packages, + 0, # fixed seed for reproducibility + smash_config["num_iters"], + ) + ray_outputs.append(out) + + # Shutdown Ray processes on other devices and log any resulting exception as warnings. + tuned_config_by_batch_size: dict[int, Any] = {} + try: + for batch_size, output_ref in zip(batch_sizes, ray_outputs): + try: + raw_config = ray.get(output_ref) + tuned_config_by_batch_size[batch_size] = ensure_benchmark_config(raw_config) + except no_valid_config_error: + pruna_logger.warning( + "No valid config for batch_size=%s; skipping (smaller batch sizes may still be used).", + batch_size, + ) + finally: + try: + ray.shutdown() + except Exception as e: + pruna_logger.warning("Exception during ray.shutdown(): %s", e) + + if not tuned_config_by_batch_size: + raise RuntimeError( + "No valid kernel configuration was found for any batch size. " + "All configurations failed (e.g., due to OutOfResources). " + "This can happen on GPUs with limited resources. " + "Consider reducing your model size or tuning search space." + ) + + end = time.time() + pruna_logger.info(f"Tuning took {end - start:.2f} seconds") + + save_configs( + tuned_config_by_batch_size, + nb_experts, + shard_intermediate_size, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + None, + smash_config["path_to_huggingface_hub_cache"], + smash_config["path_to_vllm_cache"], + imported_packages, + ) + + best_configs_and_hyperparameters = { + "triton_version": imported_packages["triton"].__version__, + "best_configs_moe_kernel": tuned_config_by_batch_size, + "num_experts": nb_experts, + "shard_intermediate_size": shard_intermediate_size, + "dtype": "bfloat16" if dtype == torch.bfloat16 else "float16", + "use_fp8_w8a8": use_fp8_w8a8, + "use_int8_w8a16": use_int8_w8a16, + } + with open(smash_config.cache_dir / "moe_kernel_tuner.json", "w") as f: + json.dump(best_configs_and_hyperparameters, f) + + smash_config.save_artifacts_fns.append(SAVE_ARTIFACTS_FUNCTIONS.moe_kernel_tuner_artifacts.name) + return model + + def import_algorithm_packages(self) -> dict[str, Any]: + """ + Import the algorithm packages (vLLM, Triton, Ray, and utils). + + Ray is imported here so it is not a top-level dependency of the package. + Utils are imported here so they are only loaded when this algorithm is used. + + Returns + ------- + dict[str, Any] + The algorithm packages and helpers (tune, save_configs, etc.). + """ + import ray + import vllm.envs as envs + import vllm.model_executor.layers.fused_moe.fused_moe as fused_moe + import vllm.platforms as vllm_platforms + from vllm.model_executor.layers.fused_moe import override_config + from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + _get_config_dtype_str, + ) + from vllm.triton_utils import triton + + from pruna.algorithms.utils.moe_kernel_tuner import ( + BenchmarkConfig, + NoValidConfigError, + ensure_benchmark_config, + get_configs_compute_bound, + save_configs, + tune_kernel, + ) + + return dict( + FusedMoEQuantConfig=FusedMoEQuantConfig, + _get_config_dtype_str=_get_config_dtype_str, + FusedMoE=fused_moe, + vllm_platforms=vllm_platforms, + triton=triton, + override_config=override_config, + envs=envs, + ray=ray, + tune_kernel=tune_kernel, + get_configs_compute_bound=get_configs_compute_bound, + ensure_benchmark_config=ensure_benchmark_config, + save_configs=save_configs, + NoValidConfigError=NoValidConfigError, + BenchmarkConfig=BenchmarkConfig, + ) + + +def extract_hunyuan_dimensions( + model: Any, + model_config: Any, + tensor_parallel_size: int, +) -> tuple[int, int, int, int]: + """ + Extract MoE dimensions for HunyuanImage3ForCausalMM. + + Parameters + ---------- + model : Any + The model (used only for type context). + model_config : Any + The model's config object. Must have num_experts, moe_topk, intermediate_size, hidden_size. + tensor_parallel_size : int + Tensor parallel size (must divide intermediate_size). + + Returns + ------- + tuple[int, int, int, int] + (nb_experts, shard_intermediate_size, hidden_size, topk). + - nb_experts: number of experts in the MoE layer. + - shard_intermediate_size: intermediate dimension per GPU shard (after SiLU-and-mul). + - hidden_size: model hidden dimension. + - topk: number of active experts per token. + """ + if getattr(model_config, "num_experts", None) is None: + raise ValueError(f"Model config {type(model_config).__name__} is missing attribute 'num_experts'.") + if getattr(model_config, "moe_topk", None) is None: + raise ValueError(f"Model config {type(model_config).__name__} is missing attribute 'moe_topk'.") + if getattr(model_config, "intermediate_size", None) is None: + raise ValueError(f"Model config {type(model_config).__name__} is missing attribute 'intermediate_size'.") + if getattr(model_config, "hidden_size", None) is None: + raise ValueError(f"Model config {type(model_config).__name__} is missing attribute 'hidden_size'.") + + nb_experts = int(model_config.num_experts) + topk = int(model_config.moe_topk[0]) + intermediate_size = int(model_config.intermediate_size) + hidden_size = int(model_config.hidden_size) + + if intermediate_size % tensor_parallel_size != 0: + raise ValueError( + f"Expected tensor_parallel_size to be a divisor of the model's intermediate size " + f"(MLP hidden dimension) {intermediate_size}, but got {tensor_parallel_size}." + ) + shard_intermediate_size = 2 * intermediate_size // tensor_parallel_size + return nb_experts, shard_intermediate_size, hidden_size, topk + + +def extract_transformers_moe_dimensions( + model: Any, + model_config: Any, + tensor_parallel_size: int, +) -> tuple[int, int, int, int]: + """ + Extract MoE dimensions for standard transformers MoE LMs (e.g. Mixtral, Qwen-MoE). + + Parameters + ---------- + model : Any + The model (used to decide num_experts_per_tok vs moe_topk). + model_config : Any + The model's config. Must have num_experts, intermediate_size (or moe_intermediate_size), + hidden_size, and either num_experts_per_tok or moe_topk. + tensor_parallel_size : int + Tensor parallel size (must divide intermediate_size). + + Returns + ------- + tuple[int, int, int, int] + (nb_experts, shard_intermediate_size, hidden_size, topk). + - nb_experts: number of experts in the MoE layer. + - shard_intermediate_size: intermediate dimension per GPU shard. + - hidden_size: model hidden dimension. + - topk: number of active experts per token. + """ + if getattr(model_config, "num_experts", None) is None: + raise ValueError(f"Model config {type(model_config).__name__} is missing attribute 'num_experts'.") + if getattr(model_config, "hidden_size", None) is None: + raise ValueError(f"Model config {type(model_config).__name__} is missing attribute 'hidden_size'.") + + nb_experts = int(model_config.num_experts) + hidden_size = int(model_config.hidden_size) + + if is_moe_lm(model): + if getattr(model_config, "num_experts_per_tok", None) is None: + raise ValueError(f"Model config {type(model_config).__name__} is missing attribute 'num_experts_per_tok'.") + topk = int(model_config.num_experts_per_tok) + else: + if getattr(model_config, "moe_topk", None) is None: + raise ValueError(f"Model config {type(model_config).__name__} is missing attribute 'moe_topk'.") + topk = int(model_config.moe_topk[0]) + + moe_intermediate = getattr(model_config, "moe_intermediate_size", None) + if moe_intermediate is not None: + intermediate_size = int(moe_intermediate) + else: + if getattr(model_config, "intermediate_size", None) is None: + raise ValueError( + f"Model config {type(model_config).__name__} is missing attribute " + "'intermediate_size' or 'moe_intermediate_size'." + ) + intermediate_size = int(model_config.intermediate_size) + + if intermediate_size % tensor_parallel_size != 0: + raise ValueError( + f"Expected tensor_parallel_size to be a divisor of the model's intermediate size " + f"(MLP hidden dimension) {intermediate_size}, but got {tensor_parallel_size}." + ) + shard_intermediate_size = 2 * intermediate_size // tensor_parallel_size + return nb_experts, shard_intermediate_size, hidden_size, topk diff --git a/src/pruna/algorithms/utils/moe_kernel_tuner.py b/src/pruna/algorithms/utils/moe_kernel_tuner.py new file mode 100644 index 00000000..99b73e09 --- /dev/null +++ b/src/pruna/algorithms/utils/moe_kernel_tuner.py @@ -0,0 +1,539 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import json +import pathlib +from importlib.util import find_spec +from itertools import product +from typing import Any, Dict, TypedDict + +import torch + +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.logging.logger import pruna_logger + + +class NoValidConfigError(RuntimeError): + """ + Raised when no valid kernel configuration was found for a given batch size. + + All configurations failed (e.g., due to OutOfResources). This can happen on + GPUs with limited resources. The caller may retry with smaller batch sizes. + """ + + +class BenchmarkConfig(TypedDict): + """The configuration for the matrix multiplication (tiling and warp scheduling).""" + + BLOCK_SIZE_M: int + BLOCK_SIZE_N: int + BLOCK_SIZE_K: int + GROUP_SIZE_M: int + num_warps: int + num_stages: int + + +def tune_kernel( + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + search_space: list[dict[str, int]], + block_quant_shape: list[int] | None, + use_deep_gemm: bool, + imported_packages: Dict[str, Any], + seed: int, + num_iters: int, +) -> BenchmarkConfig: + """ + Tune a given Triton kernel (run inside a Ray worker; no Ray at module level). + + Parameters + ---------- + num_tokens : int + Batch size in tokens (sequence length × batch dimension). + num_experts : int + Number of experts in the MoE layer. + shard_intermediate_size : int + Intermediate size of the model in the shard (if using tensor parallelism). + hidden_size : int + Model hidden dimension (input/output of the expert layer). + topk : int + Number of active experts per token. + dtype : torch.dtype + Dtype for weights and activations. + use_fp8_w8a8 : bool + Whether to use fp8_w8a8. + use_int8_w8a16 : bool + Whether to use int8_w8a16. + search_space : list[dict[str, int]] + Search space for the kernel (tiling and warp scheduling). + block_quant_shape : list[int] | None + Block shape for the kernel (None here). + use_deep_gemm : bool + Whether to use deep gemm (False here). + imported_packages : Dict[str, Any] + Imported packages (vllm, triton, etc.). + seed : int + Random seed for reproducibility. + num_iters : int + Number of iterations to average the kernel time on. + + Returns + ------- + BenchmarkConfig + The best config found. + """ + from datetime import datetime + + import ray.experimental.tqdm_ray as tqdm_ray + + imported_packages["vllm_platforms"].current_platform.seed_everything(seed) + best_config = None + best_time = float("inf") + + for config in tqdm_ray.tqdm(search_space): + try: + kernel_time = benchmark_config( + config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=num_iters, + block_quant_shape=block_quant_shape, + use_deep_gemm=use_deep_gemm, + imported_packages=imported_packages, + ) + except imported_packages["triton"].runtime.autotuner.OutOfResources: + continue + + if kernel_time < best_time: + best_time, best_config = kernel_time, config + + now = datetime.now() + pruna_logger.info(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") + if best_config is None: + raise NoValidConfigError( + f"No valid kernel configuration was found for batch_size={num_tokens}. " + "All configurations failed (e.g., due to OutOfResources). " + "This can happen on GPUs with limited resources. " + "Consider reducing your model size, or tuning search space." + ) + return best_config + + +def ensure_benchmark_config(config: dict[str, Any]) -> BenchmarkConfig: + """ + Convert the raw dict returned by tune to a canonical BenchmarkConfig. + + Preserves optional keys (e.g. waves_per_eu, matrix_instr_nonkdim, kpack) when + present so that the config file format stays compatible with vLLM. + + Parameters + ---------- + config : dict[str, Any] + Raw config from tune (same keys as BenchmarkConfig, possibly with extras). + + Returns + ------- + BenchmarkConfig + Config with required keys and any optional keys that were present. + """ + result: dict[str, Any] = { + "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], + "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], + "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], + "GROUP_SIZE_M": config["GROUP_SIZE_M"], + "num_warps": config["num_warps"], + "num_stages": config["num_stages"], + } + if "waves_per_eu" in config: + result["waves_per_eu"] = config.get("waves_per_eu") + if "matrix_instr_nonkdim" in config: + result["matrix_instr_nonkdim"] = config.get("matrix_instr_nonkdim") + if "kpack" in config: + result["kpack"] = config.get("kpack") + return result + + +def get_configs_compute_bound(smash_config: SmashConfigPrefixWrapper) -> list[dict[str, int]]: + """ + Get the grid-search space for the kernel (tiling and warp scheduling). + + Parameters + ---------- + smash_config : SmashConfigPrefixWrapper + The Smash configuration. + + Returns + ------- + list[dict[str, int]] + The search space for the kernel (tiling and warp scheduling). + """ + configs: list[dict[str, int]] = [] + + block_m_range = [2**i for i in range(4, smash_config["block_size_m_max"] + 1)] + block_n_range = [2**i for i in range(5, smash_config["block_size_n_max"] + 1)] + block_k_range = [2**i for i in range(6, smash_config["block_size_k_max"] + 1)] + num_warps_range = [4, 8] + group_m_range = [1, 16, 32, 64] + num_stage_range = [2, 3, 4, 5] + + param_ranges = { + "BLOCK_SIZE_M": block_m_range, + "BLOCK_SIZE_N": block_n_range, + "BLOCK_SIZE_K": block_k_range, + "GROUP_SIZE_M": group_m_range, + "num_warps": num_warps_range, + "num_stages": num_stage_range, + } + + keys, values = zip(*param_ranges.items()) + for config_values in product(*values): + config = dict(zip(keys, config_values)) + configs.append(config) + + return configs + + +def benchmark_config( + config: BenchmarkConfig, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + block_quant_shape: list[int] | None = None, + use_deep_gemm: bool = False, + imported_packages: Dict[str, Any] | None = None, +) -> float: + """ + Benchmark a given Triton kernel using CUDAGraph. + + This function is copied from the vLLM repository. + https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py + + Tuning runs on a single GPU per Ray worker; Ray sets CUDA_VISIBLE_DEVICES + so that set_device(0) in this function selects the assigned GPU. + + Parameters + ---------- + config : BenchmarkConfig + The configuration to benchmark. + num_tokens : int + Batch size in tokens. + num_experts : int + Number of experts. + shard_intermediate_size : int + Intermediate size of the model in the shard (if using tensor parallelism). + hidden_size : int + Model hidden dimension. + topk : int + Number of active experts per token. + dtype : torch.dtype + Dtype for weights and activations. + use_fp8_w8a8 : bool + Whether to use fp8_w8a8. + use_int8_w8a16 : bool + Whether to use int8_w8a16. + num_iters : int + Number of iterations to run the benchmark. + block_quant_shape : list[int] | None + Block shape for the kernel (None here). + use_deep_gemm : bool + Whether to use deep gemm (False here). + imported_packages : Dict[str, Any] | None + Imported packages (vllm, triton, etc.). + + Returns + ------- + float + Average latency per kernel invocation in microseconds (each CUDA graph + replay runs 10 kernel invocations; we average over num_iters replays + then convert to µs). + """ + if imported_packages is None: + raise ValueError("imported_packages is required") + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for MoeKernelTuner.") + # Ray sets CUDA_VISIBLE_DEVICES per worker to the GPU it scheduled; use device 0. + torch.cuda.set_device(0) + device = torch.device("cuda") + + init_dtype = torch.float16 if use_fp8_w8a8 else dtype + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + if use_int8_w8a16: + w1 = torch.randint( + -127, + 127, + ( + num_experts, + shard_intermediate_size, + hidden_size, + ), + dtype=torch.int8, + device=device, + ) + w2 = torch.randint( + -127, + 127, + ( + num_experts, + hidden_size, + shard_intermediate_size // 2, + ), + dtype=torch.int8, + device=device, + ) + else: + w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype, device=device) + w2 = torch.randn(num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype, device=device) + gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32, device=device) + + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + if use_int8_w8a16: + w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size), dtype=torch.float32, device=device) + w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32, device=device) + if use_deep_gemm: + block_quant_shape = [128, 128] + if use_fp8_w8a8: + if block_quant_shape: + block_n, block_k = block_quant_shape[0], block_quant_shape[1] + e = num_experts + n = shard_intermediate_size // 2 + k = hidden_size + factor_for_scale = 1e-2 + n_tiles_w1 = (2 * n + block_n - 1) // block_n + n_tiles_w2 = (k + block_n - 1) // block_n + k_tiles_w1 = (k + block_k - 1) // block_k + k_tiles_w2 = (n + block_k - 1) // block_k + w1_scale = torch.rand((e, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device=device) * factor_for_scale + w2_scale = torch.rand((e, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device=device) * factor_for_scale + else: + w1_scale = torch.randn(num_experts, dtype=torch.float32, device=device) + w2_scale = torch.randn(num_experts, dtype=torch.float32, device=device) + + a1_scale = torch.randn(1, dtype=torch.float32, device=device) + a2_scale = torch.randn(1, dtype=torch.float32, device=device) + + w1 = w1.to(device=device, dtype=imported_packages["vllm_platforms"].current_platform.fp8_dtype()) + w2 = w2.to(device=device, dtype=imported_packages["vllm_platforms"].current_platform.fp8_dtype()) + + input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32, device=device) + + def prepare(i: int): + input_gating.copy_(gating_output[i]) + + def run(): + if use_fp8_w8a8: + quant_dtype = torch.float8_e4m3fn + elif use_int8_w8a16: + quant_dtype = torch.int8 + else: + quant_dtype = None + + quant_config = imported_packages["FusedMoEQuantConfig"].make( + quant_dtype=quant_dtype, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_quant_shape, + ) + + with imported_packages["override_config"](config): + topk_weights, topk_ids, token_expert_indices = imported_packages["FusedMoE"].fused_topk( + x, input_gating, topk, renormalize=not use_deep_gemm + ) + return imported_packages["FusedMoE"].fused_experts( + x, + w1, + w2, + topk_weights, + topk_ids, + inplace=True, + quant_config=quant_config, + allow_deep_gemm=use_deep_gemm, + ) + + run() + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(10): + run() + torch.cuda.synchronize() + + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + prepare(i) + torch.cuda.synchronize() + + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + # Average latency per kernel invocation in microseconds (each replay runs 10 invocations). + avg_us = sum(latencies) / (num_iters * 10) * 1000 + graph.reset() + return avg_us + + +def write_config_file( + path: pathlib.Path, + data: dict[str, Any], + merge_if_exists: bool, + write_only_if_missing: bool, + log_label: str, +) -> None: + """ + Write config dict to a JSON file, optionally merging with existing content. + + Parameters + ---------- + path : pathlib.Path + Target file path. + data : dict[str, Any] + Config to write (will be merged into existing if merge_if_exists and file exists). + merge_if_exists : bool + If True and path exists, load existing JSON, update with data, then write. + write_only_if_missing : bool + If True and path already exists, do not write (preserves existing file). + log_label : str + Label for the log message when writing. + """ + if write_only_if_missing and path.exists(): + return + if merge_if_exists and path.exists(): + with open(path) as f: + existing: dict[str, Any] = json.load(f) + existing.update(data) + data = existing + path.parent.mkdir(parents=True, exist_ok=True) + pruna_logger.info(f"Writing best config to {log_label}...") + with open(path, "w") as f: + json.dump(data, f, indent=4) + f.write("\n") + + +def save_configs( + configs: dict[int, BenchmarkConfig], + num_experts: int, + shard_intermediate_size: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + block_quant_shape: list[int] | None, + path_to_huggingface_hub_cache: str, + path_to_vllm_cache: str, + imported_packages: Dict[str, Any], +) -> None: + """ + Save the best configs to the HF cache and vLLM cache. + + Writes one file per cache; for the HF path we merge with existing file content + if present so that other keys are preserved. Uses a shared helper to avoid + duplicating the write logic. + + Parameters + ---------- + configs : dict[int, BenchmarkConfig] + Best config per batch size (keys are batch sizes). + num_experts : int + Number of experts. + shard_intermediate_size : int + Intermediate size of the model in the shard (if using tensor parallelism). + dtype : torch.dtype + Dtype for weights and activations. + use_fp8_w8a8 : bool + Whether to use fp8_w8a8. + use_int8_w8a16 : bool + Whether to use int8_w8a16. + block_quant_shape : list[int] | None + Block shape for the kernel (None here). + path_to_huggingface_hub_cache : str + Path to the Hugging Face Hub cache. + path_to_vllm_cache : str + Path to the vLLM MoE configs directory (used when VLLM_TUNED_CONFIG_FOLDER is None). + imported_packages : Dict[str, Any] + Imported packages (vllm, triton, etc.). + """ + dtype_str = imported_packages["_get_config_dtype_str"]( + dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 + ) + + filename = imported_packages["FusedMoE"].get_config_file_name( + num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape + ) + + triton_version = imported_packages["triton"].__version__ + # JSON keys must be strings; configs keys are int (batch sizes). + data: dict[str, Any] = {"triton_version": triton_version} + for k, v in configs.items(): + data[str(k)] = v + + path_to_kernel_configs = ( + pathlib.Path(path_to_huggingface_hub_cache) / ".cache/huggingface/hub/models--RedHatAI--moe/blobs/configs" + ) + path_hf = path_to_kernel_configs / filename + write_config_file( + path_hf, + data, + merge_if_exists=True, + write_only_if_missing=False, + log_label=str(path_hf), + ) + + path_to_vllm_configs = imported_packages["envs"].VLLM_TUNED_CONFIG_FOLDER + if path_to_vllm_configs is None: + submodule_locations = find_spec("vllm").submodule_search_locations + if submodule_locations is not None and len(submodule_locations) > 0: + path_where_vllm_is_installed = submodule_locations[0] + else: + raise RuntimeError("Could not determine installation path for vllm.") + path_to_vllm_configs = pathlib.Path(path_where_vllm_is_installed).parent / path_to_vllm_cache + path_vllm = pathlib.Path(path_to_vllm_configs) / filename + write_config_file( + path_vllm, + data, + merge_if_exists=False, + write_only_if_missing=True, + log_label=str(path_vllm), + ) diff --git a/src/pruna/engine/load_artifacts.py b/src/pruna/engine/load_artifacts.py index f79cf5b9..607fe641 100644 --- a/src/pruna/engine/load_artifacts.py +++ b/src/pruna/engine/load_artifacts.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +import json from enum import Enum from functools import partial from pathlib import Path @@ -81,6 +82,71 @@ def load_torch_artifacts(model: Any, model_path: str | Path, smash_config: Smash torch.compiler.load_cache_artifacts(artifact_bytes) +def load_moe_kernel_tuner_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig, **kwargs) -> Any: + """ + Load a tuned kernel config inside the hf/vllm caches. + + Parameters + ---------- + model : Any + The model to load the artifacts for. + model_path : str | Path + The path to the model directory. + smash_config : SmashConfig + The SmashConfig object. + **kwargs : Any + Additional keyword arguments to pass to the function. + + Returns + ------- + Any + The loaded MoE model. + """ + from pruna.algorithms.moe_kernel_tuner import MoeKernelTuner + from pruna.algorithms.utils.moe_kernel_tuner import save_configs + + imported_packages = MoeKernelTuner().import_algorithm_packages() + save_dir = Path(model_path) + with open(save_dir / "moe_kernel_tuner.json") as f: + best_configs_and_hyperparameters = json.load(f) + if not best_configs_and_hyperparameters: + raise ValueError(f"MoE kernel tuner artifacts not found in {save_dir}") + else: + # check if the triton version is the same as the one used to tune the kernel + triton_version = best_configs_and_hyperparameters["triton_version"] + if triton_version != imported_packages["triton"].__version__: + msg = ( + f"Triton version mismatch: {triton_version} != " + f"{imported_packages['triton'].__version__}. " + "Performance may be degraded or config may be invalid. " + "We recommend re-tuning the kernel in your environment." + ) + pruna_logger.info(msg) + + best_configs = best_configs_and_hyperparameters["best_configs_moe_kernel"] + num_experts = best_configs_and_hyperparameters["num_experts"] + shard_intermediate_size = best_configs_and_hyperparameters["shard_intermediate_size"] + dtype = best_configs_and_hyperparameters["dtype"] + # Convert dtype string back to torch.dtype if needed + dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float16 + use_fp8_w8a8 = best_configs_and_hyperparameters["use_fp8_w8a8"] + use_int8_w8a16 = best_configs_and_hyperparameters["use_int8_w8a16"] + + # save the config attached to smash_config, inside the hf and vllm caches. + save_configs( + best_configs, + num_experts, + shard_intermediate_size, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + None, + smash_config["moe_kernel_tuner_path_to_huggingface_hub_cache"], + smash_config["moe_kernel_tuner_path_to_vllm_cache"], + imported_packages, + ) + + class LOAD_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 """ Enumeration of *artifact* load functions. @@ -116,6 +182,7 @@ class LOAD_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 """ torch_artifacts = partial(load_torch_artifacts) + moe_kernel_tuner_artifacts = partial(load_moe_kernel_tuner_artifacts) def __call__(self, *args, **kwargs) -> None: """Call the underlying load function.""" diff --git a/src/pruna/engine/model_checks.py b/src/pruna/engine/model_checks.py index d3163389..fa5fb763 100644 --- a/src/pruna/engine/model_checks.py +++ b/src/pruna/engine/model_checks.py @@ -121,7 +121,7 @@ def is_moe_lm(model: Any) -> bool: bool True if the model is a MoE LM, False otherwise. """ - return hasattr(model, "num_experts") + return hasattr(getattr(model, "config", None), "num_experts") def is_transformers_pipeline_with_causal_lm(model: Any) -> bool: diff --git a/src/pruna/engine/save_artifacts.py b/src/pruna/engine/save_artifacts.py index e9da3e08..e13e2060 100644 --- a/src/pruna/engine/save_artifacts.py +++ b/src/pruna/engine/save_artifacts.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import shutil from enum import Enum from functools import partial from pathlib import Path @@ -86,6 +87,32 @@ def save_torch_artifacts(model: Any, model_path: str | Path, smash_config: Smash smash_config.load_artifacts_fns.append(LOAD_ARTIFACTS_FUNCTIONS.torch_artifacts.name) +def save_moe_kernel_tuner_artifacts(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: + """ + Move the tuned config from pruna cache into the model directory. + + Parameters + ---------- + model : Any + The model to save artifacts for. + model_path : str | Path + The directory where the model and its artifacts will be saved. + smash_config : SmashConfig + The SmashConfig object. + + Returns + ------- + None + This function does not return anything. + """ + src_file = Path(smash_config.cache_dir) / "moe_kernel_tuner.json" + dest_file = Path(model_path) / "moe_kernel_tuner.json" + shutil.move(src_file, dest_file) + + # define here the load artifacts function + smash_config.load_artifacts_fns.append(LOAD_ARTIFACTS_FUNCTIONS.moe_kernel_tuner_artifacts.name) + + class SAVE_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 """ Enumeration of *artifact* save functions. @@ -121,6 +148,7 @@ class SAVE_ARTIFACTS_FUNCTIONS(Enum): # noqa: N801 """ torch_artifacts = partial(save_torch_artifacts) + moe_kernel_tuner_artifacts = partial(save_moe_kernel_tuner_artifacts) def __call__(self, *args, **kwargs) -> None: """ diff --git a/tests/algorithms/testers/moe_kernel_tuner.py b/tests/algorithms/testers/moe_kernel_tuner.py new file mode 100644 index 00000000..d67a795c --- /dev/null +++ b/tests/algorithms/testers/moe_kernel_tuner.py @@ -0,0 +1,13 @@ +from pruna.algorithms.moe_kernel_tuner import MoeKernelTuner + +from .base_tester import AlgorithmTesterBase + + +class TestMoeKernelTuner(AlgorithmTesterBase): + """Test the MoeKernelTuner.""" + + models = ["qwen_moe_tiny_random"] + reject_models = ["sd_tiny_random"] + allow_pickle_files = False + algorithm_class = MoeKernelTuner + metrics = ["perplexity"] diff --git a/tests/fixtures.py b/tests/fixtures.py index 1c2f24c1..f9bb2f7f 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -197,7 +197,7 @@ def get_autoregressive_text_to_image_model(model_id: str) -> tuple[Any, SmashCon "wan_tiny_random": partial(get_diffusers_model, "pruna-test/wan-t2v-tiny-random", torch_dtype=torch.bfloat16), "flux_tiny": partial(get_diffusers_model, "pruna-test/tiny_flux", torch_dtype=torch.float16), "tiny_llama": partial(get_automodel_transformers, "pruna-test/tiny_llama", torch_dtype=torch.bfloat16), - "qwen3_next_moe_tiny_random": partial( - get_automodel_transformers, "tiny-random/qwen3-next-moe", torch_dtype=torch.bfloat16 + "qwen_moe_tiny_random": partial( + get_automodel_transformers, "yujiepan/qwen1.5-moe-tiny-random", torch_dtype=torch.bfloat16 ), }