From ac8bef8128a00fdb39c34d856fa733d05d7f2b83 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Mon, 2 Feb 2026 14:16:47 +0800 Subject: [PATCH 1/4] judge llm --- ais_bench/benchmark/cli/workers.py | 139 +++++++++++++++++- .../aime2025/aime2025_gen_0_shot_llmjudge.py | 117 +++++++++++++++ ais_bench/benchmark/datasets/aime2025.py | 24 ++- ais_bench/benchmark/datasets/base.py | 19 +++ .../benchmark/datasets/utils/datasets.py | 6 +- .../benchmark/datasets/utils/llm_judge.py | 36 +++++ ais_bench/benchmark/utils/file/file.py | 31 +++- 7 files changed, 358 insertions(+), 14 deletions(-) create mode 100644 ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py create mode 100644 ais_bench/benchmark/datasets/utils/llm_judge.py diff --git a/ais_bench/benchmark/cli/workers.py b/ais_bench/benchmark/cli/workers.py index ca997164..f75cce7e 100644 --- a/ais_bench/benchmark/cli/workers.py +++ b/ais_bench/benchmark/cli/workers.py @@ -1,3 +1,4 @@ +import os import os.path as osp import copy from abc import ABC, abstractmethod @@ -8,12 +9,15 @@ from ais_bench.benchmark.registry import PARTITIONERS, RUNNERS, build_from_cfg from ais_bench.benchmark.utils.config.run import get_config_type from ais_bench.benchmark.utils.logging.logger import AISLogger +from ais_bench.benchmark.utils.logging.exceptions import PredictionInvalidException +from ais_bench.benchmark.utils.logging.error_codes import TMAN_CODES from ais_bench.benchmark.partitioners import NaivePartitioner from ais_bench.benchmark.runners import LocalRunner from ais_bench.benchmark.tasks import OpenICLEvalTask, OpenICLApiInferTask, OpenICLInferTask from ais_bench.benchmark.summarizers import DefaultSummarizer, DefaultPerfSummarizer from ais_bench.benchmark.calculators import DefaultPerfMetricCalculator from ais_bench.benchmark.cli.utils import fill_model_path_if_datasets_need +from ais_bench.benchmark.utils.file.file import load_jsonl, dump_jsonl logger = AISLogger() @@ -75,7 +79,7 @@ def do_work(self, cfg: ConfigDict): logger.info("Merging datasets with the same model and inferencer...") tasks = self._merge_datasets(tasks) - runner = RUNNERS.build(cfg.infer.runner) + runner = RUNNERS.build(cfg.judge_infer.runner) runner(tasks) logger.info("Inference tasks completed.") @@ -108,6 +112,117 @@ def _update_tasks_cfg(self, tasks, cfg: ConfigDict): task.attack = cfg.attack +class JudgeInfer(BaseWorker): + def update_cfg(self, cfg: ConfigDict) -> None: + def get_task_type() -> str: + if cfg["models"][0]["attr"] == "service": + return get_config_type(OpenICLApiInferTask) + else: + return get_config_type(OpenICLInferTask) + + new_cfg = dict( + judge_infer=dict( + partitioner=dict(type=get_config_type(NaivePartitioner)), + runner=dict( + max_num_workers=self.args.max_num_workers, + max_workers_per_gpu=self.args.max_workers_per_gpu, + debug=self.args.debug, + task=dict(type=get_task_type()), + type=get_config_type(LocalRunner), + ), + ), + ) + + cfg.merge_from_dict(new_cfg) + if cfg.cli_args.debug: + cfg.judge_infer.runner.debug = True + cfg.judge_infer.partitioner["out_dir"] = osp.join(cfg["work_dir"], "predictions/") + return cfg + + def do_work(self, cfg: ConfigDict): + partitioner = PARTITIONERS.build(cfg.judge_infer.partitioner) + logger.info("Starting inference tasks...") + tasks = partitioner(cfg) + + # delete the tasks without judge_infer_cfg + new_tasks = [] + for task in tasks: + if task["datasets"][0][0].get("judge_infer_cfg"): + new_tasks.append(task) + tasks = new_tasks + if len(tasks) == 0: + return + + # update tasks cfg before run + self._update_tasks_cfg(tasks, cfg) + + if ( + cfg.get("cli_args", {}).get("merge_ds", False) + or cfg.get("cli_args", {}).get("mode") == "perf" # performance mode will enable merge datasets by default + ): + logger.info("Merging datasets with the same model and inferencer...") + tasks = self._merge_datasets(tasks) + + runner = RUNNERS.build(cfg.judge_infer.runner) + runner(tasks) + self._result_post_process(tasks, cfg) + logger.info("Inference tasks completed.") + + def _merge_datasets(self, tasks): + # merge datasets with the same model, dataset type and inferencer + task_groups = defaultdict(list) + for task in tasks: + key = ( + task["models"][0]["abbr"] # same model + + "_" + + str(task['datasets'][0][0]['type']) # same dataset type + + "_" + + str(task["datasets"][0][0]["infer_cfg"]["inferencer"]) # same inferencer with the same args + ) + task_groups[key].append(task) + new_tasks = [] + for key, task_group in task_groups.items(): + new_task = copy.deepcopy(task_group[0]) + if len(task_group) > 1: + for t in task_group[1:]: + new_task["datasets"][0].extend(t["datasets"][0]) + new_tasks.append(new_task) + return new_tasks + + def _update_tasks_cfg(self, tasks, cfg: ConfigDict): + # update parameters to correct sub cfg + if hasattr(cfg, "attack"): + for task in tasks: + cfg.attack.dataset = task.datasets[0][0].abbr + task.attack = cfg.attack + + # update judge cfgs to model cfgs and data + for task in tasks: + task["datasets"][0][0]["predictions_path"] = osp.join(cfg.judge_infer.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl') + if not osp.exists(task["datasets"][0][0]["predictions_path"]): + raise PredictionInvalidException(TMAN_CODES.UNKNOWN_ERROR, f"Predictions path {task['datasets'][0][0]['predictions_path']} does not exist.") + task["datasets"][0][0]["abbr"] = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"].pop("additional_abbr")}' + model_abbr = task["models"][0]["abbr"] + task["models"][0] = task["datasets"][0][0]["judge_infer_cfg"].pop("judge_model") + task["models"][0]["abbr"] = model_abbr + task["datasets"][0][0]["type"] = task["datasets"][0][0]["judge_infer_cfg"].pop("judge_dataset_type") + task["datasets"][0][0]["reader_cfg"] = task["datasets"][0][0]["judge_infer_cfg"].pop("judge_reader_cfg") + task["datasets"][0][0]["infer_cfg"] = task["datasets"][0][0].pop("judge_infer_cfg") + + def _result_post_process(self, tasks, cfg: ConfigDict): + # Reconstruct the judge infer predictions to normal predictions format + for task in tasks: + model_org_prediction_path = task["datasets"][0][0]["predictions_path"] + model_preds: dict = {item["uuid"]: item for item in load_jsonl(model_org_prediction_path)} + judge_org_prediction_path = osp.join(cfg.judge_infer.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl') + judge_preds: list = load_jsonl(judge_org_prediction_path) + for i, pred in enumerate(judge_preds): + uuid = pred["gold"] + judge_preds[i]["id"] = model_preds[uuid]["id"] + os.remove(judge_org_prediction_path) + dump_jsonl(judge_preds, judge_org_prediction_path) + + class Eval(BaseWorker): def update_cfg(self, cfg: ConfigDict) -> None: new_cfg = dict( @@ -138,7 +253,7 @@ def do_work(self, cfg: ConfigDict): logger.info("Starting evaluation tasks...") tasks = partitioner(cfg) - # update tasks cfg before run + # Update tasks cfg before run self._update_tasks_cfg(tasks, cfg) runner = RUNNERS.build(cfg.eval.runner) @@ -151,9 +266,11 @@ def do_work(self, cfg: ConfigDict): logger.info("Evaluation tasks completed.") def _update_tasks_cfg(self, tasks, cfg: ConfigDict): - # update parameters to correct sub cfg - pass - + # Replace default model config to judge model config + for task in tasks: + if task["datasets"][0][0].get("judge_infer_cfg"): + task["datasets"][0][0]["abbr"] = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"].pop("additional_abbr")}' + task["datasets"][0][0].pop("judge_infer_cfg") class AccViz(BaseWorker): def update_cfg(self, cfg: ConfigDict) -> None: @@ -171,6 +288,7 @@ def update_cfg(self, cfg: ConfigDict) -> None: def do_work(self, cfg: ConfigDict) -> int: logger.info("Summarizing evaluation results...") summarizer_cfg = cfg.get("summarizer", {}) + cfg = self._cfg_pre_process(cfg) # For subjective summarizer if summarizer_cfg.get("function", None): @@ -203,6 +321,13 @@ def do_work(self, cfg: ConfigDict) -> int: summarizer = build_from_cfg(summarizer_cfg) summarizer.summarize(time_str=self.args.cfg_time_str) + def _cfg_pre_process(self, cfg: ConfigDict) -> None: + for i, dataset in enumerate(cfg.datasets): + if dataset.get("judge_infer_cfg"): + cfg.datasets[i]["abbr"] = f'{cfg.datasets[i]["abbr"]}-{cfg.datasets[i]["judge_infer_cfg"]["judge_model"].pop("additional_abbr")}' + cfg.datasets[i].pop("judge_infer_cfg") + return cfg + class PerfViz(BaseWorker): def update_cfg(self, cfg: ConfigDict) -> None: @@ -233,9 +358,9 @@ def do_work(self, cfg: ConfigDict) -> int: WORK_FLOW = dict( - all=[Infer, Eval, AccViz], + all=[Infer, JudgeInfer, Eval, AccViz], infer=[Infer], - eval=[Eval, AccViz], + eval=[JudgeInfer, Eval, AccViz], viz=[AccViz], perf=[Infer, PerfViz], perf_viz=[PerfViz], diff --git a/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py b/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py new file mode 100644 index 00000000..c87fb232 --- /dev/null +++ b/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py @@ -0,0 +1,117 @@ +from ais_bench.benchmark.openicl.icl_prompt_template import PromptTemplate +from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever +from ais_bench.benchmark.openicl.icl_inferencer import GenInferencer +from ais_bench.benchmark.models import VLLMCustomAPIChat +from ais_bench.benchmark.utils.postprocess.model_postprocessors import extract_non_reasoning_content +from ais_bench.benchmark.datasets import ( + Aime2025Dataset, + Aime2025JDGDataset, +) +from ais_bench.benchmark.datasets.utils.llm_judge import get_a_or_b, LLMJudgeCorrectEvaluator + + +aime2025_reader_cfg = dict(input_columns=["question"], output_column="answer") + + +aime2025_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict( + role="HUMAN", + prompt="{question}\nRemember to put your final answer within \\boxed{}.", + ), + ], + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +GRADER_TEMPLATE = """ + Please as a grading expert, judge whether the final answers given by the candidates below are consistent with the standard answers, that is, whether the candidates answered correctly. + + Here are some evaluation criteria: + 1. Please refer to the given standard answer. You don't need to re-generate the answer to the question because the standard answer has been given. You only need to judge whether the candidate's answer is consistent with the standard answer according to the form of the question. Don't try to answer the original question. You can assume that the standard answer is definitely correct. + 2. Because the candidate's answer may be different from the standard answer in the form of expression, before making a judgment, please understand the question and the standard answer first, and then judge whether the candidate's answer is correct, but be careful not to try to answer the original question. + 3. Some answers may contain multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. As long as the answer is the same as the standard answer, it is enough. For multiple-select questions and multiple-blank fill-in-the-blank questions, the candidate needs to answer all the corresponding options or blanks correctly to be considered correct. + 4. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. And some formulas are expressed in different ways, but they are equivalent and correct. + 5. If the prediction is given with \\boxed{}, please ignore the \\boxed{} and only judge whether the candidate's answer is consistent with the standard answer. + + Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of: + A: CORRECT + B: INCORRECT + Just return the letters "A" or "B", with no text around it. + + Here is your task. Simply reply with either CORRECT, INCORRECT. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer. + + + : \n{question}\n\n\n + : \n{answer}\n\n\n + : \n{model_answer}\n\n\n + + Judging the correctness of candidates' answers: +""".strip() + +aime2025_judge_infer_cfg = dict( + judge_reader_cfg = dict(input_columns=["question", "answer", "model_answer"], output_column="model_pred_uuid"), + judge_model=dict( + attr="service", + type=VLLMCustomAPIChat, + additional_abbr="judge", # Be added after dataset abbr + path="", + model="", + stream=True, + request_rate=0, + use_timestamp=False, + retry=2, + api_key="", + host_ip="localhost", + host_port=8080, + url="", + max_out_len=512, + batch_size=1, + trust_remote_code=False, + generation_kwargs=dict( + temperature=0.01, + ignore_eos=False, + ), + pred_postprocessor=dict(type=extract_non_reasoning_content), + ), + judge_dataset_type=Aime2025JDGDataset, + prompt_template=dict( + type=PromptTemplate, + template=dict( + begin=[ + dict( + role='SYSTEM', + fallback_role='HUMAN', + prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.", + ) + ], + round=[ + dict(role='HUMAN', prompt=GRADER_TEMPLATE), + ], + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +aime2025_eval_cfg = dict( + evaluator=dict(type=LLMJudgeCorrectEvaluator), + pred_postprocessor=dict(type=get_a_or_b), +) + +aime2025_datasets = [ + dict( + abbr="aime2025", + type=Aime2025Dataset, + path="ais_bench/datasets/aime2025/aime2025.jsonl", + reader_cfg=aime2025_reader_cfg, + infer_cfg=aime2025_infer_cfg, + judge_infer_cfg=aime2025_judge_infer_cfg, + eval_cfg=aime2025_eval_cfg, + ) +] diff --git a/ais_bench/benchmark/datasets/aime2025.py b/ais_bench/benchmark/datasets/aime2025.py index 6e67b07d..548e28b2 100644 --- a/ais_bench/benchmark/datasets/aime2025.py +++ b/ais_bench/benchmark/datasets/aime2025.py @@ -1,16 +1,15 @@ -import json +import json from datasets import Dataset from ais_bench.benchmark.registry import LOAD_DATASET from ais_bench.benchmark.datasets.utils.datasets import get_data_path -from .base import BaseDataset +from .base import BaseDataset, BaseJDGDatasetMethod @LOAD_DATASET.register_module() class Aime2025Dataset(BaseDataset): - @staticmethod def load(path, **kwargs): path = get_data_path(path) @@ -20,3 +19,22 @@ def load(path, **kwargs): line = json.loads(line.strip()) dataset.append(line) return Dataset.from_list(dataset) + +class Aime2025JDGDataset(Aime2025Dataset): + def load(self, path, predictions_path, **kwargs): + + dataset_content = Aime2025Dataset.load(path, **kwargs) + + # 加载被测模型的推理结果(排序后) + predictions: list = BaseJDGDatasetMethod.load_from_predictions(predictions_path) + + # 为数据集添加 model_answer 列 + dataset_list = [] + + for item in predictions: + item_dict = dataset_content[int(item["id"])] + item_dict["model_answer"] = item["prediction"] + item_dict["model_pred_uuid"] = item["uuid"] # Be filled in gold + dataset_list.append(item_dict) + + return Dataset.from_list(dataset_list) diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index de062a5d..f3deb8f3 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -1,3 +1,5 @@ +import os + from abc import abstractmethod from typing import List, Dict, Optional, Union @@ -8,6 +10,7 @@ from ais_bench.benchmark.utils.logging.logger import AISLogger from ais_bench.benchmark.utils.logging.error_codes import DSET_CODES from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError +from ais_bench.benchmark.utils.file.file import load_jsonl disable_progress_bar() # disable mapping progress bar, preventing terminal interface contamination @@ -108,3 +111,19 @@ def test(self): @abstractmethod def load(**kwargs) -> Union[Dataset, DatasetDict]: pass + +class BaseJDGDatasetMethod: + @staticmethod + def load_from_predictions(prediction_path: str) -> Dict: + """Load predictions from a directory and merge them with the dataset. + + Args: + prediction_dir (str): The directory containing prediction files. + + Returns: + Dataset: The merged dataset with predictions. + """ + if os.path.exists(prediction_path): + preds = load_jsonl(prediction_path) + preds.sort(key=lambda x: x.get('id',0)) + return preds diff --git a/ais_bench/benchmark/datasets/utils/datasets.py b/ais_bench/benchmark/datasets/utils/datasets.py index 9f473d6a..6bb31701 100644 --- a/ais_bench/benchmark/datasets/utils/datasets.py +++ b/ais_bench/benchmark/datasets/utils/datasets.py @@ -69,7 +69,7 @@ def get_sample_data(data_list: list, sample_mode: str = "default", request_count data_list (list): Data list. sample_mode (str): Sample mode. request_count (int): Request count. - + Raises: ValueError: If sample mode is not supported. ValueError: If request count is negative. @@ -101,7 +101,7 @@ def get_sample_data(data_list: list, sample_mode: str = "default", request_count return shuffle_data else: raise ValueError(f"Sample mode: {sample_mode} is not supported!") - + def get_meta_json(dataset_path, meta_path): ori_meta_path = meta_path if not meta_path: @@ -389,7 +389,7 @@ def _to_float(text: str): return relative_change <= max_relative_change else: return prediction.lower() == target.lower() - + def anls_compute(groundtruth, prediction): gt_answer = ' '.join(groundtruth.strip().lower().split()) det_answer = ' '.join(prediction.strip().lower().split()) diff --git a/ais_bench/benchmark/datasets/utils/llm_judge.py b/ais_bench/benchmark/datasets/utils/llm_judge.py new file mode 100644 index 00000000..a17e72b7 --- /dev/null +++ b/ais_bench/benchmark/datasets/utils/llm_judge.py @@ -0,0 +1,36 @@ +import re + +from ais_bench.benchmark.utils.logging import AISLogger +from ais_bench.benchmark.registry import (ICL_EVALUATORS, LOAD_DATASET, + TEXT_POSTPROCESSORS) +from ais_bench.benchmark.openicl.icl_evaluator import BaseEvaluator +logger = AISLogger() + +@TEXT_POSTPROCESSORS.register_module("get_a_or_b") +def get_a_or_b(pred: str) -> str: + """从模型回复中提取A或B""" + match = re.search(r'[AB]', pred) + return match.group(0) if match else 'B' + + +@ICL_EVALUATORS.register_module() +class LLMJudgeCorrectEvaluator(BaseEvaluator): + + def __init__(self): + super().__init__() + + def score(self, predictions, references): + if len(predictions) != len(references): + return {'error': 'preds and refrs have different length'} + correct = 0 + count = 0 + details = [] + for i, j in zip(predictions, references): + detail = {'pred': i, 'answer': j, 'correct': False} + count += 1 + if i == "A": + correct += 1 + detail['correct'] = True + details.append(detail) + result = {'accuracy': 100 * correct / count, 'details': details} + return result \ No newline at end of file diff --git a/ais_bench/benchmark/utils/file/file.py b/ais_bench/benchmark/utils/file/file.py index d6bfde67..47f048dc 100644 --- a/ais_bench/benchmark/utils/file/file.py +++ b/ais_bench/benchmark/utils/file/file.py @@ -1,6 +1,8 @@ from typing import List, Tuple, Union import os import json +import mmap +import orjson import fnmatch import tabulate @@ -226,4 +228,31 @@ def check_mm_custom(path): return False if line["type"] not in ["image", "video", "audio"]: return False - return True \ No newline at end of file + return True + +def load_jsonl(path: str) -> List[dict]: + """Load JSONL file into a list of dictionaries. + + Args: + path: Path to the JSONL file + + Returns: + List of dictionaries, each representing a line in the JSONL file + """ + preds = [] + with open(path, "rb") as f: + mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + for line in iter(mm.readline, b""): + preds.append(orjson.loads(line)) + return preds + +def dump_jsonl(data: List[dict], path: str): + """Dump a list of dictionaries to a JSONL file. + + Args: + data: List of dictionaries to be dumped + path: Path to the output JSONL file + """ + with open(path, 'wb') as f: + for item in data: + f.write(orjson.dumps(item) + b'\n') \ No newline at end of file From 16a9848d2e55a34ae445ac26dc1d92581d4c1744 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Thu, 5 Feb 2026 09:42:36 +0800 Subject: [PATCH 2/4] reconstruct the judgedatasets --- ais_bench/benchmark/cli/workers.py | 27 ++++++- .../aime2025/aime2025_gen_0_shot_llmjudge.py | 9 ++- ais_bench/benchmark/datasets/aime2025.py | 25 ++---- ais_bench/benchmark/datasets/base.py | 77 ++++++++++++++----- .../benchmark/datasets/utils/llm_judge.py | 22 +++++- 5 files changed, 113 insertions(+), 47 deletions(-) diff --git a/ais_bench/benchmark/cli/workers.py b/ais_bench/benchmark/cli/workers.py index f75cce7e..77a5d59d 100644 --- a/ais_bench/benchmark/cli/workers.py +++ b/ais_bench/benchmark/cli/workers.py @@ -1,6 +1,7 @@ import os import os.path as osp import copy +import shutil from abc import ABC, abstractmethod from collections import defaultdict @@ -201,7 +202,7 @@ def _update_tasks_cfg(self, tasks, cfg: ConfigDict): task["datasets"][0][0]["predictions_path"] = osp.join(cfg.judge_infer.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl') if not osp.exists(task["datasets"][0][0]["predictions_path"]): raise PredictionInvalidException(TMAN_CODES.UNKNOWN_ERROR, f"Predictions path {task['datasets'][0][0]['predictions_path']} does not exist.") - task["datasets"][0][0]["abbr"] = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"].pop("additional_abbr")}' + task["datasets"][0][0]["abbr"] = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"]["abbr"]}' model_abbr = task["models"][0]["abbr"] task["models"][0] = task["datasets"][0][0]["judge_infer_cfg"].pop("judge_model") task["models"][0]["abbr"] = model_abbr @@ -263,15 +264,35 @@ def do_work(self, cfg: ConfigDict): runner(task_part) else: runner(tasks) + self._result_post_process(tasks, cfg) logger.info("Evaluation tasks completed.") def _update_tasks_cfg(self, tasks, cfg: ConfigDict): # Replace default model config to judge model config + self.judge_result_paths = {} for task in tasks: if task["datasets"][0][0].get("judge_infer_cfg"): - task["datasets"][0][0]["abbr"] = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"].pop("additional_abbr")}' + new_dataset_abbr = f'{task["datasets"][0][0]["abbr"]}-{task["datasets"][0][0]["judge_infer_cfg"]["judge_model"]["abbr"]}' + org_dataset_abbr = task["datasets"][0][0]["abbr"] + self.judge_result_paths[new_dataset_abbr] = org_dataset_abbr + task["datasets"][0][0]["abbr"] = new_dataset_abbr task["datasets"][0][0].pop("judge_infer_cfg") + def _result_post_process(self, tasks, cfg: ConfigDict): + # Copy judge infer result to normal name + + for task in tasks: + if task["datasets"][0][0]["abbr"] in self.judge_result_paths.keys(): + cur_results_path = osp.join(cfg.eval.partitioner.out_dir, task["models"][0]["abbr"], f'{task["datasets"][0][0]["abbr"]}.jsonl') + final_org_results_path = osp.join(cfg.eval.partitioner.out_dir, task["models"][0]["abbr"], f'{self.judge_result_paths[task["datasets"][0][0]["abbr"]]}.jsonl') + if os.path.exists(final_org_results_path): + os.remove(final_org_results_path) + + if os.path.exists(cur_results_path): + # 基于cur_results_path的文件复制一份final_org_results_path + shutil.copy(cur_results_path, final_org_results_path) + + class AccViz(BaseWorker): def update_cfg(self, cfg: ConfigDict) -> None: summarizer_cfg = cfg.get("summarizer", {}) @@ -324,7 +345,7 @@ def do_work(self, cfg: ConfigDict) -> int: def _cfg_pre_process(self, cfg: ConfigDict) -> None: for i, dataset in enumerate(cfg.datasets): if dataset.get("judge_infer_cfg"): - cfg.datasets[i]["abbr"] = f'{cfg.datasets[i]["abbr"]}-{cfg.datasets[i]["judge_infer_cfg"]["judge_model"].pop("additional_abbr")}' + cfg.datasets[i]["abbr"] = f'{cfg.datasets[i]["abbr"]}-{cfg.datasets[i]["judge_infer_cfg"]["judge_model"]["abbr"]}' cfg.datasets[i].pop("judge_infer_cfg") return cfg diff --git a/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py b/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py index c87fb232..7ece227c 100644 --- a/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py +++ b/ais_bench/benchmark/configs/datasets/aime2025/aime2025_gen_0_shot_llmjudge.py @@ -38,10 +38,11 @@ 3. Some answers may contain multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. As long as the answer is the same as the standard answer, it is enough. For multiple-select questions and multiple-blank fill-in-the-blank questions, the candidate needs to answer all the corresponding options or blanks correctly to be considered correct. 4. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. And some formulas are expressed in different ways, but they are equivalent and correct. 5. If the prediction is given with \\boxed{}, please ignore the \\boxed{} and only judge whether the candidate's answer is consistent with the standard answer. + 6. If the candidate's answer is semantically incomplete at the end, please judge it as inconsistent. Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of: - A: CORRECT - B: INCORRECT + A: Means the answer is consistent with the standard answer. + B: Means the answer is inconsistent with the standard answer. Just return the letters "A" or "B", with no text around it. Here is your task. Simply reply with either CORRECT, INCORRECT. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer. @@ -51,7 +52,7 @@ : \n{answer}\n\n\n : \n{model_answer}\n\n\n - Judging the correctness of candidates' answers: + Judging the correctness of candidates' answers, please return the the letters "A" or "B" first before your thinking: """.strip() aime2025_judge_infer_cfg = dict( @@ -59,7 +60,7 @@ judge_model=dict( attr="service", type=VLLMCustomAPIChat, - additional_abbr="judge", # Be added after dataset abbr + abbr="judge", # Be added after dataset abbr path="", model="", stream=True, diff --git a/ais_bench/benchmark/datasets/aime2025.py b/ais_bench/benchmark/datasets/aime2025.py index 548e28b2..b6b13a1c 100644 --- a/ais_bench/benchmark/datasets/aime2025.py +++ b/ais_bench/benchmark/datasets/aime2025.py @@ -4,9 +4,9 @@ from ais_bench.benchmark.registry import LOAD_DATASET from ais_bench.benchmark.datasets.utils.datasets import get_data_path +from ais_bench.benchmark.datasets.utils.llm_judge import LLMJudgeDataset -from .base import BaseDataset, BaseJDGDatasetMethod - +from ais_bench.benchmark.datasets.base import BaseDataset @LOAD_DATASET.register_module() class Aime2025Dataset(BaseDataset): @@ -20,21 +20,8 @@ def load(path, **kwargs): dataset.append(line) return Dataset.from_list(dataset) -class Aime2025JDGDataset(Aime2025Dataset): - def load(self, path, predictions_path, **kwargs): - - dataset_content = Aime2025Dataset.load(path, **kwargs) - - # 加载被测模型的推理结果(排序后) - predictions: list = BaseJDGDatasetMethod.load_from_predictions(predictions_path) - # 为数据集添加 model_answer 列 - dataset_list = [] - - for item in predictions: - item_dict = dataset_content[int(item["id"])] - item_dict["model_answer"] = item["prediction"] - item_dict["model_pred_uuid"] = item["uuid"] # Be filled in gold - dataset_list.append(item_dict) - - return Dataset.from_list(dataset_list) +@LOAD_DATASET.register_module() +class Aime2025JDGDataset(LLMJudgeDataset): + def _get_dataset_class(self): + return Aime2025Dataset diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index f3deb8f3..243a52b0 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -1,7 +1,5 @@ -import os - from abc import abstractmethod -from typing import List, Dict, Optional, Union +from typing import List, Dict, Optional, Union, Type from datasets import Dataset, DatasetDict from datasets.utils.logging import disable_progress_bar @@ -10,7 +8,6 @@ from ais_bench.benchmark.utils.logging.logger import AISLogger from ais_bench.benchmark.utils.logging.error_codes import DSET_CODES from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError -from ais_bench.benchmark.utils.file.file import load_jsonl disable_progress_bar() # disable mapping progress bar, preventing terminal interface contamination @@ -109,21 +106,61 @@ def test(self): return self.reader.dataset['test'] @abstractmethod - def load(**kwargs) -> Union[Dataset, DatasetDict]: + def load(self, **kwargs) -> Union[Dataset, DatasetDict]: pass -class BaseJDGDatasetMethod: - @staticmethod - def load_from_predictions(prediction_path: str) -> Dict: - """Load predictions from a directory and merge them with the dataset. - - Args: - prediction_dir (str): The directory containing prediction files. - - Returns: - Dataset: The merged dataset with predictions. - """ - if os.path.exists(prediction_path): - preds = load_jsonl(prediction_path) - preds.sort(key=lambda x: x.get('id',0)) - return preds + +class BaseJDGDataset(BaseDataset): + def __init__(self, + reader_cfg: Optional[Dict] = {}, + k: Union[int, List[int]] = 1, + n: int = 1, + **kwargs): + self.dataset_instance = self._init_org_datasets_instance(reader_cfg, k, n, **kwargs) + super().__init__(reader_cfg, k, n, **kwargs) + + def load(self, predictions_path: str, **kwargs): + + dataset_content = self.dataset_instance.dataset["test"] + + # 加载被测模型的推理结果(排序后) + predictions: list = self._load_from_predictions(predictions_path) + + # 为数据集添加 model_answer 列 + if isinstance(dataset_content, Dataset): + dataset_list = [] + for item in predictions: + item_dict = dataset_content[int(item["id"])] + item_dict["model_answer"] = item["prediction"] + item_dict["model_pred_uuid"] = item["uuid"] # Be filled in gold + dataset_list.append(item_dict) + elif isinstance(dataset_content, DatasetDict): + dataset_list = [] + for key in dataset_content: + for item in predictions: + item_dict = dataset_content[key][int(item["id"])] + item_dict["model_answer"] = item["prediction"] + item_dict["model_pred_uuid"] = item["uuid"] # Be filled in gold + dataset_list.append(item_dict) + else: + raise ValueError(f"Unsupported dataset type: {type(dataset_content)}") + + return Dataset.from_list(dataset_list) + + @abstractmethod + def _load_from_predictions(self, prediction_path: str) -> Dict: + pass + + @abstractmethod + def _get_dataset_class(self): + return BaseDataset + + def _init_org_datasets_instance( + self, + reader_cfg: Optional[Dict] = {}, + k: Union[int, List[int]] = 1, + n: int = 1, + **kwargs): + dataset_class = self._get_dataset_class() + return dataset_class(reader_cfg, k, n, **kwargs) + diff --git a/ais_bench/benchmark/datasets/utils/llm_judge.py b/ais_bench/benchmark/datasets/utils/llm_judge.py index a17e72b7..8b6b18f0 100644 --- a/ais_bench/benchmark/datasets/utils/llm_judge.py +++ b/ais_bench/benchmark/datasets/utils/llm_judge.py @@ -1,18 +1,38 @@ import re +import os from ais_bench.benchmark.utils.logging import AISLogger from ais_bench.benchmark.registry import (ICL_EVALUATORS, LOAD_DATASET, TEXT_POSTPROCESSORS) from ais_bench.benchmark.openicl.icl_evaluator import BaseEvaluator +from ais_bench.benchmark.datasets.base import BaseJDGDataset +from ais_bench.benchmark.utils.file.file import load_jsonl logger = AISLogger() + @TEXT_POSTPROCESSORS.register_module("get_a_or_b") def get_a_or_b(pred: str) -> str: """从模型回复中提取A或B""" - match = re.search(r'[AB]', pred) + match = re.search(r'[AB]', pred[-1:]) return match.group(0) if match else 'B' +class LLMJudgeDataset(BaseJDGDataset): + def _load_from_predictions(self, prediction_path: str): + """Load predictions from a directory and merge them with the dataset. + + Args: + prediction_path (str): The path to the prediction file. + + Returns: + Dataset: The merged dataset with predictions. + """ + if os.path.exists(prediction_path): + preds = load_jsonl(prediction_path) + preds.sort(key=lambda x: x.get('id',0)) + return preds + + @ICL_EVALUATORS.register_module() class LLMJudgeCorrectEvaluator(BaseEvaluator): From 312bb1d7cb3c5d384f9cc00aedb3d5ef143bff6b Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 13 Feb 2026 15:12:26 +0800 Subject: [PATCH 3/4] suppport gedit infer --- ais_bench/benchmark/cli/workers.py | 2 +- .../configs/datasets/gedit/gedit_gen.py | 44 +++ .../models/lmm_models/qwen_image_edit.py | 18 + ais_bench/benchmark/datasets/g_edit.py | 94 +++++ ais_bench/benchmark/models/__init__.py | 3 +- .../benchmark/models/local_models/__init__.py | 0 .../benchmark/models/local_models/base.py | 22 +- .../local_models/qwen_image_edit_mindie_sd.py | 335 ++++++++++++++++++ ais_bench/benchmark/models/output.py | 64 +++- .../icl_inferencer/icl_lmm_gen_inferencer.py | 75 ++++ .../icl_inferencer/output_handler/__init__.py | 0 .../output_handler/base_handler.py | 16 +- .../output_handler/bfcl_v3_output_handler.py | 10 +- .../gen_inferencer_output_handler.py | 2 + .../lmm_gen_inferencer_output_handler.py | 72 ++++ .../ppl_inferencer_output_handler.py | 20 +- .../icl_prompt_template_mm.py | 3 +- ais_bench/benchmark/utils/image_process.py | 14 + .../multi_device_run_qwen_image_edit.py | 31 ++ 19 files changed, 803 insertions(+), 22 deletions(-) create mode 100644 ais_bench/benchmark/configs/datasets/gedit/gedit_gen.py create mode 100644 ais_bench/benchmark/configs/models/lmm_models/qwen_image_edit.py create mode 100644 ais_bench/benchmark/datasets/g_edit.py create mode 100644 ais_bench/benchmark/models/local_models/__init__.py create mode 100644 ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py create mode 100644 ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py create mode 100644 ais_bench/benchmark/openicl/icl_inferencer/output_handler/__init__.py create mode 100644 ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py create mode 100644 ais_bench/benchmark/utils/image_process.py create mode 100644 ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py diff --git a/ais_bench/benchmark/cli/workers.py b/ais_bench/benchmark/cli/workers.py index 77a5d59d..689762a0 100644 --- a/ais_bench/benchmark/cli/workers.py +++ b/ais_bench/benchmark/cli/workers.py @@ -80,7 +80,7 @@ def do_work(self, cfg: ConfigDict): logger.info("Merging datasets with the same model and inferencer...") tasks = self._merge_datasets(tasks) - runner = RUNNERS.build(cfg.judge_infer.runner) + runner = RUNNERS.build(cfg.infer.runner) runner(tasks) logger.info("Inference tasks completed.") diff --git a/ais_bench/benchmark/configs/datasets/gedit/gedit_gen.py b/ais_bench/benchmark/configs/datasets/gedit/gedit_gen.py new file mode 100644 index 00000000..57509dee --- /dev/null +++ b/ais_bench/benchmark/configs/datasets/gedit/gedit_gen.py @@ -0,0 +1,44 @@ +from ais_bench.benchmark.openicl.icl_prompt_template.icl_prompt_template_mm import MMPromptTemplate +from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever +from ais_bench.benchmark.openicl.icl_inferencer.icl_lmm_gen_inferencer import LMMGenInferencer +from ais_bench.benchmark.datasets.g_edit import GEditDataset, GEditEvaluator + + +gedit_reader_cfg = dict( + input_columns=['question', 'image'], + output_column='task_type' +) + + +gedit_infer_cfg = dict( + prompt_template=dict( + type=MMPromptTemplate, + template=dict( + round=[ + dict(role="HUMAN", prompt_mm={ + "text": {"type": "text", "text": "{question}"}, + "image": {"type": "image_url", "image_url": {"url": "data:image/png;base64,{image}"}}, + }) + ] + ) + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=LMMGenInferencer) +) + +gedit_eval_cfg = dict( + evaluator=dict(type=GEditEvaluator) +) + +gedit_datasets = [ + dict( + abbr='gedit', + type=GEditDataset, + path='ais_bench/datasets/GEdit-Bench', # 数据集路径,使用相对路径时相对于源码根路径,支持绝对路径 + split_count=1, + split_index=0, + reader_cfg=gedit_reader_cfg, + infer_cfg=gedit_infer_cfg, + eval_cfg=gedit_eval_cfg + ) +] diff --git a/ais_bench/benchmark/configs/models/lmm_models/qwen_image_edit.py b/ais_bench/benchmark/configs/models/lmm_models/qwen_image_edit.py new file mode 100644 index 00000000..c3e92dc4 --- /dev/null +++ b/ais_bench/benchmark/configs/models/lmm_models/qwen_image_edit.py @@ -0,0 +1,18 @@ +from ais_bench.benchmark.models.local_models.qwen_image_edit_mindie_sd import QwenImageEditModel + +models = [ + dict( + attr="local", # local or service + type=QwenImageEditModel, # transformers >= 4.33.0 用这个,prompt 是构造成对话格式 + abbr='qwen-image-edit', + path='/home/yanhe/models/Qwen-Image-Edit-2509/', # path to model dir, current value is just a example + device_kwargs=dict( + ), + infer_kwargs=dict( # 模型参数参考 huggingface.co/docs/transformers/v4.50.0/en/model_doc/auto#transformers.AutoModel.from_pretrained + num_inference_steps=50, + num_images_per_prompt=1, + ), + run_cfg = dict(num_gpus=1, num_procs=1), # 多卡/多机多卡 参数,使用torchrun拉起任务 + batch_size=1, # 每次推理的batch size + ) +] \ No newline at end of file diff --git a/ais_bench/benchmark/datasets/g_edit.py b/ais_bench/benchmark/datasets/g_edit.py new file mode 100644 index 00000000..9d9224b6 --- /dev/null +++ b/ais_bench/benchmark/datasets/g_edit.py @@ -0,0 +1,94 @@ +import json +from datasets import Dataset, load_from_disk, concatenate_datasets +from concurrent.futures import ThreadPoolExecutor, as_completed + +from ais_bench.benchmark.registry import LOAD_DATASET +from ais_bench.benchmark.openicl import BaseEvaluator +from ais_bench.benchmark.datasets.utils.datasets import get_data_path +from ais_bench.benchmark.datasets.utils.llm_judge import LLMJudgeDataset +from ais_bench.benchmark.utils.image_process import pil_to_base64 +from PIL import Image +from tqdm import tqdm + +from ais_bench.benchmark.datasets.base import BaseDataset +from ais_bench.benchmark.utils.prompt import AIS_CONTENT_TAG, AIS_TEXT_START, AIS_IMAGE_START + +GEDIT_COUNT = 10 + +class GEditEvaluator(BaseEvaluator): + def score(self, predictions, references): + details = [] + for i, pred in enumerate(predictions): + details.append({ + 'pred': pred, + 'ref': references[i], + }) + result = {'accuracy': 100 * len(predictions) / len(references), 'details': details} + return result + +@LOAD_DATASET.register_module() +class GEditDataset(BaseDataset): + @staticmethod + def load(path, use_raw=False, split_count=1, split_index=0, **kwargs): + path = get_data_path(path) + dataset = load_from_disk(path) + + # 数据集切分:分成 split_count 份,取第 split_index 份 + if split_count > 1: + total_len = len(dataset) + base_size = total_len // split_count # 每份基础大小 + remainder = total_len % split_count # 余数 + + # 计算当前 split_index 的起始和结束位置 + # 前 remainder 份每份多一个元素 + if split_index < remainder: + start_idx = split_index * (base_size + 1) + end_idx = start_idx + (base_size + 1) + else: + start_idx = remainder * (base_size + 1) + (split_index - remainder) * base_size + end_idx = start_idx + base_size + + dataset = dataset.select(range(start_idx, end_idx)) + else: + dataset = dataset.select(range(GEDIT_COUNT)) + + if use_raw: + image_column = 'input_image_raw' + else: + image_column = 'input_image' + + def process_example_to_dataset(example): + """处理单条数据并转换为 Dataset""" + image_url = pil_to_base64(example[image_column], "PNG") + example['content'] = AIS_IMAGE_START + image_url + AIS_CONTENT_TAG \ + + AIS_TEXT_START + example['instruction'] + AIS_CONTENT_TAG + # 使用 from_dict 替代 from_list 以提高性能 + data_dict = {key: [example[key]] for key in example.keys()} + return Dataset.from_dict(data_dict) + + max_workers = 4 # Adjust based on system resources + processed_datasets = [None] * len(dataset) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # 提交所有任务 + with tqdm(total=len(dataset), desc=f"Submitting tasks split_count: {split_count}, split_index={split_index}", unit="example") as submit_pbar: + futures = {} + for i, example in enumerate(dataset): + future = executor.submit(process_example_to_dataset, example) + futures[future] = i + submit_pbar.update(1) + + # 收集处理完成的 Dataset + with tqdm(total=len(dataset), desc="Processing GEdit dataset", unit="example") as pbar: + for future in as_completed(futures): + idx = futures[future] + processed_datasets[idx] = future.result() + pbar.update(1) + + # 合并所有 Dataset + return concatenate_datasets(processed_datasets) + +@LOAD_DATASET.register_module() +class GEditJDGDataset(LLMJudgeDataset): + def _get_dataset_class(self): + return GEditDataset \ No newline at end of file diff --git a/ais_bench/benchmark/models/__init__.py b/ais_bench/benchmark/models/__init__.py index 5908d946..12230bf1 100644 --- a/ais_bench/benchmark/models/__init__.py +++ b/ais_bench/benchmark/models/__init__.py @@ -14,4 +14,5 @@ from ais_bench.benchmark.models.api_models.triton_api import TritonCustomAPIStream # noqa: F401 from ais_bench.benchmark.models.api_models.tgi_api import TGICustomAPIStream # noqa: F401 from ais_bench.benchmark.models.api_models.vllm_custom_api_chat import VllmMultiturnAPIChatStream # noqa: F401 -from ais_bench.benchmark.models.local_models.vllm_offline_vl import VLLMOfflineVLModel \ No newline at end of file +from ais_bench.benchmark.models.local_models.vllm_offline_vl import VLLMOfflineVLModel +from ais_bench.benchmark.models.local_models.qwen_image_edit_mindie_sd import QwenImageEditModel \ No newline at end of file diff --git a/ais_bench/benchmark/models/local_models/__init__.py b/ais_bench/benchmark/models/local_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ais_bench/benchmark/models/local_models/base.py b/ais_bench/benchmark/models/local_models/base.py index b2766aab..74d43283 100644 --- a/ais_bench/benchmark/models/local_models/base.py +++ b/ais_bench/benchmark/models/local_models/base.py @@ -57,7 +57,7 @@ def __init__(self, self.is_synthetic = False @abstractmethod - def _generate(self, input, max_out_len: int) -> List[str]: + def generate(self, inputs, max_out_len: int) -> List[str]: """Generate result given a input. Args: @@ -133,17 +133,6 @@ def parse_template(self, prompt_template: PromptType, mode: str) -> str: """ return self.template_parser.parse_template(prompt_template, mode) - def generate_from_template(self, templates: List[PromptType], **kwargs): - """Generate completion from a list of templates. - - Args: - templates (List[PromptType]): A list of templates. - max_out_len (int): The maximum length of the output. - """ - inputs = self.parse_template(templates, mode='gen') - max_out_lens = kwargs.get("max_out_lens", [None] * len(templates)) - return self.generate(inputs, max_out_lens, **kwargs) - def get_token_len_from_template( self, templates: Union[PromptType, List[PromptType]], @@ -204,6 +193,15 @@ def sync_inputs(self, inputs: str) -> str: def to(self, device): self.model.to(device) +class BaseLMModel(BaseModel): + """Base class for language models. + """ + def generate(self, inputs, outputs, **kwargs) -> List[str]: + raise AISBenchNotImplementedError( + MODEL_CODES.UNKNOWN_ERROR, + f'{self.__class__.__name__} does not supported' + ' to be called in base classes') + class LMTemplateParser: """Intermidate prompt template parser, specifically for language models. diff --git a/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py b/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py new file mode 100644 index 00000000..b4115533 --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwen_image_edit_mindie_sd.py @@ -0,0 +1,335 @@ +# flake8: noqa +# yapf: disable +import os +import time +from typing import Dict, List, Optional, Union + +import torch +import torch_npu +import base64 +import io +from PIL import Image + +from ais_bench.benchmark.models.local_models.base import BaseLMModel +from ais_bench.benchmark.registry import MODELS +from ais_bench.benchmark.utils.prompt import PromptList +from ais_bench.benchmark.utils.logging import AISLogger +from ais_bench.benchmark.utils.logging.error_codes import UTILS_CODES +from ais_bench.benchmark.models.local_models.huggingface_above_v4_33 import (_convert_chat_messages, + _get_meta_template, + ) + +# 解决 diffuser 0.35.1 torch2.1 报错 +def custom_op( + name, + fn=None, + /, + *, + mutates_args, + device_types=None, + schema=None, + tags=None, +): + def decorator(func): + return func + + if fn is not None: + return decorator(fn) + + return decorator + +def register_fake( + op, + fn=None, + /, + *, + lib=None, + _stacklevel: int = 1, + allow_override: bool = False, +): + def decorator(func): + return func + + if fn is not None: + return decorator(fn) + + return decorator + +if hasattr(torch, 'library'): + torch.library.custom_op = custom_op + torch.library.register_fake = register_fake + +# 导入 qwen_image_edit 相关模块 +try: + from ais_bench.benchmark.models.local_models.qwenimage_edit.transformer_qwenimage import QwenImageTransformer2DModel + from ais_bench.benchmark.models.local_models.qwenimage_edit.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline + from mindiesd import CacheConfig, CacheAgent +except ImportError as e: + raise ImportError(f"请确保 qwenimage_edit 模块在 Python 路径中: {e}") + +PromptType = Union[PromptList, str] + +# 模型推理相关配置常量 +DEFAULT_MODEL_PATH = "/home/yanhe/models/Qwen-Image-Edit-2509/" +DEFAULT_TORCH_DTYPE = "bfloat16" +DEFAULT_DEVICE = "npu" +DEFAULT_DEVICE_ID = 0 +DEFAULT_NUM_INFERENCE_STEPS = 1 # 40 +DEFAULT_TRUE_CFG_SCALE = 4.0 +DEFAULT_GUIDANCE_SCALE = 1.0 +DEFAULT_SEED = 0 +DEFAULT_NUM_IMAGES_PER_PROMPT = 1 +DEFAULT_QUANT_DESC_PATH = None + +# 缓存配置开关 +COND_CACHE = bool(int(os.environ.get('COND_CACHE', 0))) +UNCOND_CACHE = bool(int(os.environ.get('UNCOND_CACHE', 0))) + + +@MODELS.register_module() +class QwenImageEditModel(BaseLMModel): + """Model wrapper for Qwen-Image-Edit-2509 models. + + Args: + path (str): The path to the model. + model_kwargs (dict): Additional model arguments. + sample_kwargs (dict): Additional sampling arguments. + vision_kwargs (dict): Additional vision arguments. + meta_template (Optional[Dict]): The model's meta prompt template. + """ + + def __init__(self, + path: str = DEFAULT_MODEL_PATH, + device_kwargs: dict = dict(), + infer_kwargs: dict = dict(), + meta_template: Optional[Dict] = None, + **other_kwargs): + self.logger = AISLogger() + self.path = path + self.max_out_len = other_kwargs.get('max_out_len', None) + self.template_parser = _get_meta_template(meta_template) + + # 设备配置 + self.device = device_kwargs.get('device', DEFAULT_DEVICE) + #self.device_id = device_kwargs.get('device_id', DEFAULT_DEVICE_ID) + # 在这里声明环境变量 + self.logger.debug(f"device id from kwargs: {device_kwargs.get('device_id', DEFAULT_DEVICE_ID)}") + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = f"{device_kwargs.get('device_id', DEFAULT_DEVICE_ID)}" + self.device_id = DEFAULT_DEVICE_ID + self.device_str = f"{self.device}:{DEFAULT_DEVICE_ID}" + self.logger.debug(f"device_str: {self.device_str}; device_id: {self.device_id}") + self.logger.debug(f"ASCEND_RT_VISIBLE_DEVICES: {os.getenv('ASCEND_RT_VISIBLE_DEVICES')}") + + # 模型配置 + self.torch_dtype = other_kwargs.get('torch_dtype', DEFAULT_TORCH_DTYPE) + self.torch_dtype = torch.bfloat16 if self.torch_dtype == "bfloat16" else torch.float32 + + # 推理配置 + self.num_inference_steps = infer_kwargs.get('num_inference_steps', DEFAULT_NUM_INFERENCE_STEPS) + self.true_cfg_scale = infer_kwargs.get('true_cfg_scale', DEFAULT_TRUE_CFG_SCALE) + self.guidance_scale = infer_kwargs.get('guidance_scale', DEFAULT_GUIDANCE_SCALE) + self.seed = infer_kwargs.get('seed', DEFAULT_SEED) + self.num_images_per_prompt = infer_kwargs.get('num_images_per_prompt', DEFAULT_NUM_IMAGES_PER_PROMPT) + self.quant_desc_path = infer_kwargs.get('quant_desc_path', DEFAULT_QUANT_DESC_PATH) + + # 加载模型 + self._load_model() + + # 缓存配置 + if COND_CACHE or UNCOND_CACHE: + # 保守cache + cache_config = CacheConfig( + method="dit_block_cache", + blocks_count=60, + steps_count=self.num_inference_steps, + step_start=10, + step_interval=3, + step_end=35, + block_start=10, + block_end=50 + ) + self.pipeline.transformer.cache_cond = CacheAgent(cache_config) if COND_CACHE else None + self.pipeline.transformer.cache_uncond = CacheAgent(cache_config) if UNCOND_CACHE else None + self.logger.info("启用缓存配置") + + def _load_model(self): + """加载模型""" + self.logger.info(f"从 {self.path} 加载模型...") + + # 设置设备 + if self.device == "npu": + torch.npu.set_device(self.device_id) + + # 加载 transformer + transformer = QwenImageTransformer2DModel.from_pretrained( + os.path.join(self.path, 'transformer'), + torch_dtype=self.torch_dtype, + device_map=None, # 禁用自动设备映射 + low_cpu_mem_usage=True # 启用CPU低内存模式 + ) + + # 量化配置 + if self.quant_desc_path: + from mindiesd import quantize + self.logger.info("Quantizing Transformer (单独量化核心组件)...") + quantize( + model=transformer, + quant_des_path=self.quant_desc_path, + use_nz=True, + ) + if self.device == "npu": + torch.npu.empty_cache() # 清理NPU显存缓存 + + # 加载 pipeline + self.pipeline = QwenImageEditPlusPipeline.from_pretrained( + self.path, + transformer=transformer, + torch_dtype=self.torch_dtype, + device_map=None, + low_cpu_mem_usage=True + ) + + # VAE优化配置(避免显存溢出) + self.pipeline.vae.use_slicing = True + self.pipeline.vae.use_tiling = True + + # 移动模型到目标设备 + self.pipeline.to(self.device_str) + self.pipeline.set_progress_bar_config(disable=None) # 显示进度条 + + def _get_meta_template(self, meta_template): + """获取元模板""" + class DummyTemplateParser: + def parse_template(self, prompt_template, mode): + return prompt_template + return DummyTemplateParser() + + def _generate(self, input) -> List[Image]: + """Generate result given a input. + + Args: + input (PromptType): A string or PromptDict. + The PromptDict should be organized in AISBench' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + # 处理输入格式 + images = [] + prompts = [] + neg_prompts = [] + print(f"in _generate") + #self.logger.info(f"输入: {input}") + if isinstance(input, str): + prompts.append(input) + neg_prompts.append("") + elif isinstance(input, list): + # 处理包含图像的输入 + for item in input[0]["prompt"]: + if item["type"] == "image_url": + base64_url = item["image_url"]["url"].split(",")[1] + img = Image.open(io.BytesIO(base64.b64decode(base64_url))).convert("RGB") + images.append(img) + elif item["type"] == "text": + prompts.append(item["text"]) + neg_prompts.append("") + else: + prompts.append("") + neg_prompts.append("") + + # 如果没有图像输入,使用默认图像 + if not images: + raise ValueError("QwenImageEditModel requires image input") + + # 执行推理 + results = [] + for prompt, neg_prompt in zip(prompts, neg_prompts): + # 准备输入参数 + print("in _generate loop") + inputs = { + "image": images, + "prompt": prompt, + "negative_prompt": neg_prompt, + "generator": torch.Generator(device=self.device_str).manual_seed(self.seed), + "true_cfg_scale": self.true_cfg_scale, + "guidance_scale": self.guidance_scale, + "num_inference_steps": self.num_inference_steps, + "num_images_per_prompt": self.num_images_per_prompt, + } + + # 执行推理并计时 + if self.device == "npu": + torch.npu.synchronize() # 昇腾设备同步 + start_time = time.time() + + with torch.inference_mode(): + output = self.pipeline(**inputs) + + if self.device == "npu": + torch.npu.synchronize() + end_time = time.time() + infer_time = end_time - start_time + self.logger.info(f"推理完成,耗时: {infer_time:.2f}秒") + + return output + + def encode(self, prompt: str) -> torch.Tensor: + """Encode prompt to tokens. Not necessary for most cases. + + Args: + prompt (str): Input string. + + Returns: + torch.Tensor: Encoded tokens. + """ + raise NotImplementedError(f'{self.__class__.__name__} does not implement `encode` method.') + + def decode(self, tokens: torch.Tensor) -> str: + """Decode tokens to text. Not necessary for most cases. + + Args: + tokens (torch.Tensor): Input tokens. + + Returns: + str: Decoded text. + """ + raise NotImplementedError(f'{self.__class__.__name__} does not implement `decode` method.') + + def get_token_len(self, prompt: str) -> int: + """Get lengths of the tokenized strings. + + Args: + prompt (str): Input string. + + Returns: + int: Length of the input tokens + """ + # 对于图像编辑模型,token长度计算可能不同,这里返回一个默认值 + return len(prompt.split()) + + def generate(self, inputs, outputs, **kwargs): + """Generate completion from inputs. + + Args: + inputs: Inputs for generation. + max_out_lens: Maximum output lengths. + **kwargs: Additional keyword arguments. + + Returns: + List[str]: Generated completions. + """ + #self.logger.info(f"model {inputs=}") + if not isinstance(inputs, list): + inputs = [inputs] + + for i, input in enumerate(inputs): + result = self._generate(input) + # result is QwenImagePipelineOutput with 'images' attribute + if hasattr(result, 'images') and result.images: + outputs[i].success = True + outputs[i].content = result.images # 将图像列表赋值给 content + else: + outputs[i].success = False + outputs[i].content = [""] diff --git a/ais_bench/benchmark/models/output.py b/ais_bench/benchmark/models/output.py index f0676ae3..935466f5 100644 --- a/ais_bench/benchmark/models/output.py +++ b/ais_bench/benchmark/models/output.py @@ -1,5 +1,9 @@ +import os import time from abc import abstractmethod +from typing import Union + +from PIL import Image import numpy as np @@ -174,4 +178,62 @@ def update_extra_details_data_from_text_response(self, text_response: dict) -> N for item in text_response.get("choices", []): message = item.get("message", {}) self.extra_details_data["message"] = message - return # only one message is allowed \ No newline at end of file + return # only one message is allowed + + +LLM_META_DATA_TYPE = Union[Image, str] + + +class LMMOutput(Output): + def __init__(self, perf_mode: bool = False) -> None: + super().__init__(perf_mode) + self.content: list[LLM_META_DATA_TYPE] = [""] + self.HANDLER_MAP = { + Image.Image: self._handle_image, + str: self._handle_text, + } + + def get_prediction(self, save_dir: str) -> dict: + """Get the final prediction by combining content and reasoning. + + Returns: + dict: Combined prediction content + """ + output = [] + for i, item in enumerate(self.content): + output.append(self.HANDLER_MAP[type(item)](save_dir, i)) + if len(output) == 1: + return output[0] + else: + return output + + def _handle_image(self, save_dir: str, index: int) -> str: + """Handle image content. + + Args: + save_dir: Directory to save image + index: Index of image in content list + + Returns: + str: Last two levels of image path + """ + image = self.content[index] + image_path = os.path.join(save_dir, f"image_{self.uuid}_{index}.png") + if os.path.exists(image_path): + os.remove(image_path) + image.save(image_path) + return os.path.join(*image_path.split(os.sep)[-2:]) + + def _handle_text(self, save_dir: str, index: int) -> str: + """Handle text content. + + Args: + save_dir: Directory to save text + index: Index of text in content list + + Returns: + str: Text content + """ + return self.content[index] + + diff --git a/ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py b/ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py new file mode 100644 index 00000000..48604ba9 --- /dev/null +++ b/ais_bench/benchmark/openicl/icl_inferencer/icl_lmm_gen_inferencer.py @@ -0,0 +1,75 @@ +''' +Author: SJTUyh yh_silence@alumni.sjtu.edu.cn +Date: 2026-02-11 16:38:01 +LastEditors: SJTUyh yh_silence@alumni.sjtu.edu.cn +LastEditTime: 2026-02-12 18:39:02 +FilePath: \benchmark\ais_bench\benchmark\openicl\icl_inferencer\icl_lmm_gen_inferencer.py +Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE +''' +import uuid +from typing import List, Optional + +from ais_bench.benchmark.models.output import LMMOutput +from ais_bench.benchmark.registry import ICL_INFERENCERS +from ais_bench.benchmark.openicl.icl_retriever import BaseRetriever +from ais_bench.benchmark.openicl.icl_inferencer.icl_gen_inferencer import GenInferencer +from ais_bench.benchmark.openicl.icl_inferencer.output_handler.lmm_gen_inferencer_output_handler import LMMGenInferencerOutputHandler + + +@ICL_INFERENCERS.register_module() +class LMMGenInferencer(GenInferencer): + def __init__( + self, + model_cfg, + stopping_criteria: List[str] = [], + batch_size: Optional[int] = 1, + mode: Optional[str] = "infer", + gen_field_replace_token: Optional[str] = "", + output_json_filepath: Optional[str] = "./icl_inference_output", + save_every: Optional[int] = 1, + **kwargs, + ) -> None: + super().__init__( + model_cfg=model_cfg, + stopping_criteria=stopping_criteria, + batch_size=batch_size, + mode=mode, + gen_field_replace_token=gen_field_replace_token, + output_json_filepath=output_json_filepath, + save_every=save_every, + **kwargs, + ) + + self.output_handler = LMMGenInferencerOutputHandler(perf_mode=self.perf_mode, + save_every=self.save_every) + def inference(self, retriever: BaseRetriever, output_json_filepath: Optional[str] = None) -> List: + self.output_handler.set_output_path(output_json_filepath) + return super().inference(retriever, output_json_filepath) + + def batch_inference( + self, + datum, + ) -> None: + """Perform batch inference on the given dataloader. + + Args: + dataloader: DataLoader containing the inference data + + Returns: + List of inference results + """ + indexs = datum.pop("index") + inputs = datum.pop("prompt") + data_abbrs = datum.pop("data_abbr") + outputs = [LMMOutput(self.perf_mode) for _ in range(len(indexs))] + for output in outputs: + output.uuid = str(uuid.uuid4()).replace("-", "") + golds = datum.pop("gold", [None] * len(inputs)) + self.model.generate(inputs, outputs, **datum) + + for index, input, output, data_abbr, gold in zip( + indexs, inputs, outputs, data_abbrs, golds + ): + self.output_handler.report_cache_info_sync( + index, input, output, data_abbr, gold + ) \ No newline at end of file diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/__init__.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/base_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/base_handler.py index 42cef866..9d2a450b 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/base_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/base_handler.py @@ -56,7 +56,14 @@ def __init__(self, perf_mode: bool = False, save_every: int = 100) -> None: self.save_every = save_every @abstractmethod - def get_prediction_result(self, output: Union[str, Output], gold: Optional[str] = None, input: Optional[Union[str, List[str]]] = None) -> dict: + def get_prediction_result( + self, + output: Union[str, Output], + gold: Optional[str] = None, + input: Optional[Union[str, List[str]]] = None, + data_abbr: Optional[str] = "" + + ) -> dict: """ Get the prediction result. @@ -64,7 +71,7 @@ def get_prediction_result(self, output: Union[str, Output], gold: Optional[str] output (Union[str, Output]): Output result from inference gold (Optional[str]): Ground truth data for comparison input (Optional[Union[str, List[str]]]): Input data for the inference - + data_abbr (Optional[str]): Abbreviation of the dataset Returns: dict: Prediction result """ @@ -74,6 +81,7 @@ def get_prediction_result(self, output: Union[str, Output], gold: Optional[str] def get_result( self, conn: sqlite3.Connection, + data_abbr: str, input: Union[str, List[str]], output: Union[str, Output], gold: Optional[str] = None, @@ -113,7 +121,7 @@ def get_result( if gold: result_data["gold"] = gold else: - result_data = self.get_prediction_result(output, gold=gold, input=input) + result_data = self.get_prediction_result(output, gold=gold, input=input, data_abbr=data_abbr) if not result_data.get("success", True): self.all_success = False if isinstance(output, Output) and hasattr(output, "error_info"): @@ -365,7 +373,7 @@ def run_cache_consumer( try: uid = str(uuid.uuid4())[:8] - result_data = self.get_result(conn, *item[2:]) + result_data = self.get_result(conn, *item[1:]) id, data_abbr = item[0], item[1] json_data = { "data_abbr": data_abbr, diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/bfcl_v3_output_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/bfcl_v3_output_handler.py index 0b47eec6..8bf8fbb7 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/bfcl_v3_output_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/bfcl_v3_output_handler.py @@ -11,7 +11,13 @@ class BFCLV3OutputHandler(BaseInferencerOutputHandler): """ Output handler for BFCLV3 inference tasks. """ - def get_prediction_result(self, output: FunctionCallOutput, gold: Optional[str] = None, input: Optional[Union[str, List[str]]] = None) -> dict: + def get_prediction_result( + self, + output: FunctionCallOutput, + gold: Optional[str] = None, + input: Optional[Union[str, List[str]]] = None, + data_abbr: Optional[str] = "" + ) -> dict: """ Get the prediction result for BFCLV3 inference tasks. @@ -19,6 +25,8 @@ def get_prediction_result(self, output: FunctionCallOutput, gold: Optional[str] output (FunctionCallOutput): Output result from inference gold (Optional[str]): Ground truth data for comparison input (Optional[Union[str, List[str]]]): Input data for the inference (not used in this implementation) + data_abbr (Optional[str]): Abbreviation of the dataset (not used in this implementation) + Returns: dict: Prediction result containing success, uuid, prediction (tool_calls), and inference_log Raises: diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py index 111799d2..726f841c 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/gen_inferencer_output_handler.py @@ -27,6 +27,7 @@ def get_prediction_result( output: Union[str, Output], gold: Optional[str] = None, input: Optional[Union[str, List[str]]] = None, + data_abbr: Optional[str] = "", ) -> dict: """ Get the prediction result for accuracy mode. @@ -35,6 +36,7 @@ def get_prediction_result( output (Union[str, Output]): Output result from inference gold (Optional[str]): Ground truth data for comparison input (Optional[Union[str, List[str]]]): Input data for the inference + data_abbr (Optional[str]): Abbreviation of the dataset Returns: dict: Prediction result diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py new file mode 100644 index 00000000..70a05336 --- /dev/null +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/lmm_gen_inferencer_output_handler.py @@ -0,0 +1,72 @@ +from typing import List, Optional, Union +import sqlite3 +import uuid +from pathlib import Path + +from ais_bench.benchmark.openicl.icl_inferencer.output_handler.base_handler import BaseInferencerOutputHandler +from ais_bench.benchmark.models.output import LMMOutput +from ais_bench.benchmark.utils.logging.error_codes import ICLI_CODES +from ais_bench.benchmark.utils.logging.exceptions import AISBenchImplementationError + +class LMMGenInferencerOutputHandler(BaseInferencerOutputHandler): + """ + Output handler for generation-based inference tasks. + + This handler specializes in processing generation model outputs, + supporting both performance measurement and accuracy evaluation modes. + It handles different data formats and provides appropriate result storage. + + Attributes: + all_success (bool): Flag indicating if all operations were successful + perf_mode (bool): Whether in performance measurement mode + cache_queue (queue.Queue): Queue for caching results before writing + """ + def set_output_path(self, output_path: str) -> None: + self.output_path = output_path + + def get_prediction_result( + self, + output: Union[str, LMMOutput], + gold: Optional[str] = None, + input: Optional[Union[str, List[str]]] = None, + data_abbr: Optional[str] = "", + ) -> dict: + """ + Get the prediction result for accuracy mode. + + Args: + output (Union[str, LMMOutput]): Output result from inference + gold (Optional[str]): Ground truth data for comparison + input (Optional[Union[str, List[str]]]): Input data for the inference + data_abbr (Optional[str]): Abbreviation of the dataset + + Returns: + dict: Prediction result + """ + try: + save_dir = Path(self.output_path) / f"{data_abbr}_out_file" + if not save_dir.exists(): + save_dir.mkdir(parents=True, exist_ok=True) + for item in input[0]['prompt']: + if item.get('image_url'): + item['image_url']['url'] = item['image_url']['url'][:256] + result_data = { + "success": ( + output.success if isinstance(output, LMMOutput) else True + ), + "uuid": output.uuid if isinstance(output, LMMOutput) else str(uuid.uuid4()).replace("-", ""), + "origin_prompt": input if input is not None else "", + "prediction": ( + output.get_prediction(save_dir) + if isinstance(output, LMMOutput) + else output + ), + } + if gold: + result_data["gold"] = gold + except Exception as e: + import traceback + print(f"[ERROR] LMMGenInferencerOutputHandler.get_prediction_result failed: {type(e).__name__}: {e}") + print(f"[ERROR] Traceback: {traceback.format_exc()}") + raise RuntimeError(f"Failed to get prediction result: {e}") + return result_data \ No newline at end of file diff --git a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/ppl_inferencer_output_handler.py b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/ppl_inferencer_output_handler.py index bf5ac30e..44da5be5 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/output_handler/ppl_inferencer_output_handler.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/output_handler/ppl_inferencer_output_handler.py @@ -46,7 +46,25 @@ def __init__(self, perf_mode: bool = False, save_every: int = 100) -> None: super().__init__(save_every) self.perf_mode = perf_mode - def get_prediction_result(self, output: Union[str, PPLResponseOutput], gold: Optional[str] = None, input: Union[str, List[str]] = None) -> dict: + def get_prediction_result( + self, + output: Union[str, PPLResponseOutput], + gold: Optional[str] = None, + input: Optional[Union[str, List[str]]] = None, + data_abbr: Optional[str] = "" + ) -> dict: + """ + Get the prediction result for performance mode. + + Args: + output (Union[str, PPLResponseOutput]): Model output + gold (Optional[str]): Ground truth data for comparison + input (Optional[Union[str, List[str]]]): Input data for the inference + data_abbr (Optional[str]): Abbreviation of the dataset + + Returns: + dict: Prediction result + """ if not isinstance(output, PPLResponseOutput): raise AISBenchImplementationError(ICLI_CODES.UNKNOWN_ERROR, f"Output is not a PPLResponseOutput") result_data = { diff --git a/ais_bench/benchmark/openicl/icl_prompt_template/icl_prompt_template_mm.py b/ais_bench/benchmark/openicl/icl_prompt_template/icl_prompt_template_mm.py index 83d42a4e..34562b5f 100644 --- a/ais_bench/benchmark/openicl/icl_prompt_template/icl_prompt_template_mm.py +++ b/ais_bench/benchmark/openicl/icl_prompt_template/icl_prompt_template_mm.py @@ -39,7 +39,7 @@ def check_mm_template(self): if "prompt_mm" not in data.keys(): return False return True - + def format_mm_url(self, template, entry): """ for mm_custom dataset @@ -103,6 +103,7 @@ def generate_item( template = self.format_mm_url(self.template, entry) template = self._encode_template(template, ice=False) template = template.format_mm(**entry) + for i, item in enumerate(template): if "prompt_mm" in item: template[i]["prompt_mm"] = self.get_mm_template(item) diff --git a/ais_bench/benchmark/utils/image_process.py b/ais_bench/benchmark/utils/image_process.py new file mode 100644 index 00000000..0863dcdf --- /dev/null +++ b/ais_bench/benchmark/utils/image_process.py @@ -0,0 +1,14 @@ +import base64 +from io import BytesIO +from PIL import Image + +def pil_to_base64(image, format="JPEG"): + """ + Convert PIL Image to base64 string + """ + if not isinstance(image, Image.Image): + raise ValueError("Input must be a PIL Image object") + buffered = BytesIO() + image.save(buffered, format) + img_str = base64.b64encode(buffered.getvalue()).decode() + return img_str \ No newline at end of file diff --git a/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py b/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py new file mode 100644 index 00000000..467ed0b3 --- /dev/null +++ b/ais_bench/configs/lmm_exmaple/multi_device_run_qwen_image_edit.py @@ -0,0 +1,31 @@ +from mmengine.config import read_base + +with read_base(): + from ais_bench.benchmark.configs.models.lmm_models.qwen_image_edit import models as qwen_image_edit_models + from ais_bench.benchmark.configs.summarizers.example import summarizer + from ais_bench.benchmark.configs.datasets.gedit.gedit_gen import gedit_datasets + +device_list = [0, 1, 2, 3] + +datasets = [] +models = [] +model_dataset_combinations = [] + +for i in device_list: + model_config = {k: v for k, v in qwen_image_edit_models[0].items()} + model_config['abbr'] = f"{model_config['abbr']}-{i}" + model_config['device_kwargs'] = dict(model_config['device_kwargs']) + model_config['device_kwargs']['device_id'] = i + models.append(model_config) + + dataset_config = {k: v for k, v in gedit_datasets[0].items()} + dataset_config['abbr'] = f"{dataset_config['abbr']}-{i}" + dataset_config['split_count'] = len(device_list) + dataset_config['split_index'] = i + datasets.append(dataset_config) + + # 关键:为每个设备创建一个独立的 model-dataset 组合 + model_dataset_combinations.append({ + 'models': [model_config], # 只包含当前模型 + 'datasets': [dataset_config] # 只包含当前数据集 + }) \ No newline at end of file From ce167ed1a6a505086d4921417a3065309560428e Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Fri, 13 Feb 2026 15:15:23 +0800 Subject: [PATCH 4/4] add qwen image edit dep --- .../local_models/qwenimage_edit/__init__.py | 0 .../local_models/qwenimage_edit/attn_layer.py | 201 ++++ .../qwenimage_edit/distributed/__init__.py | 0 .../qwenimage_edit/distributed/all_to_all.py | 156 +++ .../distributed/group_coordinator.py | 640 ++++++++++++ .../distributed/parallel_mgr.py | 404 ++++++++ .../qwenimage_edit/distributed/utils.py | 152 +++ .../pipeline_qwenimage_edit_plus.py | 964 ++++++++++++++++++ .../scheduling_flow_match_euler_discrete.py | 563 ++++++++++ .../qwenimage_edit/transformer_qwenimage.py | 792 ++++++++++++++ 10 files changed, 3872 insertions(+) create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/__init__.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/attn_layer.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/__init__.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/all_to_all.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/group_coordinator.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/parallel_mgr.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/utils.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/pipeline_qwenimage_edit_plus.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/scheduling_flow_match_euler_discrete.py create mode 100644 ais_bench/benchmark/models/local_models/qwenimage_edit/transformer_qwenimage.py diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/__init__.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/attn_layer.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/attn_layer.py new file mode 100644 index 00000000..2d0e58e7 --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/attn_layer.py @@ -0,0 +1,201 @@ +import torch +from torch import Tensor +import torch_npu + +import torch.distributed as dist +from yunchang import LongContextAttention +try: + from yunchang.kernels import AttnType +except ImportError: + raise ImportError("Please install yunchang 0.6.0 or later") + + +import math +import os +from typing import Any + +from mindiesd import attention_forward + + + +# from yunchang.comm.all_to_all import SeqAllToAll4D +# from yunchang.globals import HAS_SPARSE_SAGE_ATTENTION + +from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.all_to_all import SeqAllToAll4D +import logging + +from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.parallel_mgr import ( + get_sequence_parallel_world_size, + get_sequence_parallel_rank, + get_sp_group +) + + +logger = logging.getLogger(__name__) +MAX_TOKEN = 2147483647 + + +class xFuserLongContextAttention_new4(LongContextAttention): + ring_impl_type_supported_kv_cache = ["basic"] + + def __init__( + self, + scatter_idx: int = 2, + gather_idx: int = 1, + ring_impl_type: str = "basic", + use_pack_qkv: bool = False, + use_kv_cache: bool = False, + use_sync: bool = False, + attn_type: AttnType = AttnType.FA, + attn_processor: torch.nn.Module = None, + q_descale=None, + k_descale=None, + v_descale=None, + ) -> None: + """ + Arguments: + scatter_idx: int = 2, the scatter dimension index for Ulysses All2All + gather_idx: int = 1, the gather dimension index for Ulysses All2All + ring_impl_type: str = "basic", the ring implementation type, currently only support "basic" + use_pack_qkv: bool = False, whether to use pack qkv in the input + use_kv_cache: bool = False, whether to use kv cache in the attention layer, which is applied in PipeFusion. + attn_type: AttnType = AttnType.FA, the attention type supported inside long context attention, including "TORCH", "FA", "FA3", "SAGE_FP16", "SAGE_FP8" + attn_processor: nn.Module = None, the attention processor can be passed in to replace the attention processor if attn_type is do not support it. + """ + super().__init__( + scatter_idx=scatter_idx, + gather_idx=gather_idx, + ring_impl_type=ring_impl_type, + use_pack_qkv=use_pack_qkv, + use_sync=use_sync, + attn_type = attn_type, + ) + self.use_kv_cache = use_kv_cache + self.q_descale = q_descale + self.k_descale = k_descale + self.v_descale = v_descale + + # 校验:仅"basic"类型的环形实现支持KV缓存 + if ( + use_kv_cache + and ring_impl_type not in self.ring_impl_type_supported_kv_cache + ): + raise RuntimeError( + f"ring_impl_type: {ring_impl_type} do not support SP kv cache." + ) + + self.attn_processor = attn_processor + + @torch.compiler.disable + def forward( + self, + attn, + query: Tensor, # [B, S_image/ulysses_size, H, D] + key: Tensor, + value: Tensor, + *, + joint_tensor_query=None, # [B, S_text, H, D] + joint_tensor_key=None, + joint_tensor_value=None, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + joint_strategy="none", + txt_pad_len = 0 + ) -> Tensor: + """forward + + Arguments: + attn (Attention): the attention module + query (Tensor): query input to the layer + key (Tensor): key input to the layer + value (Tensor): value input to the layer + args: other args, + joint_tensor_query: Tensor = None, a replicated tensor among processes appended to the front or rear of query, depends the joint_strategy + joint_tensor_key: Tensor = None, a replicated tensor among processes appended to the front or rear of key, depends the joint_strategy + joint_tensor_value: Tensor = None, a replicated tensor among processes appended to the front or rear of value, depends the joint_strategy, + *args: the args same as flash_attn_interface + joint_strategy: str = "none", the joint strategy for joint attention, currently only support "front" and "rear" + + Returns: + * output (Tensor): context output + """ + + + sp_world_size = get_sequence_parallel_world_size() # USP + sp_rank = get_sequence_parallel_rank() + + + + joint_tensor_query = torch.chunk(joint_tensor_query, sp_world_size, dim=2)[sp_rank] # [B, S_text, H, D] --> [B, S_text, H/ulysses_size, D] + joint_tensor_key = torch.chunk(joint_tensor_key, sp_world_size, dim=2)[sp_rank] + joint_tensor_value = torch.chunk(joint_tensor_value, sp_world_size, dim=2)[sp_rank] + + + + # 3 X (bs, seq_len/N, head_cnt, head_size) -> 3 X (bs, seq_len, head_cnt/N, head_size) + # scatter 2, gather 1 + if self.use_pack_qkv: + # (3*bs, seq_len/N, head_cnt, head_size) + qkv = torch.cat([query, key, value]).contiguous() + # (3*bs, seq_len, head_cnt/N, head_size) + qkv = SeqAllToAll4D.apply( + self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx, + ) + qkv = torch.chunk(qkv, 3, dim=0) + query_layer, key_layer, value_layer = qkv + + else: + # 非打包模式:分别对Q/K/V进行通信拆分 + query_layer = SeqAllToAll4D.apply( + self.ulysses_pg, query, self.scatter_idx, self.gather_idx , # [B, S_image/ulysses_size, H, D] --> [B, S_image, H/ulysses_size, D] + ) + key_layer = SeqAllToAll4D.apply( + self.ulysses_pg, key, self.scatter_idx, self.gather_idx, + ) + value_layer = SeqAllToAll4D.apply( + self.ulysses_pg, value, self.scatter_idx, self.gather_idx, + ) + + # Concatenate for joint attention + # Order: [text, image] + joint_query = torch.cat([joint_tensor_query, query_layer], dim=1) # (B, S_txt + S_img, H/ulysses_size, D_head) + joint_key = torch.cat([joint_tensor_key, key_layer], dim=1) + joint_value = torch.cat([joint_tensor_value, value_layer], dim=1) + + + out = attention_forward( + joint_query, + joint_key, + joint_value, + opt_mode="manual", + op_type="fused_attn_score", + layout="BNSD" + ) + + if type(out) == tuple: + context_layer, _, _ = out + else: + context_layer = out + + txt_seq_len = joint_tensor_query.shape[1] + + text_out = context_layer[:, :txt_seq_len, :, :].contiguous() # 强制连续 + image_out = context_layer[:, txt_seq_len:, :, :].contiguous() + + # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) + # scatter 1, gather 2 + image_out = SeqAllToAll4D.apply( + self.ulysses_pg, image_out, self.gather_idx, self.scatter_idx # [B, S_image, H/ulysses_size, D] --> [B, S_image/ulysses_size, H, D] + ) + + text_out = get_sp_group().all_gather(text_out, dim=2) # (B, S_txt , H/ulysses_size, D_head) --> (B, S_txt , H, D_head) + + output = torch.cat([text_out, image_out], dim=1) + # out e.g., [s/p::h] + return output + diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/__init__.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/all_to_all.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/all_to_all.py new file mode 100644 index 00000000..2ffea2f1 --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/all_to_all.py @@ -0,0 +1,156 @@ +# Copyright (c) Microsoft Corporation and Jiarui Fang +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from typing import Any, Tuple +from torch import Tensor +from torch.nn import Module + +import torch.distributed as dist + + +def all_to_all_4D( + input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None, use_sync: bool = False +) -> torch.tensor: + """ + all-to-all for QKV + + Args: + input (torch.tensor): a tensor sharded along dim scatter dim + scatter_idx (int): default 1 + gather_idx (int): default 2 + group : torch process group + use_sync (bool): whether to synchronize after all-to-all + + Returns: + torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) + """ + assert ( + input.dim() == 4 + ), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}" + + seq_world_size = dist.get_world_size(group) + + # 分支 1:scatter_idx=2 且 gather_idx=1(Ulysses 并行的 “拆分多头” 场景),按「多头维度(dim2)」拆分张量,同时将「序列维度(dim1)」重组为完整长度。 + if scatter_idx == 2 and gather_idx == 1: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs) + bs, shard_seqlen, hc, hs = input.shape + seqlen = shard_seqlen * seq_world_size + shard_hc = hc // seq_world_size + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs) + input_t = ( + input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs) + .transpose(0, 2) + .contiguous() + ) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head + + if seq_world_size > 1: + if use_sync: + dist.all_to_all_single(output, input_t, group=group) + else: + comm = dist.all_to_all_single(output, input_t, group=group, async_op=True) + + def getter(): + comm.wait() + comm_output = output.reshape(seqlen, bs, shard_hc, hs) + + # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) + comm_output = comm_output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs) + return comm_output + + return getter + else: + output = input_t + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(seqlen, bs, shard_hc, hs) + + # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) + output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs) + + return output + + # 分支 2:scatter_idx=1 且 gather_idx=2(Ulysses 并行的 “合并多头” 场景),与分支 1 相反,按「序列维度(dim1)」拆分张量,同时将「多头维度(dim2)」重组为完整多头数。 + elif scatter_idx == 1 and gather_idx == 2: + # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) + bs, seqlen, shard_hc, hs = input.shape + hc = shard_hc * seq_world_size + shard_seqlen = seqlen // seq_world_size + seq_world_size = dist.get_world_size(group) + + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs) + input_t = ( + input.reshape(bs, seq_world_size, shard_seqlen, shard_hc, hs) + .transpose(0, 3) + .transpose(0, 1) + .contiguous() + .reshape(seq_world_size, shard_hc, shard_seqlen, bs, hs) + ) + + output = torch.empty_like(input_t) + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single + # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head + if seq_world_size > 1: + if use_sync: + dist.all_to_all_single(output, input_t, group=group) + else: + comm = dist.all_to_all_single(output, input_t, group=group, async_op=True) + + def getter(): + comm.wait() + comm_output = output.reshape(hc, shard_seqlen, bs, hs) + + # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) + comm_output = comm_output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs) + return comm_output + + return getter + else: + output = input_t + + # if scattering the seq-dim, transpose the heads back to the original dimension + output = output.reshape(hc, shard_seqlen, bs, hs) + + # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) + output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs) + + return output + else: + raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") + + +class SeqAllToAll4D(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + input: Tensor, + scatter_idx: int, + gather_idx: int, + use_sync: bool = False, + ) -> Tensor: + + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + ctx.use_sync = use_sync + return all_to_all_4D(input, scatter_idx, gather_idx, group=group, use_sync=use_sync) + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: + return ( + None, + SeqAllToAll4D.apply( + ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync + ), + None, + None, + None, + ) \ No newline at end of file diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/group_coordinator.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/group_coordinator.py new file mode 100644 index 00000000..c48d22a6 --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/group_coordinator.py @@ -0,0 +1,640 @@ +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +from collections import namedtuple +from typing import Any, Dict, List, Optional, Tuple, Union +import pickle + +import torch +import torch_npu +import torch.distributed +from torch.distributed import Backend, ProcessGroup + +import logging + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]], prefix: str = "" +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + + If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its + metadata will be "key1%key2". + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list = [] + for key, value in tensor_dict.items(): + if "%" in key: + logging.error( + "Avoid having '%' in key " + "as it is used as a separator for nested entries." + ) + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "npu:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (prefix + key, TensorMetadata(device, value.dtype, value.size())) + ) + tensor_list.append(value) + elif isinstance(value, dict): + if len(value) == 0: + metadata_list.append((prefix + key, value)) + inner_metadata_list, inner_tensor_list = _split_tensor_dict( + value, prefix + key + "%" + ) + metadata_list.extend(inner_metadata_list) + tensor_list.extend(inner_tensor_list) + else: + metadata_list.append((prefix + key, value)) + return metadata_list, tensor_list + + +def _update_nested_dict(nested_dict, flattened_key, value): + key_splits = flattened_key.split("%") + cur_dict = nested_dict + for k in key_splits[:-1]: + if k not in cur_dict: + cur_dict[k] = {} + cur_dict = cur_dict[k] + cur_dict[key_splits[-1]] = value + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and npu graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + ): + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + + # 原代码(导致超时) + # cpu_group = torch.distributed.new_group(ranks, backend="gloo") + + # 修改后(使用HCCL后端) + cpu_group = torch.distributed.new_group(ranks, backend="hccl") # 适配昇腾环境 + + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + if torch.npu.is_available(): + self.device = torch.device(f"npu:{local_rank}") + else: + self.device = torch.device("cpu") + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @property + def group_next_rank(self): + """Return the group rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group + 1) % world_size + + @property + def group_prev_rank(self): + """Return the group rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group - 1) % world_size + + @property + def skip_rank(self): + """Return the global rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(world_size - rank_in_group - 1) % world_size] + + @property + def group_skip_rank(self): + """Return the group rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (world_size - rank_in_group - 1) % world_size + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + NOTE: This operation will be applied in-place or out-of-place. + Always assume this function modifies its input, but use the return + value as the output. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + else: + torch.distributed.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather( + self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False, async_op: bool = False + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + input_size = list(input_.size()) + input_size[0] *= world_size + output_tensor = torch.empty( + input_size, dtype=input_.dtype, device=input_.device + ) + + # All-gather. + if async_op: + current_input_size = input_size.copy() # 复制列表 + current_world_size = world_size + current_dim = dim + current_separate_tensors = separate_tensors + comm = torch.distributed.all_gather_into_tensor( # ljf 报错 + output_tensor, input_, group=self.device_group, async_op=async_op + ) + + def getter(): + comm.wait() + nonlocal output_tensor # 声明为非局部变量 + + if current_dim != 0: + # 使用捕获的变量,而不是外部变量 + temp_size = current_input_size + temp_size[0] //= current_world_size + output_tensor = output_tensor.reshape([current_world_size] + temp_size) + output_tensor = output_tensor.movedim(0, current_dim) + + if current_separate_tensors: + tensor_list = [ + output_tensor.view(-1) + .narrow(0, input_.numel() * i, input_.numel()) + .view_as(input_) + for i in range(current_world_size) + ] + return tensor_list + else: + current_input_size[current_dim] = current_input_size[current_dim] * world_size + # Reshape + output_tensor = output_tensor.reshape(current_input_size) + return output_tensor + + return getter + else: + torch.distributed.all_gather_into_tensor( # ljf 报错 + output_tensor, input_, group=self.device_group + ) + if dim != 0: + input_size[0] //= world_size + output_tensor = output_tensor.reshape([world_size, ] + input_size) + output_tensor = output_tensor.movedim(0, dim) + + if separate_tensors: + tensor_list = [ + output_tensor.view(-1) + .narrow(0, input_.numel() * i, input_.numel()) + .view_as(input_) + for i in range(world_size) + ] + return tensor_list + else: + input_size = list(input_.size()) + input_size[dim] = input_size[dim] * world_size + # Reshape + output_tensor = output_tensor.reshape(input_size) + return output_tensor + + def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast( + input_, src=self.ranks[src], group=self.device_group + ) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.shm_broadcaster is not None: + return self.shm_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list( + [obj], src=self.ranks[src], group=self.cpu_group + ) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=self.ranks[src], group=self.cpu_group + ) + return recv[0] + + def broadcast_object_list( + self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None + ): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list( + obj_list, src=self.ranks[src], group=self.device_group + ) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor( + [object_tensor.numel()], dtype=torch.long, device="cpu" + ) + + # Send object size + + torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv( + size_tensor, src=self.ranks[src], group=self.cpu_group + ) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu", + ) + + rank_object = torch.distributed.recv( + object_tensor, src=self.ranks[src], group=self.cpu_group + ) + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + src = self.ranks[src] + + rank = self.rank + if rank == src: + metadata_list: List[Tuple[Any, Any]] = [] + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = self.group_next_rank + + metadata_list: List[Tuple[Any, Any]] = [] + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send( + tensor, dst=self.ranks[dst], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) + return None + + def recv_tensor_dict( + self, src: Optional[int] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = self.group_prev_rank + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv( + tensor, src=self.ranks[src], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=group) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the rank_in_group of the destination rank.""" + if dst is None: + dst = self.group_next_rank + + torch.distributed.send( + tensor, + self.ranks[dst], + group=( + self.device_groups[self.rank_in_group % 2] + if self.world_size == 2 + else self.device_group + ), + ) + + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the rank_in_group of the source rank.""" + if src is None: + src = self.group_prev_rank + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv( + tensor, + self.ranks[src], + ( + self.device_groups[(self.rank_in_group + 1) % 2] + if self.world_size == 2 + else self.device_group + ), + ) + return tensor + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + + +class SequenceParallelGroupCoordinator(GroupCoordinator): + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + **kwargs, + ): + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + ) + self.ulysses_group = kwargs.get("ulysses_group", None) + self.ulysses_world_size = torch.distributed.get_world_size(self.ulysses_group) + self.ulysses_rank = torch.distributed.get_rank(self.ulysses_group) + + self.ring_group = kwargs.get("ring_group", None) + self.ring_world_size = torch.distributed.get_world_size(self.ring_group) + self.ring_rank = torch.distributed.get_rank(self.ring_group) \ No newline at end of file diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/parallel_mgr.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/parallel_mgr.py new file mode 100644 index 00000000..0b6ef343 --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/parallel_mgr.py @@ -0,0 +1,404 @@ +import os +from typing import List, Optional +from dataclasses import dataclass +import torch.distributed as dist +import torch_npu +import logging +from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.utils import RankGenerator, generate_masked_orthogonal_rank_groups +from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.group_coordinator import GroupCoordinator, SequenceParallelGroupCoordinator + +#--------- ljf ------------------- +import torch +import torch.distributed +try: + import torch_musa + from torch_musa.core.device import set_device, device_count +except ModuleNotFoundError: + pass +#--------------------------- + +from yunchang import set_seq_parallel_pg +from yunchang.globals import PROCESS_GROUP + +_WORLD: Optional[GroupCoordinator] = None +_TP: Optional[GroupCoordinator] = None +_SP: Optional[SequenceParallelGroupCoordinator] = None +_CFG: Optional[GroupCoordinator] = None + + +@dataclass +class ParallelConfig: + tp_degree: int = 1 + sp_degree: int = 1 + ulysses_degree: int = 1 + ring_degree: int = 1 + use_cfg_parallel: bool = False + world_size: int = 1 + + def __post_init__(self): + if self.use_cfg_parallel: + self.cfg_degree = 2 + else: + self.cfg_degree = 1 + if not self.tp_degree * self.sp_degree * self.cfg_degree <= self.world_size: + logging.error( + "tp_degree * sp_degree * cfg_degree must be less than or equal to " + "world_size because of classifier free guidance" + ) + if not (self.world_size % (self.tp_degree * self.sp_degree * self.cfg_degree) == 0): + logging.error("world_size must be divisible by tp_degree * sp_degree * cfg_degree") + + +# * QUERY +def get_world_group() -> GroupCoordinator: + if _WORLD is None: + logging.error("world group is not initialized") + return _WORLD + + +# TP +def get_tp_group() -> GroupCoordinator: + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return get_tp_group().world_size + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return get_tp_group().rank_in_group + + +# SP +def get_sp_group() -> SequenceParallelGroupCoordinator: + if _SP is None: + logging.error("pipeline model parallel group is not initialized") + return _SP + + +def get_sequence_parallel_state(): + """Return state for the sequence parallel group.""" + return _SP is not None + + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + if not get_sequence_parallel_state(): + return 1 + return get_sp_group().world_size + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + if not get_sequence_parallel_state(): + return 0 + return get_sp_group().rank_in_group + + +# CFG +def get_cfg_group() -> GroupCoordinator: + if _CFG is None: + logging.error("classifier_free_guidance parallel group is not initialized") + return _CFG + + +def get_cfg_state(): + """Return state for the sequence parallel group.""" + return _CFG is not None + + +def get_classifier_free_guidance_world_size(): + """Return world size for the classifier_free_guidance parallel group.""" + if not get_cfg_state(): + return 1 + return get_cfg_group().world_size + + +def get_classifier_free_guidance_rank(): + """Return my rank for the classifier_free_guidance parallel group.""" + if not get_cfg_state(): + return 0 + return get_cfg_group().rank_in_group + + +def init_world_group( + ranks: List[int], local_rank: int, backend: str +) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + ) + +# wan2.1 的 +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "hccl", +): + logging.debug( + "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) + if not dist.is_initialized(): + if distributed_init_method is None: + logging.error( + "distributed_init_method must be provided when initializing " + "distributed environment" + ) + # this backend is used for WORLD + dist.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + ) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = int(os.getenv('LOCAL_RANK', 0)) + torch_npu.npu.set_device(local_rank) + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(dist.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + if not _WORLD.world_size == dist.get_world_size(): + logging.error("world group already initialized with a different world size") + + +# def init_distributed_environment( +# world_size: int = -1, +# rank: int = -1, +# distributed_init_method: str = "env://", +# local_rank: int = -1, +# backend: str = "hccl", +# ): +# logging.debug( +# "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", +# world_size, +# rank, +# local_rank, +# distributed_init_method, +# backend, +# ) +# if not torch.distributed.is_initialized(): +# assert distributed_init_method is not None, ( +# "distributed_init_method must be provided when initializing " +# "distributed environment" +# ) +# # this backend is used for WORLD +# torch.distributed.init_process_group( +# backend=backend, +# init_method=distributed_init_method, +# world_size=world_size, +# rank=rank, +# ) +# set_device(torch.distributed.get_rank() % device_count()) +# # set the local rank +# # local_rank is not available in torch ProcessGroup, +# # see https://github.com/pytorch/pytorch/issues/122816 +# if local_rank == -1: +# # local rank not set, this usually happens in single-node +# # setting, where we can use rank as local rank +# if distributed_init_method == "env://": +# # local_rank = int(os.getenv('LOCAL_RANK', 0)) +# local_rank = dist.get_rank() +# print(f"init_distributed_environment 里面 local_rank {local_rank}") +# else: +# local_rank = rank +# global _WORLD +# if _WORLD is None: +# ranks = list(range(torch.distributed.get_world_size())) +# _WORLD = init_world_group(ranks, local_rank, backend) +# print(f"_WORLD 初始化") +# else: +# assert ( +# _WORLD.world_size == torch.distributed.get_world_size() +# ), "world group already initialized with a different world size" +# print(f"_WORLD 没有 初始化") + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return ( + _CFG is not None + and _SP is not None + and _TP is not None + ) + + +def init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + parallel_mode: str, + **kwargs, +) -> GroupCoordinator: + if parallel_mode not in [ + "tensor", + "sequence", + "classifier_free_guidance", + ]: + logging.error(f"parallel_mode {parallel_mode} is not supported") + if parallel_mode == "sequence": # ulysses + return SequenceParallelGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + **kwargs, + ) + else: + return GroupCoordinator( # cfg + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def initialize_model_parallel( + classifier_free_guidance_degree: int = 1, + sequence_parallel_degree: int = 1, + ulysses_degree: int = 1, + ring_degree: int = 1, + tensor_parallel_degree: int = 1, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + classifier_free_guidance_degree: number of GPUs used for Classifier Free Guidance (CFG) + sequence_parallel_degree: number of GPUs used for sequence parallelism. + tensor_parallel_degree: number of GPUs used for tensor parallelism. + backend: distributed backend of pytorch collective comm. + """ + # Get world size and rank. Ensure some consistencies. + if not dist.is_initialized(): + logging.error("dist is not initialized") + world_size: int = dist.get_world_size() + backend = backend + + if ( + world_size + != classifier_free_guidance_degree + * sequence_parallel_degree + * tensor_parallel_degree + ): + raise RuntimeError( + f"world_size ({world_size}) is not equal to " + f"sequence_parallel_degree ({sequence_parallel_degree}) x " + f"classifier_free_guidance_degree " + f"({classifier_free_guidance_degree}) x " + f"tensor_parallel_degree " + f"({tensor_parallel_degree})" + ) + + rank_generator: RankGenerator = RankGenerator( + tensor_parallel_degree, + sequence_parallel_degree, + classifier_free_guidance_degree, + "tp-sp-cfg", + ) + + global _CFG + if _CFG is not None: + logging.error("classifier_free_guidance group is already initialized") + _CFG = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("cfg"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="classifier_free_guidance", + ) + + global _SP + if _SP is not None: + logging.error("sequence parallel group is already initialized") + set_seq_parallel_pg( + sp_ulysses_degree=ulysses_degree, + sp_ring_degree=ring_degree, + rank=get_world_group().rank_in_group, + world_size=world_size + ) + _SP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("sp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="sequence", + ulysses_group=PROCESS_GROUP.ULYSSES_PG, + ring_group=PROCESS_GROUP.RING_PG, + ) + + global _TP + assert _TP is None, "Tensor parallel group is already initialized" + _TP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("tp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="tensor", + ) + + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _CFG + if _CFG: + _CFG.destroy() + _CFG = None + + global _SP + if _SP: + _SP.destroy() + _SP = None + + global _TP + if _TP: + _TP.destroy() + _TP = None + + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if dist.is_initialized(): + dist.destroy_process_group() + + +def init_parallel_env(parallel_config: ParallelConfig): + if not model_parallel_is_initialized(): + logging.warning("Model parallel is not initialized, initializing...") + init_distributed_environment( + world_size=dist.get_world_size(), + rank=dist.get_rank(), + backend='hccl', + ) + initialize_model_parallel( + classifier_free_guidance_degree=parallel_config.cfg_degree, + sequence_parallel_degree=parallel_config.sp_degree, + ulysses_degree=parallel_config.ulysses_degree, + ring_degree=parallel_config.ring_degree, + tensor_parallel_degree=parallel_config.tp_degree, + ) + + +def finalize_parallel_env(): + if model_parallel_is_initialized(): + destroy_model_parallel() + destroy_distributed_environment() diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/utils.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/utils.py new file mode 100644 index 00000000..c53ae68f --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/distributed/utils.py @@ -0,0 +1,152 @@ +from typing import List +import logging + + +def generate_masked_orthogonal_rank_groups( + world_size: int, parallel_size: List[int], mask: List[bool] +) -> List[List[int]]: + """Generate orthogonal parallel groups based on the parallel size and mask. + + Arguments: + world_size (int): world size + + parallel_size (List[int]): + The parallel size of each orthogonal parallel type. For example, if + tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, + and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. + + mask (List[bool]): + The mask controls which parallel methods the generated groups represent. If mask[i] is + True, it means the generated group contains the i-th parallelism method. For example, + if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then + the generated group is the `tp-dp` group, if the mask = [False, True, False], then the + generated group is the `pp` group. + """ + + def prefix_product(a: List[int], init=1) -> List[int]: + r = [init] + for v in a: + init = init * v + r.append(init) + return r + + def inner_product(a: List[int], b: List[int]) -> int: + return sum([x * y for x, y in zip(a, b)]) + + def decompose(index, shape, stride=None): + """ + This function solve the math problem below: + There is an equation: + index = sum(idx[i] * stride[i]) + And given the value of index, stride. + Return the idx. + This function will used to get the pp/dp/pp_rank + from group_index and rank_in_group. + """ + if stride is None: + stride = prefix_product(shape) + idx = [(index // d) % s for s, d in zip(shape, stride)] + # stride is a prefix_product result. And the value of stride[-1] + # is not used. + if not ( + sum([x * y for x, y in zip(idx, stride[:-1])]) == index + ): + logging.error("idx {} with shape {} mismatch the return idx {}".format(index, shape, idx)) + return idx + + masked_shape = [s for s, m in zip(parallel_size, mask) if m] + unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] + + global_stride = prefix_product(parallel_size) + masked_stride = [d for d, m in zip(global_stride, mask) if m] + unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] + + group_size = prefix_product(masked_shape)[-1] + num_of_group = world_size // group_size + + ranks = [] + for group_index in range(num_of_group): + # get indices from unmaksed for group_index. + decomposed_group_idx = decompose(group_index, unmasked_shape) + rank = [] + for rank_in_group in range(group_size): + # get indices from masked for rank_in_group. + decomposed_rank_idx = decompose(rank_in_group, masked_shape) + rank.append( + inner_product(decomposed_rank_idx, masked_stride) + + inner_product(decomposed_group_idx, unmasked_stride) + ) + ranks.append(rank) + return ranks + + +class RankGenerator(object): + def __init__( + self, + tp: int, + sp: int, + cfg: int, + order: str, + rank_offset: int = 0, + ) -> None: + self.tp = tp + self.sp = sp + self.cfg = cfg + self.rank_offset = rank_offset + self.world_size = tp * sp * cfg + + self.name_to_size = { + "sp": self.sp, + "cfg": self.cfg, + "tp": self.tp, + } + order = order.lower() + + for name in self.name_to_size.keys(): + if name not in order and self.name_to_size[name] != 1: + raise RuntimeError( + f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})." + ) + elif name not in order: + order = order + "-" + name + + self.order = order + self.ordered_size = [] + + for token in order.split("-"): + self.ordered_size.append(self.name_to_size[token]) + + def get_mask(self, order: str, token: str): + ordered_token = order.split("-") + token = token.split("-") + mask = [False] * len(ordered_token) + for t in token: + mask[ordered_token.index(t)] = True + return mask + + def get_ranks(self, token): + """Get rank group by input token. + + Arguments: + token (str): + Specify the ranks type that want to get. If we want + to obtain multiple parallel types, we can use a hyphen + '-' to separate them. For example, if we want to obtain + the TP_DP group, the token should be 'tp-dp'. + + independent_ep (bool: True): + This flag controls whether we treat EP and DP independently. + EP shares ranks with DP, if we want to get ranks related to + EP, we should set the flag. For example, get_ranks('dp', True) + will get DP modulo EP group, and get_ranks('dp', False) will + get full DP group. + """ + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups( + self.world_size, self.ordered_size, mask + ) + if self.rank_offset > 0: + for rank_group in ranks: + for i, _ in enumerate(rank_group): + rank_group[i] += self.rank_offset + return ranks \ No newline at end of file diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/pipeline_qwenimage_edit_plus.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/pipeline_qwenimage_edit_plus.py new file mode 100644 index 00000000..3d5c4b3b --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/pipeline_qwenimage_edit_plus.py @@ -0,0 +1,964 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import QwenImageLoraLoaderMixin +from diffusers.models import AutoencoderKLQwenImage + +# from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from ais_bench.benchmark.models.local_models.qwenimage_edit.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler + +from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput + +from ais_bench.benchmark.models.local_models.qwenimage_edit.transformer_qwenimage import QwenImageTransformer2DModel +from ais_bench.benchmark.models.local_models.qwenimage_edit.distributed.parallel_mgr import ( + get_sequence_parallel_world_size, + get_classifier_free_guidance_world_size, + get_classifier_free_guidance_rank, + get_cfg_group, + init_distributed_environment, + initialize_model_parallel, + get_sequence_parallel_rank, + get_sp_group +) + +#------------------ljf------------------- +import os +import torch_npu +#------------------------------------- + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +#------------------ljf--------------- +USE_NPU = False +if torch.npu.is_available(): + USE_NPU = True + + +COND_CACHE = bool(int(os.environ.get('COND_CACHE', 0))) +UNCOND_CACHE = bool(int(os.environ.get('UNCOND_CACHE', 0))) +#----------------------------------- + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from diffusers import QwenImageEditPlusPipeline + >>> from diffusers.utils import load_image + + >>> pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" + ... ).convert("RGB") + >>> prompt = ( + ... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors" + ... ) + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(image, prompt, num_inference_steps=50).images[0] + >>> image.save("qwenimage_edit_plus.png") + ``` +""" + +CONDITION_IMAGE_SIZE = 384 * 384 +VAE_IMAGE_SIZE = 1024 * 1024 + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height + + +class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The Qwen-Image-Edit pipeline for image editing. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + processor: Qwen2VLProcessor, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 1024 + + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 64 + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" + if isinstance(image, list): + base_img_prompt = "" + for i, img in enumerate(image): + base_img_prompt += img_prompt_template.format(i + 1) + elif image is not None: + base_img_prompt = img_prompt_template.format(1) + else: + base_img_prompt = "" + + template = self.prompt_template_encode + + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(base_img_prompt + e) for e in prompt] + + model_inputs = self.processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + + outputs = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + image (`torch.Tensor`, *optional*): + image to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + def prepare_latents( + self, + images, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + image_latents = None + if images is not None: + if not isinstance(images, list): + images = [images] + all_image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + all_image_latents.append(image_latents) + image_latents = torch.cat(all_image_latents, dim=1) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + # print(f"device {device}, ljf 随机生成latents latents {latents}") + return latents, image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[PipelineImageInput] = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free + Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is + enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale + encourages to generate images that are closely linked to the text `prompt`, usually at the expense of + lower image quality. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. This + parameter in the pipeline is there to support future guidance-distilled models when they come up. It is + ignored when not using guidance distilled models. To enable traditional classifier-free guidance, + please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should + enable classifier-free guidance computations). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + image_size = image[-1].size if isinstance(image, list) else image.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) + height = height or calculated_height + width = width or calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + if not isinstance(image, list): + image = [image] + condition_image_sizes = [] + condition_images = [] + vae_image_sizes = [] + vae_images = [] + for img in image: + image_width, image_height = img.size + condition_width, condition_height = calculate_dimensions( + CONDITION_IMAGE_SIZE, image_width / image_height + ) + vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height) + condition_image_sizes.append((condition_width, condition_height)) + vae_image_sizes.append((vae_width, vae_height)) + condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) + vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + image=condition_images, + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + image=condition_images, + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents = self.prepare_latents( + vae_images, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, # ljf None + ) + img_shapes = [ + [ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), + *[ + (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2) + for vae_width, vae_height in vae_image_sizes + ], + ] + ] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + if get_classifier_free_guidance_world_size() == 2: + if get_classifier_free_guidance_rank() == 0: + with self.transformer.cache_context("uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + use_cache=UNCOND_CACHE, #-------------ljf------------- + if_cond=False, #----------------ljf------------- + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + else: + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + use_cache=COND_CACHE, #-------ljf-------- + if_cond=True, #------------ljf----------- + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + + noise_pred_uncond, noise_pred_text = get_cfg_group().all_gather(noise_pred, separate_tensors=True) + + comb_pred = noise_pred_uncond + true_cfg_scale * (noise_pred_text - noise_pred_uncond) + + cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True) # 修正代码 + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + #------------ljf 原始代码--------------- + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + use_cache=COND_CACHE, #-------ljf-------- + if_cond=True, #------------ljf----------- + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + use_cache=UNCOND_CACHE, #-------------ljf------------- + if_cond=False, #----------------ljf------------- + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + #---------------------------------------------------- + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] # + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) \ No newline at end of file diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/scheduling_flow_match_euler_discrete.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/scheduling_flow_match_euler_discrete.py new file mode 100644 index 00000000..0e67901d --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/scheduling_flow_match_euler_discrete.py @@ -0,0 +1,563 @@ +# Copyright 2025 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, is_scipy_available, logging +from diffusers.schedulers.scheduling_utils import SchedulerMixin + + +if is_scipy_available(): + import scipy.stats + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + use_dynamic_shifting (`bool`, defaults to False): + Whether to apply timestep shifting on-the-fly based on the image resolution. + base_shift (`float`, defaults to 0.5): + Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent + with desired output. + max_shift (`float`, defaults to 1.15): + Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be + more exaggerated or stylized. + base_image_seq_len (`int`, defaults to 256): + The base image sequence length. + max_image_seq_len (`int`, defaults to 4096): + The maximum image sequence length. + invert_sigmas (`bool`, defaults to False): + Whether to invert the sigmas. + shift_terminal (`float`, defaults to None): + The end value of the shifted timestep schedule. + use_karras_sigmas (`bool`, defaults to False): + Whether to use Karras sigmas for step sizes in the noise schedule during sampling. + use_exponential_sigmas (`bool`, defaults to False): + Whether to use exponential sigmas for step sizes in the noise schedule during sampling. + use_beta_sigmas (`bool`, defaults to False): + Whether to use beta sigmas for step sizes in the noise schedule during sampling. + time_shift_type (`str`, defaults to "exponential"): + The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". + stochastic_sampling (`bool`, defaults to False): + Whether to use stochastic sampling. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting: bool = False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, + time_shift_type: str = "exponential", + stochastic_sampling: bool = False, + ): + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) + if time_shift_type not in {"exponential", "linear"}: + raise ValueError("`time_shift_type` must either be 'exponential' or 'linear'.") + + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self._shift = shift + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def shift(self): + """ + The value used for shifting. + """ + return self._shift + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_shift(self, shift: float): + self._shift = shift + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) + + if sample.device.type == "mps" and torch.is_floating_point(timestep): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + timestep = timestep.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) + + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + if self.config.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.config.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) + + def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: + r""" + Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config + value. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`torch.Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `torch.Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + timesteps: Optional[List[float]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`, *optional*): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + sigmas (`List[float]`, *optional*): + Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed + automatically. + mu (`float`, *optional*): + Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep + shifting. + timesteps (`List[float]`, *optional*): + Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed + automatically. + """ + if self.config.use_dynamic_shifting and mu is None: + raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`") + + if sigmas is not None and timesteps is not None: + if len(sigmas) != len(timesteps): + raise ValueError("`sigmas` and `timesteps` should have the same length") + + if num_inference_steps is not None: + if (sigmas is not None and len(sigmas) != num_inference_steps) or ( + timesteps is not None and len(timesteps) != num_inference_steps + ): + raise ValueError( + "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided" + ) + else: + num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps) + + self.num_inference_steps = num_inference_steps + + # 1. Prepare default sigmas + is_timesteps_provided = timesteps is not None + + if is_timesteps_provided: + timesteps = np.array(timesteps).astype(np.float32) + + if sigmas is None: + if timesteps is None: + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + sigmas = timesteps / self.config.num_train_timesteps + else: + sigmas = np.array(sigmas).astype(np.float32) + num_inference_steps = len(sigmas) + + # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of + # "exponential" or "linear" type is applied + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + + # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + + # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + # 5. Convert sigmas and timesteps to tensors and move to specified device + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + if not is_timesteps_provided: + timesteps = sigmas * self.config.num_train_timesteps + else: + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device) + + # 6. Append the terminal sigma value. + # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the + # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) + else: + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self.timesteps = timesteps + self.sigmas = sigmas + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + per_token_timesteps: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + per_token_timesteps (`torch.Tensor`, *optional*): + The timesteps for each token in the sample. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + if per_token_timesteps is not None: + per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps + + sigmas = self.sigmas[:, None, None] + lower_mask = sigmas < per_token_sigmas[None] - 1e-6 + lower_sigmas = lower_mask * sigmas + lower_sigmas, _ = lower_sigmas.max(dim=0) + + current_sigma = per_token_sigmas[..., None] + next_sigma = lower_sigmas[..., None] + dt = current_sigma - next_sigma + else: + sigma_idx = self.step_index + sigma = self.sigmas[sigma_idx] + sigma_next = self.sigmas[sigma_idx + 1] + + current_sigma = sigma + next_sigma = sigma_next + dt = sigma_next - sigma + + if self.config.stochastic_sampling: + print("ljf 进入采样器,涉及随机") + x0 = sample - current_sigma * model_output + noise = torch.randn_like(sample) + prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise + else: + print("ljf 进入采样器,无随机") + prev_sample = sample + dt * model_output + + # upon completion increase step index by one + self._step_index += 1 + if per_token_timesteps is None: + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + if not return_dict: + return (prev_sample,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + + def _time_shift_exponential(self, mu, sigma, t): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def _time_shift_linear(self, mu, sigma, t): + return mu / (mu + (1 / t - 1) ** sigma) + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/ais_bench/benchmark/models/local_models/qwenimage_edit/transformer_qwenimage.py b/ais_bench/benchmark/models/local_models/qwenimage_edit/transformer_qwenimage.py new file mode 100644 index 00000000..52b2a002 --- /dev/null +++ b/ais_bench/benchmark/models/local_models/qwenimage_edit/transformer_qwenimage.py @@ -0,0 +1,792 @@ +# Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from diffusers.utils.torch_utils import maybe_allow_in_graph +# from diffusers.models._modeling_parallel import ContextParallelInput, ContextParallelOutput + +from diffusers.models.attention import AttentionMixin, FeedForward +from diffusers.models.attention_dispatch import dispatch_attention_fn +from diffusers.models.attention_processor import Attention +from diffusers.models.cache_utils import CacheMixin +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm + +#------------ljf-------- +import torch_npu +from mindiesd import attention_forward +import os +ROPE_FUSE = bool(int(os.environ.get('ROPE_FUSE', 0))) +ADALN_FUSE = bool(int(os.environ.get('ADALN_FUSE', 0))) +#--------------------------- + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +class AdaLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + + def forward(self, x, mod_params): + shift, scale, gate = mod_params.chunk(3, dim=-1) + scale = (1 + scale.unsqueeze(1)) + shift = shift.unsqueeze(1) + return torch_npu.npu_layer_norm_eval( + x, normalized_shape=[self.hidden_size], weight=scale, bias=shift, eps=self.eps), gate.unsqueeze(1) + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent).to(timesteps.dtype) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def apply_rotary_emb_qwen( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + if not ROPE_FUSE: #----------------- ljf -------------------- + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(1) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + else: + cos = freqs_cis.real + sin = freqs_cis.imag + seqlen = cos.shape[0] + + cos = cos.unsqueeze(0).unsqueeze(2).unsqueeze(-1).expand(-1, -1, -1, -1, 2).reshape(1, seqlen, 1, -1) + sin = sin.unsqueeze(0).unsqueeze(2).unsqueeze(-1).expand(-1, -1, -1, -1, 2).reshape(1, seqlen, 1, -1) + + x_out = torch_npu.npu_rotary_mul(x, cos, sin, 'interleave') + return x_out.type_as(x) + + +class QwenTimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timestep, hidden_states): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D) + + conditioning = timesteps_emb + + return conditioning + + +class QwenEmbedRope(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.rope_cache = {} + + # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward(self, video_fhw, txt_seq_lens, device): + """ + Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: + txt_length: [bs] a list of 1 integers representing the length of the text + """ + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + #-----------ljf------------------- + # if not torch.compiler.is_compiling(): + # if rope_key not in self.rope_cache: + # self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) + # video_freq = self.rope_cache[rope_key] + # else: + video_freq = self._compute_video_freqs(frame, height, width, idx) + video_freq = video_freq.to(device) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + @functools.lru_cache(maxsize=None) + def _compute_video_freqs(self, frame, height, width, idx=0): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + +class QwenDoubleStreamAttnProcessor2_0: + """ + Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor + implements joint attention computation where text and image streams are processed together. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, # Image stream + encoder_hidden_states: torch.FloatTensor = None, # Text stream + encoder_hidden_states_mask: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + if encoder_hidden_states is None: + raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") + + seq_txt = encoder_hidden_states.shape[1] + + # Compute QKV for image stream (sample projections) + img_query = attn.to_q(hidden_states) + img_key = attn.to_k(hidden_states) + img_value = attn.to_v(hidden_states) + + # Compute QKV for text stream (context projections) + txt_query = attn.add_q_proj(encoder_hidden_states) + txt_key = attn.add_k_proj(encoder_hidden_states) + txt_value = attn.add_v_proj(encoder_hidden_states) + + # Reshape for multi-head attention + img_query = img_query.unflatten(-1, (attn.heads, -1)) + img_key = img_key.unflatten(-1, (attn.heads, -1)) + img_value = img_value.unflatten(-1, (attn.heads, -1)) + + txt_query = txt_query.unflatten(-1, (attn.heads, -1)) + txt_key = txt_key.unflatten(-1, (attn.heads, -1)) + txt_value = txt_value.unflatten(-1, (attn.heads, -1)) + + # Apply QK normalization + if attn.norm_q is not None: + img_query = attn.norm_q(img_query) + if attn.norm_k is not None: + img_key = attn.norm_k(img_key) + if attn.norm_added_q is not None: + txt_query = attn.norm_added_q(txt_query) + if attn.norm_added_k is not None: + txt_key = attn.norm_added_k(txt_key) + + # Apply RoPE + if image_rotary_emb is not None: + img_freqs, txt_freqs = image_rotary_emb + img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False) + img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False) + txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False) + txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False) + + + + # print("ljf img_query ", img_query) + # exit() + + # Concatenate for joint attention + # Order: [text, image] + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + + # Compute joint attention + # joint_hidden_states = dispatch_attention_fn( + # joint_query, + # joint_key, + # joint_value, + # attn_mask=attention_mask, + # dropout_p=0.0, + # is_causal=False, + # backend=self._attention_backend, + # parallel_config=self._parallel_config, + # ) + #--------------------ljf------------------------ + joint_hidden_states = attention_forward(joint_query, joint_key, joint_value, + opt_mode="manual", op_type="fused_attn_score", layout="BNSD") + #--------------------------------------------- + + # Reshape back + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + # Split attention outputs back + txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part + img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part + + # Apply output projections + img_attn_output = attn.to_out[0](img_attn_output) + if len(attn.to_out) > 1: + img_attn_output = attn.to_out[1](img_attn_output) # dropout + + txt_attn_output = attn.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +@maybe_allow_in_graph +class QwenImageTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + # Image processing modules + self.img_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 + ) + #---------------ljf------------------ + if not ADALN_FUSE: + self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + else: + self.img_norm1 = AdaLayerNorm(dim, eps=eps) + #------------------------------------ + + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, # Enable cross attention for joint computation + added_kv_proj_dim=dim, # Enable added KV projections for text stream + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=QwenDoubleStreamAttnProcessor2_0(), + qk_norm=qk_norm, + eps=eps, + ) + #--------------ljf------------------- + if not ADALN_FUSE: + self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + else: + self.img_norm2 = AdaLayerNorm(dim, eps=eps) + #-------------------------------- + + self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + # Text processing modules + self.txt_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 + ) + #---------------ljf------------------- + if not ADALN_FUSE: + self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + else: + self.txt_norm1 = AdaLayerNorm(dim, eps=eps) + #------------------------------------------- + + # Text doesn't need separate attention - it's handled by img_attn joint computation + #---------------------------ljf-------------- + if not ADALN_FUSE: + self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + else: + self.txt_norm2 = AdaLayerNorm(dim, eps=eps) + #---------------------------------------- + + self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def _modulate(self, x, mod_params): + """Apply modulation to input tensor""" + shift, scale, gate = mod_params.chunk(3, dim=-1) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + def forward( + self, + hidden_states: torch.Tensor, + # encoder_hidden_states: torch.Tensor, + encoder_hidden_states, + encoder_hidden_states_mask: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + txt_pad_len = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Get modulation parameters for both streams + img_mod_params = self.img_mod(temb) # [B, 6*dim] + txt_mod_params = self.txt_mod(temb) # [B, 6*dim] + + # Split modulation parameters for norm1 and norm2 + img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + + # Process image stream - norm1 + modulation + #------------------ljf------------------ + if not ADALN_FUSE: + img_normed = self.img_norm1(hidden_states) + img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) + else: + img_modulated, img_gate1 = self.img_norm1(hidden_states, img_mod1) + #---------------------------------------- + + # Process text stream - norm1 + modulation + #----------------------ljf--------------- + if not ADALN_FUSE: + txt_normed = self.txt_norm1(encoder_hidden_states) + txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) + else: + txt_modulated, txt_gate1 = self.txt_norm1(encoder_hidden_states, txt_mod1) + #---------------------------------- + + + # Use QwenAttnProcessor2_0 for joint attention computation + # This directly implements the DoubleStreamLayerMegatron logic: + # 1. Computes QKV for both streams + # 2. Applies QK normalization and RoPE + # 3. Concatenates and runs joint attention + # 4. Splits results back to separate streams + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=img_modulated, # Image stream (will be processed as "sample") + encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context") + encoder_hidden_states_mask=encoder_hidden_states_mask, + image_rotary_emb=image_rotary_emb, + # txt_pad_len = txt_pad_len, + **joint_attention_kwargs, + ) + + # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided + # ljf (B, S_txt_split , H*D_head), (B, S_img_split , H*D_head) + img_attn_output, txt_attn_output = attn_output + + # Apply attention gates and add residual (like in Megatron) + hidden_states = hidden_states + img_gate1 * img_attn_output + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output + + # Process image stream - norm2 + MLP + #-----------------------ljf----------- + if not ADALN_FUSE: + img_normed2 = self.img_norm2(hidden_states) + img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) + else: + img_modulated2, img_gate2 = self.img_norm2(hidden_states, img_mod2) + #--------------------------------------- + + img_mlp_output = self.img_mlp(img_modulated2) + hidden_states = hidden_states + img_gate2 * img_mlp_output + + # Process text stream - norm2 + MLP + #----------------ljf------------------------ + if not ADALN_FUSE: + txt_normed2 = self.txt_norm2(encoder_hidden_states) + txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) + else: + txt_modulated2, txt_gate2 = self.txt_norm2(encoder_hidden_states, txt_mod2) + #-------------------------------- + + txt_mlp_output = self.txt_mlp(txt_modulated2) + encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output + + # Clip to prevent overflow for fp16 + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + # return encoder_hidden_states, hidden_states + return hidden_states, encoder_hidden_states + + +class QwenImageTransformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): + """ + The Transformer model introduced in Qwen. + + Args: + patch_size (`int`, defaults to `2`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `60`): + The number of layers of dual stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `3584`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + guidance_embeds (`bool`, defaults to `False`): + Whether to use guidance embeddings for guidance-distilled variant of the model. + axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["QwenImageTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + _repeated_blocks = ["QwenImageTransformerBlock"] + # _cp_plan = { + # "": { + # "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + # "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + # "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + # }, + # "pos_embed": { + # 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), + # 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), + # }, + # "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + # } + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 64, + out_channels: Optional[int] = 16, + num_layers: int = 60, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 3584, + guidance_embeds: bool = False, # TODO: this should probably be removed + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) + + self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim) + + self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) + + self.img_in = nn.Linear(in_channels, self.inner_dim) + self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + QwenImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + #-----------------ljf------------- + self.cache_cond = None + self.cache_uncond = None + #------------------------------- + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_shapes: Optional[List[Tuple[int, int, int]]] = None, + txt_seq_lens: Optional[List[int]] = None, + guidance: torch.Tensor = None, # TODO: this should probably be removed + attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + return_dict: bool = True, + use_cache: bool = False, #---------------ljf------------ + if_cond: bool = True, #-------------------ljf------------------ + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`QwenTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`): + Mask of the input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + hidden_states = self.img_in(hidden_states) + + timestep = timestep.to(hidden_states.dtype) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = ( + self.time_text_embed(timestep, hidden_states) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states) + ) + + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + encoder_hidden_states_mask, + temb, + image_rotary_emb, + ) + + else: + #--------------------ljf----------- + if not use_cache: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + ) + else: + if if_cond: + hidden_states, encoder_hidden_states = self.cache_cond.apply( + block, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + ) + else: + hidden_states, encoder_hidden_states = self.cache_uncond.apply( + block, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + ) + #----------------------------------- + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + + # Use only the image part (hidden_states) from the dual-stream blocks + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) \ No newline at end of file