From d26d64dc22efdb16d39373b856a0e94ff07a35ad Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Thu, 31 Jul 2025 13:13:56 +0800 Subject: [PATCH 1/5] add stable diffusion int8 example Signed-off-by: Kaihui-intel --- .../stable_diffusion/static_quant/README.md | 60 +++ .../static_quant/download_dataset.sh | 24 ++ .../stable_diffusion/static_quant/main.py | 371 ++++++++++++++++++ .../static_quant/requirements.txt | 6 + 4 files changed, 461 insertions(+) create mode 100644 examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/README.md create mode 100644 examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/download_dataset.sh create mode 100644 examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/main.py create mode 100644 examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/requirements.txt diff --git a/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/README.md b/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/README.md new file mode 100644 index 00000000000..4fe8bbcdc26 --- /dev/null +++ b/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/README.md @@ -0,0 +1,60 @@ +# Stable Diffusion + +Stable Diffusion quantization and inference best known configurations with static quant. + +## Model Information + +| **Use Case** | **Framework** | **Model Repo** | **Branch/Commit/Tag** | **Optional Patch** | +|:---:| :---: |:--------------:|:---------------------:|:------------------:| +| Inference | PyTorch | https://huggingface.co/stabilityai/stable-diffusion-2-1 | - | - | + +# Pre-Requisite +* Installation of PyTorch and [Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/#installation) + + + +### Datasets + +Download the 2017 [COCO dataset](https://cocodataset.org) using the `download_dataset.sh` script. +Export the `DATASET_DIR` environment variable to specify the directory where the dataset will be downloaded. This environment variable will be used again when running training scripts. +``` +export DATASET_DIR= +bash download_dataset.sh +``` + +# Quantization and Inference +quantization +```shell +python main.py \ + --dataset_path=${DATASET_DIR} \ + --quantized_model_path=${INT8_MODEL} \ + --compile_inductor \ + --precision=int8-bf16 \ + --calibration +``` +inference +```shell +python main.py \ + --dataset_path=${DATASET_DIR} \ + --precision=int8-bf16 \ + --benchmark \ + -w 1 \ + -i 10 \ + --quantized_model_path=${INT8_MODEL} \ --compile_inductor +``` +## FID evaluation +We have also evaluated FID scores on COCO2017 validation dataset for BF16 model, mixture of BF16 and INT8 model. FID results are listed below. + +| Model | BF16 | INT8+BF16 | +|----------------------|-------|-----------| +| stable-diffusion-2-1 | 27.94 | 27.14 | + +To evaluated FID score on COCO2017 validation dataset for mixture of BF16 and INT8 model, you can use below command. + +```bash +python main.py \ + --dataset_path=${DATASET_DIR} \ + --precision=int8-bf16 \ + --accuracy \ + --quantized_model_path=${INT8_MODEL} \ --compile_inductor +``` \ No newline at end of file diff --git a/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/download_dataset.sh b/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/download_dataset.sh new file mode 100644 index 00000000000..8771da1442a --- /dev/null +++ b/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/download_dataset.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +# +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +DATASET_DIR=${DATASET_DIR-$PWD} + +dir=$(pwd) +mkdir ${DATASET_DIR}; cd ${DATASET_DIR} +curl -O http://images.cocodataset.org/zips/val2017.zip; unzip val2017.zip +curl -O http://images.cocodataset.org/annotations/annotations_trainval2017.zip; unzip annotations_trainval2017.zip +cd $dir diff --git a/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/main.py b/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/main.py new file mode 100644 index 00000000000..035fa014f2e --- /dev/null +++ b/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/main.py @@ -0,0 +1,371 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import logging +import os +import time +import threading +from tqdm import tqdm + +import torch +from PIL import Image +from diffusers import DiffusionPipeline +from torchmetrics.image.fid import FrechetInceptionDistance +import torchvision.datasets as dset +import torchvision.transforms as transforms + +logging.getLogger().setLevel(logging.INFO) + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_name_or_path", type=str, default="stabilityai/stable-diffusion-2-1", help="Model path") + parser.add_argument("--quantized_model_path", type=str, default="quantized_model.pt", help="INT8 model path") + parser.add_argument("--dataset_path", type=str, default=None, help="COCO2017 dataset path") + parser.add_argument("--prompt", type=str, default="A big burly grizzly bear is show with grass in the background.", help="input text") + parser.add_argument("--output_dir", type=str, default=None,help="output path") + parser.add_argument("--seed", type=int, default=0, help="random seed") + parser.add_argument('--precision', type=str, default="fp32", help='precision: fp32, bf32, bf16, fp16, int8-bf16, int8-fp32') + parser.add_argument('--calibration', action='store_true', default=False, help='doing calibration step for LCM int8') + parser.add_argument('--compile_inductor', action='store_true', default=False, help='compile with inductor backend') + parser.add_argument('--profile', action='store_true', default=False, help='profile') + parser.add_argument('--benchmark', action='store_true', default=False, help='test performance') + parser.add_argument('--accuracy', action='store_true', default=False, help='test accuracy') + parser.add_argument('-w', '--warmup_iterations', default=-1, type=int, help='number of warmup iterations to run') + parser.add_argument('-i', '--iterations', default=-1, type=int, help='number of total iterations to run') + parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training') + parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training') + parser.add_argument('--dist-url', default='env://', type=str, help='url used to set up distributed training') + parser.add_argument('--dist-backend', default='ccl', type=str, help='distributed backend') + parser.add_argument("--weight-sharing", action='store_true', default=False, help="using weight_sharing to test the performance of inference") + parser.add_argument("--number-instance", default=0, type=int, help="the instance numbers for test the performance of latcy, only works when enable weight-sharing") + + args = parser.parse_args() + return args + +def run_weights_sharing_model(pipe, tid, args): + total_time = 0 + for i in range(args.iterations + args.warmup_iterations): + # run model + start = time.time() + if args.precision == "bf16" or args.precision == "fp16" or args.precision == "int8-bf16": + with torch.autocast("cpu", dtype=args.dtype), torch.no_grad(): + output = pipe(args.prompt, generator=torch.manual_seed(args.seed)).images + else: + with torch.no_grad(): + output = pipe(args.prompt, generator=torch.manual_seed(args.seed)).images + end = time.time() + print('time per prompt(s): {:.2f}'.format((end - start))) + if i >= args.warmup_iterations: + total_time += end - start + + print("Instance num: ", tid) + print("Latency: {:.2f} s".format(total_time / args.iterations)) + print("Throughput: {:.5f} samples/sec".format(args.iterations / total_time)) + +def main(): + + args = parse_args() + logging.info(f"Parameters {args}") + + # CCL related + os.environ['MASTER_ADDR'] = str(os.environ.get('MASTER_ADDR', '127.0.0.1')) + os.environ['MASTER_PORT'] = '29500' + os.environ['RANK'] = str(os.environ.get('PMI_RANK', 0)) + os.environ['WORLD_SIZE'] = str(os.environ.get('PMI_SIZE', 1)) + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + print("World size: ", args.world_size) + + args.distributed = args.world_size > 1 + if args.distributed: + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + + # load model + pipe = DiffusionPipeline.from_pretrained(args.model_name_or_path) + if not args.accuracy: + pipe.safety_checker = None + + # data type + if args.precision == "fp32": + print("Running fp32 ...") + args.dtype=torch.float32 + elif args.precision == "bf32": + print("Running bf32 ...") + args.dtype=torch.float32 + elif args.precision == "bf16": + print("Running bf16 ...") + args.dtype=torch.bfloat16 + elif args.precision == "fp16": + print("Running fp16 ...") + args.dtype=torch.half + elif args.precision == "int8-bf16": + print("Running int8-bf16 ...") + args.dtype=torch.bfloat16 + elif args.precision == "int8-fp32": + print("Running int8-fp32 ...") + args.dtype=torch.float32 + else: + raise ValueError("--precision needs to be the following: fp32, bf32, bf16, fp16, int8-bf16, int8-fp32") + + if args.compile_inductor: + pipe.precision = torch.float32 + elif args.model_name_or_path == "SimianLuo/LCM_Dreamshaper_v7" and args.precision == "int8-bf16": + pipe.precision = torch.float32 + else: + pipe.precision = args.dtype + if args.model_name_or_path == "stabilityai/stable-diffusion-2-1": + text_encoder_input = torch.ones((1, 77), dtype=torch.int64) + input = torch.randn(2, 4, 96, 96).to(memory_format=torch.channels_last).to(dtype=pipe.precision), torch.tensor(921), torch.randn(2, 77, 1024).to(dtype=pipe.precision) + elif args.model_name_or_path == "SimianLuo/LCM_Dreamshaper_v7": + text_encoder_input = torch.ones((1, 77), dtype=torch.int64) + input = torch.randn(1, 4, 96, 96).to(memory_format=torch.channels_last).to(dtype=pipe.precision), torch.tensor(921), torch.randn(1, 77, 768).to(dtype=pipe.precision), torch.randn(1, 256).to(dtype=pipe.precision) + else: + raise ValueError("This script currently only supports stabilityai/stable-diffusion-2-1 and SimianLuo/LCM_Dreamshaper_v7.") + + if args.distributed: + import oneccl_bindings_for_pytorch + torch.distributed.init_process_group(backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank) + print("Rank and world size: ", torch.distributed.get_rank()," ", torch.distributed.get_world_size()) + # print("Create DistributedDataParallel in CPU") + # pipe = torch.nn.parallel.DistributedDataParallel(pipe) + + # prepare dataloader + val_coco = dset.CocoCaptions(root = '{}/val2017'.format(args.dataset_path), + annFile = '{}/annotations/captions_val2017.json'.format(args.dataset_path), + transform=transforms.Compose([transforms.Resize((512, 512)), transforms.PILToTensor(), ])) + + if args.distributed: + val_sampler = torch.utils.data.distributed.DistributedSampler(val_coco, shuffle=False) + else: + val_sampler = None + + val_dataloader = torch.utils.data.DataLoader(val_coco, + batch_size=1, + shuffle=False, + num_workers=0, + sampler=val_sampler) + + + # torch.compile with inductor backend + if args.compile_inductor: + print("torch.compile with inductor backend ...") + # torch._inductor.config.profiler_mark_wrapper_call = True + # torch._inductor.config.cpp.enable_kernel_profile = True + torch._inductor.config.cpp.enable_concat_linear = True + from torch._inductor import config as inductor_config + inductor_config.cpp_wrapper = True + if args.precision == "fp32": + with torch.no_grad(): + pipe.unet = torch.compile(pipe.unet) + pipe.unet(*input) + pipe.unet(*input) + pipe.text_encoder = torch.compile(pipe.text_encoder) + pipe.vae.decode = torch.compile(pipe.vae.decode) + elif args.precision == "bf16": + with torch.autocast("cpu", ), torch.no_grad(): + pipe.unet = torch.compile(pipe.unet) + pipe.unet(*input) + pipe.unet(*input) + pipe.text_encoder = torch.compile(pipe.text_encoder) + pipe.vae.decode = torch.compile(pipe.vae.decode) + elif args.precision == "fp16": + with torch.autocast("cpu", dtype=torch.half), torch.no_grad(): + pipe.unet = torch.compile(pipe.unet) + pipe.unet(*input) + pipe.unet(*input) + pipe.text_encoder = torch.compile(pipe.text_encoder) + pipe.vae.decode = torch.compile(pipe.vae.decode) + elif args.precision == "int8-fp32" or args.precision == "int8-bf16": + from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e + import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq + from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer + from torch.export import export_for_training + if args.calibration: + with torch.no_grad(): + pipe.traced_unet = export_for_training(pipe.unet, input).module() + quantizer = X86InductorQuantizer() + if args.model_name_or_path == "SimianLuo/LCM_Dreamshaper_v7": + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) \ + .set_module_name_qconfig("up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_q", None) \ + .set_module_name_qconfig("up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_k", None) \ + .set_module_name_qconfig("up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_v", None) \ + .set_module_name_qconfig("up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0", None) \ + .set_module_name_qconfig("up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2", None) \ + .set_module_name_qconfig("up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj", None) \ + .set_module_name_qconfig("up_blocks.2.resnets.0.time_emb_proj", None) \ + .set_module_name_qconfig("up_blocks.2.resnets.1.time_emb_proj", None) \ + .set_module_name_qconfig("up_blocks.2.resnets.2.time_emb_proj", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_q", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_k", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_v", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_out.0", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_q", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_k", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_v", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_out.0", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.0.transformer_blocks.0.ff.net.2", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.0.transformer_blocks.0.ff.net.0.proj", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_q", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_k", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_v", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_out.0", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_q", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_k", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_v", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_out.0", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.1.transformer_blocks.0.ff.net.2", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_q", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_k", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_v", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_out.0", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_q", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_k", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_v", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.2.transformer_blocks.0.ff.net.2", None) \ + .set_module_name_qconfig("up_blocks.3.attentions.2.transformer_blocks.0.ff.net.0.proj", None) \ + .set_module_name_qconfig("up_blocks.3.resnets.0.time_emb_proj", None) \ + .set_module_name_qconfig("up_blocks.3.resnets.1.time_emb_proj", None) \ + .set_module_name_qconfig("up_blocks.3.resnets.2.time_emb_proj", None) \ + .set_module_name_qconfig("mid_block.attentions.0.transformer_blocks.0.attn1.to_q", None) \ + .set_module_name_qconfig("mid_block.attentions.0.transformer_blocks.0.attn1.to_k", None) \ + .set_module_name_qconfig("mid_block.attentions.0.transformer_blocks.0.attn1.to_v", None) \ + .set_module_name_qconfig("mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0", None) \ + .set_module_name_qconfig("mid_block.attentions.0.transformer_blocks.0.attn2.to_q", None) \ + .set_module_name_qconfig("mid_block.attentions.0.transformer_blocks.0.attn2.to_k", None) \ + .set_module_name_qconfig("mid_block.attentions.0.transformer_blocks.0.attn2.to_v", None) \ + .set_module_name_qconfig("mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0", None) \ + .set_module_name_qconfig("mid_block.attentions.0.transformer_blocks.0.ff.net.2", None) \ + .set_module_name_qconfig("mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj", None) \ + .set_module_name_qconfig("mid_block.resnets.0.time_emb_proj", None) \ + .set_module_name_qconfig("mid_block.resnets.slice(1, None, None)._modules.0.time_emb_proj", None) + else: + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + pipe.traced_unet = prepare_pt2e(pipe.traced_unet, quantizer) + # calibration + if args.model_name_or_path == "SimianLuo/LCM_Dreamshaper_v7": + for i, (images, prompts) in enumerate(tqdm(val_dataloader)): + prompt = prompts[0][0] + pipe(prompt, generator=torch.manual_seed(args.seed)) + if i == 119: + break + else: + pipe(args.prompt) + pipe.traced_unet = convert_pt2e(pipe.traced_unet) + + quantized_unet = torch.export.export(pipe.traced_unet, input) + torch.export.save(quantized_unet, args.quantized_model_path) + print(".........calibration step done..........") + return + else: + quantized_unet = torch.export.load(args.quantized_model_path) + pipe.traced_unet = quantized_unet.module() + torch.ao.quantization.move_exported_model_to_eval(pipe.traced_unet) + if args.precision == "int8-fp32": + with torch.no_grad(): + pipe.traced_unet = torch.compile(pipe.traced_unet) + pipe.traced_unet(*input) + pipe.traced_unet(*input) + pipe.text_encoder = torch.compile(pipe.text_encoder) + pipe.vae.decode = torch.compile(pipe.vae.decode) + elif args.precision == "int8-bf16": + with torch.autocast("cpu", ), torch.no_grad(): + pipe.traced_unet = torch.compile(pipe.traced_unet) + pipe.traced_unet(*input) + pipe.traced_unet(*input) + pipe.text_encoder = torch.compile(pipe.text_encoder) + pipe.vae.decode = torch.compile(pipe.vae.decode) + else: + raise ValueError("If you want to use torch.compile with inductor backend, --precision needs to be the following: fp32, bf16, int8-bf16, int8-fp32") + + # benchmark + if args.benchmark: + print("Running benchmark ...") + if args.weight_sharing: + print("weight sharing ...") + threads = [] + for i in range(1, args.number_instance+1): + thread = threading.Thread(target=run_weights_sharing_model, args=(pipe, i, args)) + threads.append(thread) + thread.start() + for thread in threads: + thread.join() + exit() + else: + total_time = 0 + for i in range(args.iterations + args.warmup_iterations): + # run model + start = time.time() + if args.precision == "bf16" or args.precision == "fp16" or args.precision == "int8-bf16": + with torch.autocast("cpu", dtype=args.dtype), torch.no_grad(): + output = pipe(args.prompt, generator=torch.manual_seed(args.seed)).images + else: + with torch.no_grad(): + output = pipe(args.prompt, generator=torch.manual_seed(args.seed)).images + end = time.time() + print('time per prompt(s): {:.2f}'.format((end - start))) + if i >= args.warmup_iterations: + total_time += end - start + + print("Latency: {:.2f} s".format(total_time / args.iterations)) + print("Throughput: {:.5f} samples/sec".format(args.iterations / total_time)) + + if args.accuracy: + print("Running accuracy ...") + # run model + if args.distributed: + torch.distributed.barrier() + fid = FrechetInceptionDistance(normalize=True) + for i, (images, prompts) in enumerate(tqdm(val_dataloader)): + prompt = prompts[0][0] + real_image = images[0] + print("prompt: ", prompt) + if args.precision == "bf16" or args.precision == "fp16" or args.precision == "int8-bf16": + with torch.autocast("cpu", dtype=args.dtype), torch.no_grad(): + output = pipe(prompt, generator=torch.manual_seed(args.seed), output_type="numpy").images + else: + with torch.no_grad(): + output = pipe(prompt, generator=torch.manual_seed(args.seed), output_type="numpy").images + + if args.output_dir: + if not os.path.exists(args.output_dir): + os.mkdir(args.output_dir) + image_name = time.strftime("%Y%m%d_%H%M%S") + Image.fromarray((output[0] * 255).round().astype("uint8")).save(f"{args.output_dir}/fake_image_{image_name}.png") + Image.fromarray(real_image.permute(1, 2, 0).numpy()).save(f"{args.output_dir}/real_image_{image_name}.png") + + fake_image = torch.tensor(output[0]).unsqueeze(0).permute(0, 3, 1, 2) + real_image = real_image.unsqueeze(0) / 255.0 + + fid.update(real_image, real=True) + fid.update(fake_image, real=False) + + if args.iterations > 0 and i == args.iterations - 1: + break + + print(f"FID: {float(fid.compute())}") + +if __name__ == '__main__': + main() diff --git a/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/requirements.txt b/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/requirements.txt new file mode 100644 index 00000000000..c4e248e7e41 --- /dev/null +++ b/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/requirements.txt @@ -0,0 +1,6 @@ +diffusers +tqdm +torch-fidelity +torchmetrics +pycocotools +transformers \ No newline at end of file From 2fddc159595b53bbe11ebc756ed7f77cb9c96095 Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Thu, 31 Jul 2025 13:18:39 +0800 Subject: [PATCH 2/5] update copyright Signed-off-by: Kaihui-intel --- .../diffusers/stable_diffusion/static_quant/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/main.py b/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/main.py index 035fa014f2e..82b65146402 100644 --- a/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/main.py +++ b/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/main.py @@ -1,7 +1,7 @@ # # -*- coding: utf-8 -*- # -# Copyright (c) 2023 Intel Corporation +# Copyright (c) 2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From e0ccc9dbe99ac9385ac228eb6a6e491352eb8507 Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Thu, 31 Jul 2025 13:39:25 +0800 Subject: [PATCH 3/5] update README Signed-off-by: Kaihui-intel --- .../diffusers/stable_diffusion/static_quant/README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/README.md b/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/README.md index 4fe8bbcdc26..beb0b59d80f 100644 --- a/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/README.md +++ b/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/README.md @@ -9,9 +9,6 @@ Stable Diffusion quantization and inference best known configurations with stati | Inference | PyTorch | https://huggingface.co/stabilityai/stable-diffusion-2-1 | - | - | # Pre-Requisite -* Installation of PyTorch and [Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/#installation) - - ### Datasets @@ -26,6 +23,7 @@ bash download_dataset.sh quantization ```shell python main.py \ + --model_name_or_path stabilityai/stable-diffusion-2-1 \ --dataset_path=${DATASET_DIR} \ --quantized_model_path=${INT8_MODEL} \ --compile_inductor \ @@ -35,6 +33,7 @@ python main.py \ inference ```shell python main.py \ + --model_name_or_path stabilityai/stable-diffusion-2-1 \ --dataset_path=${DATASET_DIR} \ --precision=int8-bf16 \ --benchmark \ From ec5399c96fb9c9a382a12c25e3b5594484ccdf9a Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Mon, 11 Aug 2025 09:21:17 +0800 Subject: [PATCH 4/5] update FID data Signed-off-by: Kaihui-intel --- .../stable_diffusion/static_quant/README.md | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/README.md b/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/README.md index beb0b59d80f..5354ef990fe 100644 --- a/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/README.md +++ b/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/README.md @@ -4,9 +4,10 @@ Stable Diffusion quantization and inference best known configurations with stati ## Model Information -| **Use Case** | **Framework** | **Model Repo** | **Branch/Commit/Tag** | **Optional Patch** | -|:---:| :---: |:--------------:|:---------------------:|:------------------:| -| Inference | PyTorch | https://huggingface.co/stabilityai/stable-diffusion-2-1 | - | - | +| **Framework** | **Model Repo** | +|:-------------:|:-------------------------------------------------------------------:| +| PyTorch | https://huggingface.co/stabilityai/stable-diffusion-2-1 | +| PyTorch | https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7 | # Pre-Requisite @@ -39,14 +40,16 @@ python main.py \ --benchmark \ -w 1 \ -i 10 \ - --quantized_model_path=${INT8_MODEL} \ --compile_inductor + --quantized_model_path=${INT8_MODEL} \ + --compile_inductor ``` ## FID evaluation We have also evaluated FID scores on COCO2017 validation dataset for BF16 model, mixture of BF16 and INT8 model. FID results are listed below. | Model | BF16 | INT8+BF16 | |----------------------|-------|-----------| -| stable-diffusion-2-1 | 27.94 | 27.14 | +| stable-diffusion-2-1 | 27.8630 | 27.8618 | +| SimianLuo/LCM_Dreamshaper_v7|42.1710| 42.3138| To evaluated FID score on COCO2017 validation dataset for mixture of BF16 and INT8 model, you can use below command. @@ -55,5 +58,6 @@ python main.py \ --dataset_path=${DATASET_DIR} \ --precision=int8-bf16 \ --accuracy \ - --quantized_model_path=${INT8_MODEL} \ --compile_inductor + --quantized_model_path=${INT8_MODEL} \ + --compile_inductor ``` \ No newline at end of file From 7a4de07c339cbde38c003ce773696ed2231c731f Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Tue, 26 Aug 2025 13:04:10 +0800 Subject: [PATCH 5/5] add jenkins test Signed-off-by: Kaihui-intel --- examples/.config/model_params_pytorch_3x.json | 16 ++++ .../static_quant/run_benchmark.sh | 95 +++++++++++++++++++ .../static_quant/run_quant.sh | 61 ++++++++++++ 3 files changed, 172 insertions(+) create mode 100644 examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/run_benchmark.sh create mode 100644 examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/run_quant.sh diff --git a/examples/.config/model_params_pytorch_3x.json b/examples/.config/model_params_pytorch_3x.json index 0d75e7b3902..cafaab4157a 100644 --- a/examples/.config/model_params_pytorch_3x.json +++ b/examples/.config/model_params_pytorch_3x.json @@ -228,6 +228,22 @@ "main_script": "main.py", "batch_size": 1 }, + "sd21_static_int8":{ + "model_src_dir": "diffusion_model/diffusers/stable_diffusion/static_quant", + "dataset_location": "/tf_dataset2/datasets/coco2017/coco/", + "input_model": "", + "main_script": "main.py", + "batch_size": 1, + "iters": 10 + }, + "lcm_static_int8":{ + "model_src_dir": "diffusion_model/diffusers/stable_diffusion/static_quant", + "dataset_location": "/tf_dataset2/datasets/coco2017/coco/", + "input_model": "", + "main_script": "main.py", + "batch_size": 1, + "iters": 10 + }, "resnet18_mixed_precision": { "model_src_dir": "cv/mixed_precision", "dataset_location": "/tf_dataset/pytorch/ImageNet/raw", diff --git a/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/run_benchmark.sh b/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/run_benchmark.sh new file mode 100644 index 00000000000..c227d5de9bb --- /dev/null +++ b/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/run_benchmark.sh @@ -0,0 +1,95 @@ +#!/bin/bash +set -x + +function main { + + init_params "$@" + run_benchmark + +} + +# init params +function init_params { + iters=10 + batch_size=8 + tuned_checkpoint=saved_results + echo ${max_eval_samples} + for var in "$@" + do + case $var in + --topology=*) + topology=$(echo $var |cut -f2 -d=) + ;; + --dataset_location=*) + dataset_location=$(echo $var |cut -f2 -d=) + ;; + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --mode=*) + mode=$(echo $var |cut -f2 -d=) + ;; + --batch_size=*) + batch_size=$(echo $var |cut -f2 -d=) + ;; + --iters=*) + iters=$(echo ${var} |cut -f2 -d=) + ;; + --optimized=*) + optimized=$(echo ${var} |cut -f2 -d=) + ;; + --config=*) + tuned_checkpoint=$(echo $var |cut -f2 -d=) + ;; + *) + echo "Error: No such parameter: ${var}" + exit 1 + ;; + esac + done + +} + + +# run_benchmark +function run_benchmark { + extra_cmd='' + mode_cmd='' + DATASET_DIR=${dataset_location} + tuned_checkpoint="unet_quantized_model.pt2" + + if [[ ${mode} == "accuracy" ]]; then + mode_cmd=" --accuracy " + elif [[ ${mode} == "performance" ]]; then + mode_cmd=" --benchmark -w 1 -i ${iters} " + else + echo "Error: No such mode: ${mode}" + exit 1 + fi + + if [[ ${optimized} == "true" ]]; then + extra_cmd=$extra_cmd" --quantized_model_path=${tuned_checkpoint} --precision=int8-bf16 " + else + extra_cmd=$extra_cmd" --precision=bf16 " + fi + echo $extra_cmd + + if [ "${topology}" = "sd21_static_int8" ]; then + model_name_or_path="stabilityai/stable-diffusion-2-1" + elif [ "${topology}" = "lcm_static_int8" ]; then + model_name_or_path="SimianLuo/LCM_Dreamshaper_v7" + else + echo "Error: No such topology: ${topology}" + exit 1 + fi + + + python main.py \ + --model_name_or_path ${model_name_or_path} \ + --dataset_path=${DATASET_DIR} \ + --compile_inductor \ + ${extra_cmd} ${mode_cmd} + +} + +main "$@" diff --git a/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/run_quant.sh b/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/run_quant.sh new file mode 100644 index 00000000000..1ebb861212f --- /dev/null +++ b/examples/3.x_api/pytorch/diffusion_model/diffusers/stable_diffusion/static_quant/run_quant.sh @@ -0,0 +1,61 @@ +#!/bin/bash +set -x + +function main { + + init_params "$@" + run_tuning + +} + +# init params +function init_params { + for var in "$@" + do + case $var in + --topology=*) + topology=$(echo $var |cut -f2 -d=) + ;; + --dataset_location=*) + dataset_location=$(echo $var |cut -f2 -d=) + ;; + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --output_model=*) + tuned_checkpoint=$(echo $var |cut -f2 -d=) + ;; + *) + echo "Error: No such parameter: ${var}" + exit 1 + ;; + esac + done + +} + +# run_tuning +function run_tuning { + extra_cmd='' + DATASET_DIR=${dataset_location} + tuned_checkpoint="unet_quantized_model.pt2" + + if [ "${topology}" = "sd21_static_int8" ]; then + model_name_or_path="stabilityai/stable-diffusion-2-1" + elif [ "${topology}" = "lcm_static_int8" ]; then + model_name_or_path="SimianLuo/LCM_Dreamshaper_v7" + else + echo "Error: No such topology: ${topology}" + exit 1 + fi + + python main.py \ + --model_name_or_path ${model_name_or_path} \ + --dataset_path=${DATASET_DIR} \ + --quantized_model_path=${tuned_checkpoint} \ + --compile_inductor \ + --precision=int8-bf16 \ + --calibration +} + +main "$@"