diff --git a/examples/README.md b/examples/README.md
index b427d39d2c2..90251320be4 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -15,6 +15,12 @@ IntelĀ® Neural Compressor validated examples with multiple compression technique
+
+ | FLUX.1-dev |
+ Text to Image |
+ Quantization (MXFP8+FP8) |
+ link |
+
| Llama-4-Scout-17B-16E-Instruct |
Multimodal Modeling |
diff --git a/examples/pytorch/diffusion_model/diffusers/flux/README.md b/examples/pytorch/diffusion_model/diffusers/flux/README.md
new file mode 100644
index 00000000000..bca01cbf32c
--- /dev/null
+++ b/examples/pytorch/diffusion_model/diffusers/flux/README.md
@@ -0,0 +1,44 @@
+# Step-by-Step
+
+This example quantizes and validates the accuracy of Flux.
+
+# Prerequisite
+
+## 1. Environment
+
+```shell
+pip install -r requirements.txt
+# Use `INC_PT_ONLY=1 pip install git+https://github.com/intel/neural-compressor.git@v3.6rc` for the latest updates before neural-compressor v3.6 release
+pip install neural-compressor-pt==3.6
+# Use `pip install git+https://github.com/intel/auto-round.git@v0.8.0rc2` for the latest updates before auto-round v0.8.0 release
+pip install auto-round==0.8.0
+```
+
+## 2. Prepare Model
+
+```shell
+hf download black-forest-labs/FLUX.1-dev --local-dir FLUX.1-dev
+```
+
+## 3. Prepare Dataset
+```shell
+wget https://github.com/mlcommons/inference/raw/refs/heads/master/text_to_image/coco2014/captions/captions_source.tsv
+```
+
+# Run
+
+## Quantization
+
+```bash
+bash run_quant.sh --topology=flux_mxfp8 --input_model=FLUX.1-dev --output_model=mxfp8_model
+```
+- topology: support flux_fp8 and flux_mxfp8
+
+
+## Evaluation
+
+```bash
+CUDA_VISIBLE_DEVICES=0,1,2,3 bash run_benchmark.sh --topology=flux_mxfp8 --input_model=FLUX.1-dev --quantized_model=mxfp8_model
+```
+
+- CUDA_VISIBLE_DEVICES: split the evaluation file into the number of GPUs' subset to speed up the evaluation
diff --git a/examples/pytorch/diffusion_model/diffusers/flux/dataset_split.py b/examples/pytorch/diffusion_model/diffusers/flux/dataset_split.py
new file mode 100644
index 00000000000..56015d1f3db
--- /dev/null
+++ b/examples/pytorch/diffusion_model/diffusers/flux/dataset_split.py
@@ -0,0 +1,22 @@
+import argparse
+import pandas as pd
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--split_num', type=int)
+parser.add_argument('--limit', default=-1, type=int)
+parser.add_argument('--input_file', type=str)
+parser.add_argument('--output_file', default="subset", type=str)
+args = parser.parse_args()
+
+# load the TSV file
+df = pd.read_csv(args.input_file, sep='\t')
+
+if args.limit > 0:
+ df = df.iloc[0:args.limit]
+
+num = round(len(df) / args.split_num)
+for i in range(args.split_num):
+ start = i * num
+ end = min((i + 1) * num, len(df))
+ df_subset = df.iloc[start:end]
+ df_subset.to_csv(f"{args.output_file}_{i}.tsv", sep='\t', index=False)
diff --git a/examples/pytorch/diffusion_model/diffusers/flux/main.py b/examples/pytorch/diffusion_model/diffusers/flux/main.py
new file mode 100644
index 00000000000..c9e0b98d5c0
--- /dev/null
+++ b/examples/pytorch/diffusion_model/diffusers/flux/main.py
@@ -0,0 +1,182 @@
+# Copyright (c) 2025 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import sys
+import argparse
+
+import pandas as pd
+import tabulate
+import torch
+
+from diffusers import AutoPipelineForText2Image, FluxTransformer2DModel
+from functools import partial
+from neural_compressor.torch.quantization import (
+ AutoRoundConfig,
+ convert,
+ prepare,
+)
+from auto_round.data_type.mxfp import quant_mx_rceil
+from auto_round.data_type.fp8 import quant_fp8_sym
+from auto_round.utils import get_block_names, get_module
+from auto_round.compressors.diffusion.eval import metric_map
+from auto_round.compressors.diffusion.dataset import get_diffusion_dataloader
+
+
+parser = argparse.ArgumentParser(
+ description="Flux quantization.", formatter_class=argparse.ArgumentDefaultsHelpFormatter
+)
+parser.add_argument("--model", "--model_name", "--model_name_or_path", help="model name or path")
+parser.add_argument('--scheme', default="MXFP8", type=str, help="quantizaion scheme.")
+parser.add_argument("--quantize", action="store_true")
+parser.add_argument("--inference", action="store_true")
+parser.add_argument("--accuracy", action="store_true")
+parser.add_argument("--dataset", type=str, default="coco2014", help="the dataset for quantization training.")
+parser.add_argument("--output_dir", "--quantized_model_path", default="./tmp_autoround", type=str, help="the directory to save quantized model")
+parser.add_argument("--eval_dataset", default="captions_source.tsv", type=str, help="eval datasets")
+parser.add_argument("--output_image_path", default="./tmp_imgs", type=str, help="the directory to save quantized model")
+parser.add_argument("--iters", "--iter", default=1000, type=int, help="tuning iters")
+parser.add_argument("--limit", default=-1, type=int, help="limit the number of prompts for evaluation")
+
+args = parser.parse_args()
+
+
+def inference_worker(eval_file, pipe, image_save_dir):
+ gen_kwargs = {
+ "guidance_scale": 7.5,
+ "num_inference_steps": 50,
+ "generator": None,
+ }
+
+ dataloader, _, _ = get_diffusion_dataloader(eval_file, nsamples=args.limit, bs=1)
+ for image_ids, prompts in dataloader:
+
+ new_ids = []
+ new_prompts = []
+ for idx, image_id in enumerate(image_ids):
+ image_id = image_id.item()
+
+ if os.path.exists(os.path.join(image_save_dir, str(image_id) + ".png")):
+ continue
+ new_ids.append(image_id)
+ new_prompts.append(prompts[idx])
+
+ if len(new_prompts) == 0:
+ continue
+
+ output = pipe(prompt=new_prompts, **gen_kwargs)
+ for idx, image_id in enumerate(new_ids):
+ output.images[idx].save(os.path.join(image_save_dir, str(image_id) + ".png"))
+
+
+def tune():
+ pipe = AutoPipelineForText2Image.from_pretrained(args.model, torch_dtype=torch.bfloat16)
+ model = pipe.transformer
+ layer_config = {}
+ kwargs = {}
+ if args.scheme == "FP8":
+ for n, m in model.named_modules():
+ if m.__class__.__name__ == "Linear":
+ layer_config[n] = {"bits": 8, "data_type": "fp", "group_size": 0}
+ elif args.scheme == "MXFP8":
+ kwargs["scheme"] = {
+ "bits": 8,
+ "group_size": 32,
+ "data_type": "mx_fp",
+ }
+
+ qconfig = AutoRoundConfig(
+ iters=args.iters,
+ dataset=args.dataset,
+ layer_config=layer_config,
+ num_inference_steps=3,
+ export_format="fake",
+ nsamples=128,
+ batch_size=1,
+ output_dir=args.output_dir,
+ **kwargs
+ )
+ model = prepare(model, qconfig)
+ model = convert(model, qconfig, pipeline=pipe)
+
+if __name__ == '__main__':
+ device = "cpu" if torch.cuda.device_count() == 0 else "cuda"
+
+ if args.quantize:
+ print(f"Start to quantize {args.model}.")
+ tune()
+ exit(0)
+
+ if args.inference:
+ pipe = AutoPipelineForText2Image.from_pretrained(args.model, torch_dtype=torch.bfloat16)
+
+ if not os.path.exists(args.output_image_path):
+ os.makedirs(args.output_image_path)
+
+ if os.path.exists(args.output_dir) and os.path.exists(os.path.join(args.output_dir, "diffusion_pytorch_model.safetensors.index.json")):
+ print(f"Loading quantized model from {args.output_dir}")
+ model = FluxTransformer2DModel.from_pretrained(args.output_dir, torch_dtype=torch.bfloat16)
+
+ # replace Linear's forward function
+ if args.scheme == "MXFP8":
+ def act_qdq_forward(module, x, *args, **kwargs):
+ qdq_x, _, _ = quant_mx_rceil(x, bits=8, group_size=32, data_type="mx_fp_rceil")
+ return module.orig_forward(qdq_x, *args, **kwargs)
+
+ all_quant_blocks = get_block_names(model)
+
+ for block_names in all_quant_blocks:
+ for block_name in block_names:
+ block = get_module(model, block_name)
+ for n, m in block.named_modules():
+ if m.__class__.__name__ == "Linear":
+ m.orig_forward = m.forward
+ m.forward = partial(act_qdq_forward, m)
+
+ if args.scheme == "FP8":
+ def act_qdq_forward(module, x, *args, **kwargs):
+ qdq_x, _, _ = quant_fp8_sym(x, group_size=0)
+ return module.orig_forward(qdq_x, *args, **kwargs)
+
+ for n, m in model.named_modules():
+ if m.__class__.__name__ == "Linear":
+ m.orig_forward = m.forward
+ m.forward = partial(act_qdq_forward, m)
+
+ pipe.transformer = model
+
+ else:
+ print("Don't supply quantized_model_path or quantized model doesn't exist, evaluate BF16 accuracy.")
+
+ inference_worker(args.eval_dataset, pipe.to(device), args.output_image_path)
+
+ if args.accuracy:
+ df = pd.read_csv(args.eval_dataset, sep="\t")
+ prompt_list = []
+ image_list = []
+ for index, row in df.iterrows():
+ assert "id" in row and "caption" in row
+ caption_id = row["id"]
+ caption_text = row["caption"]
+ if os.path.exists(os.path.join(args.output_image_path, str(caption_id) + ".png")):
+ prompt_list.append(caption_text)
+ image_list.append(os.path.join(args.output_image_path, str(caption_id) + ".png"))
+
+ result = {}
+ metrics = ["clip", "clip-iqa", "imagereward"]
+ for metric in metrics:
+ result.update(metric_map[metric](prompt_list, image_list, device))
+
+ print(tabulate.tabulate(result.items(), tablefmt="grid"))
diff --git a/examples/pytorch/diffusion_model/diffusers/flux/requirements.txt b/examples/pytorch/diffusion_model/diffusers/flux/requirements.txt
new file mode 100644
index 00000000000..1d6637869b3
--- /dev/null
+++ b/examples/pytorch/diffusion_model/diffusers/flux/requirements.txt
@@ -0,0 +1,6 @@
+diffusers==0.35.1
+pandas==2.2.2
+clip==0.2.0
+image-reward==1.5
+torchmetrics==1.8.2
+transformers==4.55.0
diff --git a/examples/pytorch/diffusion_model/diffusers/flux/run_benchmark.sh b/examples/pytorch/diffusion_model/diffusers/flux/run_benchmark.sh
new file mode 100644
index 00000000000..7fe1006ba69
--- /dev/null
+++ b/examples/pytorch/diffusion_model/diffusers/flux/run_benchmark.sh
@@ -0,0 +1,92 @@
+#!/bin/bash
+set -x
+
+function main {
+
+ init_params "$@"
+ run_benchmark
+
+}
+
+# init params
+function init_params {
+ for var in "$@"
+ do
+ case $var in
+ --topology=*)
+ topology=$(echo $var |cut -f2 -d=)
+ ;;
+ --dataset_location=*)
+ dataset_location=$(echo $var |cut -f2 -d=)
+ ;;
+ --input_model=*)
+ input_model=$(echo $var |cut -f2 -d=)
+ ;;
+ --quantized_model=*)
+ tuned_checkpoint=$(echo $var |cut -f2 -d=)
+ ;;
+ --limit=*)
+ limit=$(echo $var |cut -f2 -d=)
+ ;;
+ --output_image_path=*)
+ output_image_path=$(echo $var |cut -f2 -d=)
+ ;;
+ *)
+ echo "Error: No such parameter: ${var}"
+ exit 1
+ ;;
+ esac
+ done
+
+}
+
+
+# run_benchmark
+function run_benchmark {
+ dataset_location=${dataset_location:="captions_source.tsv"}
+ limit=${limit:=-1}
+ output_image_path=${output_image_path:="./tmp_imgs"}
+
+ if [ "${topology}" = "flux_fp8" ]; then
+ extra_cmd="--scheme FP8 --inference"
+ elif [ "${topology}" = "flux_mxfp8" ]; then
+ extra_cmd="--scheme MXFP8 --inference"
+ fi
+
+ if [ -n "$CUDA_VISIBLE_DEVICES" ]; then
+ gpu_list="${CUDA_VISIBLE_DEVICES:-}"
+ IFS=',' read -ra gpu_ids <<< "$gpu_list"
+ visible_gpus=${#gpu_ids[@]}
+ echo "visible_gpus: ${visible_gpus}"
+
+ python dataset_split.py --split_num ${visible_gpus} --input_file ${dataset_location} --limit ${limit}
+
+ for ((i=0; i bool:
return any([v.get("data_type", None) == "fp8_to_int_sym" for v in self.quant_config.values()])
@@ -252,13 +263,16 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
Returns:
The quantized model.
"""
+ pipe = kwargs.pop("pipeline", None)
tokenizer = getattr(model.orig_model, "tokenizer", None)
if tokenizer is not None:
delattr(model.orig_model, "tokenizer")
- else:
+ elif pipe is None:
tokenizer = "Placeholder"
self.dataset = CapturedDataloader(model.args_list, model.kwargs_list)
model = model.orig_model
+ if pipe is not None:
+ model = pipe
rounder = AutoRound(
model,
layer_config=self.layer_config,
@@ -307,6 +321,9 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
truncation=self.truncation,
enable_torch_compile=self.enable_torch_compile,
quant_lm_head=self.quant_lm_head,
+ guidance_scale=self.guidance_scale,
+ num_inference_steps=self.num_inference_steps,
+ generator_seed=self.generator_seed,
)
if self.enable_w4afp8:
diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py
index 4936fbbe213..1f9b50d7339 100644
--- a/neural_compressor/torch/quantization/algorithm_entry.py
+++ b/neural_compressor/torch/quantization/algorithm_entry.py
@@ -608,6 +608,7 @@ def autoround_quantize_entry(
"act_data_type": act_data_type,
}
layer_config = quant_config.to_dict().get("layer_config", None)
+ dataset = quant_config.to_dict().get("dataset", "NeelNanda/pile-10k")
output_dir = quant_config.to_dict().get("output_dir", "temp_auto_round")
enable_full_range = quant_config.enable_full_range
batch_size = quant_config.batch_size
@@ -642,6 +643,9 @@ def autoround_quantize_entry(
scheme = quant_config.scheme
device_map = quant_config.device_map
quant_lm_head = quant_config.quant_lm_head
+ guidance_scale = quant_config.to_dict().get("guidance_scale", 7.5)
+ num_inference_steps = quant_config.to_dict().get("num_inference_steps", 50)
+ generator_seed = quant_config.to_dict().get("generator_seed", None)
kwargs.pop("example_inputs")
quantizer = get_quantizer(
@@ -665,6 +669,7 @@ def autoround_quantize_entry(
batch_size=batch_size,
amp=amp,
lr_scheduler=lr_scheduler,
+ dataset=dataset,
enable_quanted_input=enable_quanted_input,
enable_minmax_tuning=enable_minmax_tuning,
lr=lr,
@@ -694,6 +699,9 @@ def autoround_quantize_entry(
scheme=scheme,
device_map=device_map,
quant_lm_head=quant_lm_head,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ generator_seed=generator_seed,
)
model = quantizer.execute(model=model, mode=mode, *args, **kwargs)
model.qconfig = configs_mapping
diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py
index 60e138876d5..d9bad24283b 100644
--- a/neural_compressor/torch/quantization/quantize.py
+++ b/neural_compressor/torch/quantization/quantize.py
@@ -228,6 +228,7 @@ def convert(
model: torch.nn.Module,
quant_config: BaseConfig = None,
inplace: bool = True,
+ **kwargs,
):
"""Convert the prepared model to a quantized model.
@@ -284,6 +285,7 @@ def convert(
configs_mapping,
example_inputs=example_inputs,
mode=Mode.CONVERT,
+ **kwargs,
)
setattr(q_model, "is_quantized", True)
return q_model