diff --git a/examples/README.md b/examples/README.md index b427d39d2c2..90251320be4 100644 --- a/examples/README.md +++ b/examples/README.md @@ -15,6 +15,12 @@ IntelĀ® Neural Compressor validated examples with multiple compression technique + + FLUX.1-dev + Text to Image + Quantization (MXFP8+FP8) + link + Llama-4-Scout-17B-16E-Instruct Multimodal Modeling diff --git a/examples/pytorch/diffusion_model/diffusers/flux/README.md b/examples/pytorch/diffusion_model/diffusers/flux/README.md new file mode 100644 index 00000000000..bca01cbf32c --- /dev/null +++ b/examples/pytorch/diffusion_model/diffusers/flux/README.md @@ -0,0 +1,44 @@ +# Step-by-Step + +This example quantizes and validates the accuracy of Flux. + +# Prerequisite + +## 1. Environment + +```shell +pip install -r requirements.txt +# Use `INC_PT_ONLY=1 pip install git+https://github.com/intel/neural-compressor.git@v3.6rc` for the latest updates before neural-compressor v3.6 release +pip install neural-compressor-pt==3.6 +# Use `pip install git+https://github.com/intel/auto-round.git@v0.8.0rc2` for the latest updates before auto-round v0.8.0 release +pip install auto-round==0.8.0 +``` + +## 2. Prepare Model + +```shell +hf download black-forest-labs/FLUX.1-dev --local-dir FLUX.1-dev +``` + +## 3. Prepare Dataset +```shell +wget https://github.com/mlcommons/inference/raw/refs/heads/master/text_to_image/coco2014/captions/captions_source.tsv +``` + +# Run + +## Quantization + +```bash +bash run_quant.sh --topology=flux_mxfp8 --input_model=FLUX.1-dev --output_model=mxfp8_model +``` +- topology: support flux_fp8 and flux_mxfp8 + + +## Evaluation + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 bash run_benchmark.sh --topology=flux_mxfp8 --input_model=FLUX.1-dev --quantized_model=mxfp8_model +``` + +- CUDA_VISIBLE_DEVICES: split the evaluation file into the number of GPUs' subset to speed up the evaluation diff --git a/examples/pytorch/diffusion_model/diffusers/flux/dataset_split.py b/examples/pytorch/diffusion_model/diffusers/flux/dataset_split.py new file mode 100644 index 00000000000..56015d1f3db --- /dev/null +++ b/examples/pytorch/diffusion_model/diffusers/flux/dataset_split.py @@ -0,0 +1,22 @@ +import argparse +import pandas as pd + +parser = argparse.ArgumentParser() +parser.add_argument('--split_num', type=int) +parser.add_argument('--limit', default=-1, type=int) +parser.add_argument('--input_file', type=str) +parser.add_argument('--output_file', default="subset", type=str) +args = parser.parse_args() + +# load the TSV file +df = pd.read_csv(args.input_file, sep='\t') + +if args.limit > 0: + df = df.iloc[0:args.limit] + +num = round(len(df) / args.split_num) +for i in range(args.split_num): + start = i * num + end = min((i + 1) * num, len(df)) + df_subset = df.iloc[start:end] + df_subset.to_csv(f"{args.output_file}_{i}.tsv", sep='\t', index=False) diff --git a/examples/pytorch/diffusion_model/diffusers/flux/main.py b/examples/pytorch/diffusion_model/diffusers/flux/main.py new file mode 100644 index 00000000000..c9e0b98d5c0 --- /dev/null +++ b/examples/pytorch/diffusion_model/diffusers/flux/main.py @@ -0,0 +1,182 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import sys +import argparse + +import pandas as pd +import tabulate +import torch + +from diffusers import AutoPipelineForText2Image, FluxTransformer2DModel +from functools import partial +from neural_compressor.torch.quantization import ( + AutoRoundConfig, + convert, + prepare, +) +from auto_round.data_type.mxfp import quant_mx_rceil +from auto_round.data_type.fp8 import quant_fp8_sym +from auto_round.utils import get_block_names, get_module +from auto_round.compressors.diffusion.eval import metric_map +from auto_round.compressors.diffusion.dataset import get_diffusion_dataloader + + +parser = argparse.ArgumentParser( + description="Flux quantization.", formatter_class=argparse.ArgumentDefaultsHelpFormatter +) +parser.add_argument("--model", "--model_name", "--model_name_or_path", help="model name or path") +parser.add_argument('--scheme', default="MXFP8", type=str, help="quantizaion scheme.") +parser.add_argument("--quantize", action="store_true") +parser.add_argument("--inference", action="store_true") +parser.add_argument("--accuracy", action="store_true") +parser.add_argument("--dataset", type=str, default="coco2014", help="the dataset for quantization training.") +parser.add_argument("--output_dir", "--quantized_model_path", default="./tmp_autoround", type=str, help="the directory to save quantized model") +parser.add_argument("--eval_dataset", default="captions_source.tsv", type=str, help="eval datasets") +parser.add_argument("--output_image_path", default="./tmp_imgs", type=str, help="the directory to save quantized model") +parser.add_argument("--iters", "--iter", default=1000, type=int, help="tuning iters") +parser.add_argument("--limit", default=-1, type=int, help="limit the number of prompts for evaluation") + +args = parser.parse_args() + + +def inference_worker(eval_file, pipe, image_save_dir): + gen_kwargs = { + "guidance_scale": 7.5, + "num_inference_steps": 50, + "generator": None, + } + + dataloader, _, _ = get_diffusion_dataloader(eval_file, nsamples=args.limit, bs=1) + for image_ids, prompts in dataloader: + + new_ids = [] + new_prompts = [] + for idx, image_id in enumerate(image_ids): + image_id = image_id.item() + + if os.path.exists(os.path.join(image_save_dir, str(image_id) + ".png")): + continue + new_ids.append(image_id) + new_prompts.append(prompts[idx]) + + if len(new_prompts) == 0: + continue + + output = pipe(prompt=new_prompts, **gen_kwargs) + for idx, image_id in enumerate(new_ids): + output.images[idx].save(os.path.join(image_save_dir, str(image_id) + ".png")) + + +def tune(): + pipe = AutoPipelineForText2Image.from_pretrained(args.model, torch_dtype=torch.bfloat16) + model = pipe.transformer + layer_config = {} + kwargs = {} + if args.scheme == "FP8": + for n, m in model.named_modules(): + if m.__class__.__name__ == "Linear": + layer_config[n] = {"bits": 8, "data_type": "fp", "group_size": 0} + elif args.scheme == "MXFP8": + kwargs["scheme"] = { + "bits": 8, + "group_size": 32, + "data_type": "mx_fp", + } + + qconfig = AutoRoundConfig( + iters=args.iters, + dataset=args.dataset, + layer_config=layer_config, + num_inference_steps=3, + export_format="fake", + nsamples=128, + batch_size=1, + output_dir=args.output_dir, + **kwargs + ) + model = prepare(model, qconfig) + model = convert(model, qconfig, pipeline=pipe) + +if __name__ == '__main__': + device = "cpu" if torch.cuda.device_count() == 0 else "cuda" + + if args.quantize: + print(f"Start to quantize {args.model}.") + tune() + exit(0) + + if args.inference: + pipe = AutoPipelineForText2Image.from_pretrained(args.model, torch_dtype=torch.bfloat16) + + if not os.path.exists(args.output_image_path): + os.makedirs(args.output_image_path) + + if os.path.exists(args.output_dir) and os.path.exists(os.path.join(args.output_dir, "diffusion_pytorch_model.safetensors.index.json")): + print(f"Loading quantized model from {args.output_dir}") + model = FluxTransformer2DModel.from_pretrained(args.output_dir, torch_dtype=torch.bfloat16) + + # replace Linear's forward function + if args.scheme == "MXFP8": + def act_qdq_forward(module, x, *args, **kwargs): + qdq_x, _, _ = quant_mx_rceil(x, bits=8, group_size=32, data_type="mx_fp_rceil") + return module.orig_forward(qdq_x, *args, **kwargs) + + all_quant_blocks = get_block_names(model) + + for block_names in all_quant_blocks: + for block_name in block_names: + block = get_module(model, block_name) + for n, m in block.named_modules(): + if m.__class__.__name__ == "Linear": + m.orig_forward = m.forward + m.forward = partial(act_qdq_forward, m) + + if args.scheme == "FP8": + def act_qdq_forward(module, x, *args, **kwargs): + qdq_x, _, _ = quant_fp8_sym(x, group_size=0) + return module.orig_forward(qdq_x, *args, **kwargs) + + for n, m in model.named_modules(): + if m.__class__.__name__ == "Linear": + m.orig_forward = m.forward + m.forward = partial(act_qdq_forward, m) + + pipe.transformer = model + + else: + print("Don't supply quantized_model_path or quantized model doesn't exist, evaluate BF16 accuracy.") + + inference_worker(args.eval_dataset, pipe.to(device), args.output_image_path) + + if args.accuracy: + df = pd.read_csv(args.eval_dataset, sep="\t") + prompt_list = [] + image_list = [] + for index, row in df.iterrows(): + assert "id" in row and "caption" in row + caption_id = row["id"] + caption_text = row["caption"] + if os.path.exists(os.path.join(args.output_image_path, str(caption_id) + ".png")): + prompt_list.append(caption_text) + image_list.append(os.path.join(args.output_image_path, str(caption_id) + ".png")) + + result = {} + metrics = ["clip", "clip-iqa", "imagereward"] + for metric in metrics: + result.update(metric_map[metric](prompt_list, image_list, device)) + + print(tabulate.tabulate(result.items(), tablefmt="grid")) diff --git a/examples/pytorch/diffusion_model/diffusers/flux/requirements.txt b/examples/pytorch/diffusion_model/diffusers/flux/requirements.txt new file mode 100644 index 00000000000..1d6637869b3 --- /dev/null +++ b/examples/pytorch/diffusion_model/diffusers/flux/requirements.txt @@ -0,0 +1,6 @@ +diffusers==0.35.1 +pandas==2.2.2 +clip==0.2.0 +image-reward==1.5 +torchmetrics==1.8.2 +transformers==4.55.0 diff --git a/examples/pytorch/diffusion_model/diffusers/flux/run_benchmark.sh b/examples/pytorch/diffusion_model/diffusers/flux/run_benchmark.sh new file mode 100644 index 00000000000..7fe1006ba69 --- /dev/null +++ b/examples/pytorch/diffusion_model/diffusers/flux/run_benchmark.sh @@ -0,0 +1,92 @@ +#!/bin/bash +set -x + +function main { + + init_params "$@" + run_benchmark + +} + +# init params +function init_params { + for var in "$@" + do + case $var in + --topology=*) + topology=$(echo $var |cut -f2 -d=) + ;; + --dataset_location=*) + dataset_location=$(echo $var |cut -f2 -d=) + ;; + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --quantized_model=*) + tuned_checkpoint=$(echo $var |cut -f2 -d=) + ;; + --limit=*) + limit=$(echo $var |cut -f2 -d=) + ;; + --output_image_path=*) + output_image_path=$(echo $var |cut -f2 -d=) + ;; + *) + echo "Error: No such parameter: ${var}" + exit 1 + ;; + esac + done + +} + + +# run_benchmark +function run_benchmark { + dataset_location=${dataset_location:="captions_source.tsv"} + limit=${limit:=-1} + output_image_path=${output_image_path:="./tmp_imgs"} + + if [ "${topology}" = "flux_fp8" ]; then + extra_cmd="--scheme FP8 --inference" + elif [ "${topology}" = "flux_mxfp8" ]; then + extra_cmd="--scheme MXFP8 --inference" + fi + + if [ -n "$CUDA_VISIBLE_DEVICES" ]; then + gpu_list="${CUDA_VISIBLE_DEVICES:-}" + IFS=',' read -ra gpu_ids <<< "$gpu_list" + visible_gpus=${#gpu_ids[@]} + echo "visible_gpus: ${visible_gpus}" + + python dataset_split.py --split_num ${visible_gpus} --input_file ${dataset_location} --limit ${limit} + + for ((i=0; i bool: return any([v.get("data_type", None) == "fp8_to_int_sym" for v in self.quant_config.values()]) @@ -252,13 +263,16 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): Returns: The quantized model. """ + pipe = kwargs.pop("pipeline", None) tokenizer = getattr(model.orig_model, "tokenizer", None) if tokenizer is not None: delattr(model.orig_model, "tokenizer") - else: + elif pipe is None: tokenizer = "Placeholder" self.dataset = CapturedDataloader(model.args_list, model.kwargs_list) model = model.orig_model + if pipe is not None: + model = pipe rounder = AutoRound( model, layer_config=self.layer_config, @@ -307,6 +321,9 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): truncation=self.truncation, enable_torch_compile=self.enable_torch_compile, quant_lm_head=self.quant_lm_head, + guidance_scale=self.guidance_scale, + num_inference_steps=self.num_inference_steps, + generator_seed=self.generator_seed, ) if self.enable_w4afp8: diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 4936fbbe213..1f9b50d7339 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -608,6 +608,7 @@ def autoround_quantize_entry( "act_data_type": act_data_type, } layer_config = quant_config.to_dict().get("layer_config", None) + dataset = quant_config.to_dict().get("dataset", "NeelNanda/pile-10k") output_dir = quant_config.to_dict().get("output_dir", "temp_auto_round") enable_full_range = quant_config.enable_full_range batch_size = quant_config.batch_size @@ -642,6 +643,9 @@ def autoround_quantize_entry( scheme = quant_config.scheme device_map = quant_config.device_map quant_lm_head = quant_config.quant_lm_head + guidance_scale = quant_config.to_dict().get("guidance_scale", 7.5) + num_inference_steps = quant_config.to_dict().get("num_inference_steps", 50) + generator_seed = quant_config.to_dict().get("generator_seed", None) kwargs.pop("example_inputs") quantizer = get_quantizer( @@ -665,6 +669,7 @@ def autoround_quantize_entry( batch_size=batch_size, amp=amp, lr_scheduler=lr_scheduler, + dataset=dataset, enable_quanted_input=enable_quanted_input, enable_minmax_tuning=enable_minmax_tuning, lr=lr, @@ -694,6 +699,9 @@ def autoround_quantize_entry( scheme=scheme, device_map=device_map, quant_lm_head=quant_lm_head, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator_seed=generator_seed, ) model = quantizer.execute(model=model, mode=mode, *args, **kwargs) model.qconfig = configs_mapping diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py index 60e138876d5..d9bad24283b 100644 --- a/neural_compressor/torch/quantization/quantize.py +++ b/neural_compressor/torch/quantization/quantize.py @@ -228,6 +228,7 @@ def convert( model: torch.nn.Module, quant_config: BaseConfig = None, inplace: bool = True, + **kwargs, ): """Convert the prepared model to a quantized model. @@ -284,6 +285,7 @@ def convert( configs_mapping, example_inputs=example_inputs, mode=Mode.CONVERT, + **kwargs, ) setattr(q_model, "is_quantized", True) return q_model