Skip to content

Commit f50b18e

Browse files
authored
[Modular] Qwen (#12220)
* add qwen modular
1 parent fc337d5 commit f50b18e

File tree

17 files changed

+4275
-9
lines changed

17 files changed

+4275
-9
lines changed

docs/source/en/api/image_processor.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ All pipelines with [`VaeImageProcessor`] accept PIL Image, PyTorch tensor, or Nu
2020

2121
[[autodoc]] image_processor.VaeImageProcessor
2222

23+
## InpaintProcessor
24+
25+
The [`InpaintProcessor`] accepts `mask` and `image` inputs and process them together. Optionally, it can accept padding_mask_crop and apply mask overlay.
26+
27+
[[autodoc]] image_processor.InpaintProcessor
28+
2329
## VaeImageProcessorLDM3D
2430

2531
The [`VaeImageProcessorLDM3D`] accepts RGB and depth inputs and returns RGB and depth outputs.

src/diffusers/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,10 @@
385385
[
386386
"FluxAutoBlocks",
387387
"FluxModularPipeline",
388+
"QwenImageAutoBlocks",
389+
"QwenImageEditAutoBlocks",
390+
"QwenImageEditModularPipeline",
391+
"QwenImageModularPipeline",
388392
"StableDiffusionXLAutoBlocks",
389393
"StableDiffusionXLModularPipeline",
390394
"WanAutoBlocks",
@@ -1038,6 +1042,10 @@
10381042
from .modular_pipelines import (
10391043
FluxAutoBlocks,
10401044
FluxModularPipeline,
1045+
QwenImageAutoBlocks,
1046+
QwenImageEditAutoBlocks,
1047+
QwenImageEditModularPipeline,
1048+
QwenImageModularPipeline,
10411049
StableDiffusionXLAutoBlocks,
10421050
StableDiffusionXLModularPipeline,
10431051
WanAutoBlocks,

src/diffusers/hooks/_helpers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def _register_attention_processors_metadata():
108108
from ..models.attention_processor import AttnProcessor2_0
109109
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
110110
from ..models.transformers.transformer_flux import FluxAttnProcessor
111+
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
111112
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
112113

113114
# AttnProcessor2_0
@@ -140,6 +141,14 @@ def _register_attention_processors_metadata():
140141
metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
141142
)
142143

144+
# QwenDoubleStreamAttnProcessor2
145+
AttentionProcessorRegistry.register(
146+
model_class=QwenDoubleStreamAttnProcessor2_0,
147+
metadata=AttentionProcessorMetadata(
148+
skip_processor_output_fn=_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0
149+
),
150+
)
151+
143152

144153
def _register_transformer_blocks_metadata():
145154
from ..models.attention import BasicTransformerBlock
@@ -298,4 +307,5 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
298307
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
299308
# not sure what this is yet.
300309
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
310+
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
301311
# fmt: on

src/diffusers/image_processor.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ def resize(
523523
size=(height, width),
524524
)
525525
image = self.pt_to_numpy(image)
526+
526527
return image
527528

528529
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
@@ -838,6 +839,137 @@ def apply_overlay(
838839
return image
839840

840841

842+
class InpaintProcessor(ConfigMixin):
843+
"""
844+
Image processor for inpainting image and mask.
845+
"""
846+
847+
config_name = CONFIG_NAME
848+
849+
@register_to_config
850+
def __init__(
851+
self,
852+
do_resize: bool = True,
853+
vae_scale_factor: int = 8,
854+
vae_latent_channels: int = 4,
855+
resample: str = "lanczos",
856+
reducing_gap: int = None,
857+
do_normalize: bool = True,
858+
do_binarize: bool = False,
859+
do_convert_grayscale: bool = False,
860+
mask_do_normalize: bool = False,
861+
mask_do_binarize: bool = True,
862+
mask_do_convert_grayscale: bool = True,
863+
):
864+
super().__init__()
865+
866+
self._image_processor = VaeImageProcessor(
867+
do_resize=do_resize,
868+
vae_scale_factor=vae_scale_factor,
869+
vae_latent_channels=vae_latent_channels,
870+
resample=resample,
871+
reducing_gap=reducing_gap,
872+
do_normalize=do_normalize,
873+
do_binarize=do_binarize,
874+
do_convert_grayscale=do_convert_grayscale,
875+
)
876+
self._mask_processor = VaeImageProcessor(
877+
do_resize=do_resize,
878+
vae_scale_factor=vae_scale_factor,
879+
vae_latent_channels=vae_latent_channels,
880+
resample=resample,
881+
reducing_gap=reducing_gap,
882+
do_normalize=mask_do_normalize,
883+
do_binarize=mask_do_binarize,
884+
do_convert_grayscale=mask_do_convert_grayscale,
885+
)
886+
887+
def preprocess(
888+
self,
889+
image: PIL.Image.Image,
890+
mask: PIL.Image.Image = None,
891+
height: int = None,
892+
width: int = None,
893+
padding_mask_crop: Optional[int] = None,
894+
) -> Tuple[torch.Tensor, torch.Tensor]:
895+
"""
896+
Preprocess the image and mask.
897+
"""
898+
if mask is None and padding_mask_crop is not None:
899+
raise ValueError("mask must be provided if padding_mask_crop is provided")
900+
901+
# if mask is None, same behavior as regular image processor
902+
if mask is None:
903+
return self._image_processor.preprocess(image, height=height, width=width)
904+
905+
if padding_mask_crop is not None:
906+
crops_coords = self._image_processor.get_crop_region(mask, width, height, pad=padding_mask_crop)
907+
resize_mode = "fill"
908+
else:
909+
crops_coords = None
910+
resize_mode = "default"
911+
912+
processed_image = self._image_processor.preprocess(
913+
image,
914+
height=height,
915+
width=width,
916+
crops_coords=crops_coords,
917+
resize_mode=resize_mode,
918+
)
919+
920+
processed_mask = self._mask_processor.preprocess(
921+
mask,
922+
height=height,
923+
width=width,
924+
resize_mode=resize_mode,
925+
crops_coords=crops_coords,
926+
)
927+
928+
if crops_coords is not None:
929+
postprocessing_kwargs = {
930+
"crops_coords": crops_coords,
931+
"original_image": image,
932+
"original_mask": mask,
933+
}
934+
else:
935+
postprocessing_kwargs = {
936+
"crops_coords": None,
937+
"original_image": None,
938+
"original_mask": None,
939+
}
940+
941+
return processed_image, processed_mask, postprocessing_kwargs
942+
943+
def postprocess(
944+
self,
945+
image: torch.Tensor,
946+
output_type: str = "pil",
947+
original_image: Optional[PIL.Image.Image] = None,
948+
original_mask: Optional[PIL.Image.Image] = None,
949+
crops_coords: Optional[Tuple[int, int, int, int]] = None,
950+
) -> Tuple[PIL.Image.Image, PIL.Image.Image]:
951+
"""
952+
Postprocess the image, optionally apply mask overlay
953+
"""
954+
image = self._image_processor.postprocess(
955+
image,
956+
output_type=output_type,
957+
)
958+
# optionally apply the mask overlay
959+
if crops_coords is not None and (original_image is None or original_mask is None):
960+
raise ValueError("original_image and original_mask must be provided if crops_coords is provided")
961+
962+
elif crops_coords is not None and output_type != "pil":
963+
raise ValueError("output_type must be 'pil' if crops_coords is provided")
964+
965+
elif crops_coords is not None:
966+
image = [
967+
self._image_processor.apply_overlay(original_mask, original_image, i, crops_coords) for i in image
968+
]
969+
970+
return image
971+
972+
841973
class VaeImageProcessorLDM3D(VaeImageProcessor):
842974
"""
843975
Image processor for VAE LDM3D.

src/diffusers/modular_pipelines/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@
4747
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
4848
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
4949
_import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"]
50+
_import_structure["qwenimage"] = [
51+
"QwenImageAutoBlocks",
52+
"QwenImageModularPipeline",
53+
"QwenImageEditModularPipeline",
54+
"QwenImageEditAutoBlocks",
55+
]
5056
_import_structure["components_manager"] = ["ComponentsManager"]
5157

5258
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -68,6 +74,12 @@
6874
SequentialPipelineBlocks,
6975
)
7076
from .modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, InsertableDict, OutputParam
77+
from .qwenimage import (
78+
QwenImageAutoBlocks,
79+
QwenImageEditAutoBlocks,
80+
QwenImageEditModularPipeline,
81+
QwenImageModularPipeline,
82+
)
7183
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
7284
from .wan import WanAutoBlocks, WanModularPipeline
7385
else:

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
5757
("wan", "WanModularPipeline"),
5858
("flux", "FluxModularPipeline"),
59+
("qwenimage", "QwenImageModularPipeline"),
60+
("qwenimage-edit", "QwenImageEditModularPipeline"),
5961
]
6062
)
6163

@@ -64,6 +66,8 @@
6466
("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"),
6567
("WanModularPipeline", "WanAutoBlocks"),
6668
("FluxModularPipeline", "FluxAutoBlocks"),
69+
("QwenImageModularPipeline", "QwenImageAutoBlocks"),
70+
("QwenImageEditModularPipeline", "QwenImageEditAutoBlocks"),
6771
]
6872
)
6973

@@ -133,8 +137,8 @@ def __getattr__(self, name):
133137
Allow attribute access to intermediate values. If an attribute is not found in the object, look for it in the
134138
intermediates dict.
135139
"""
136-
if name in self.intermediates:
137-
return self.intermediates[name]
140+
if name in self.values:
141+
return self.values[name]
138142
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
139143

140144
def __repr__(self):
@@ -548,8 +552,11 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
548552

549553
def __init__(self):
550554
sub_blocks = InsertableDict()
551-
for block_name, block_cls in zip(self.block_names, self.block_classes):
552-
sub_blocks[block_name] = block_cls()
555+
for block_name, block in zip(self.block_names, self.block_classes):
556+
if inspect.isclass(block):
557+
sub_blocks[block_name] = block()
558+
else:
559+
sub_blocks[block_name] = block
553560
self.sub_blocks = sub_blocks
554561
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
555562
raise ValueError(
@@ -830,7 +837,9 @@ def expected_configs(self):
830837
return expected_configs
831838

832839
@classmethod
833-
def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks":
840+
def from_blocks_dict(
841+
cls, blocks_dict: Dict[str, Any], description: Optional[str] = None
842+
) -> "SequentialPipelineBlocks":
834843
"""Creates a SequentialPipelineBlocks instance from a dictionary of blocks.
835844
836845
Args:
@@ -852,12 +861,19 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlo
852861
instance.block_classes = [block.__class__ for block in sub_blocks.values()]
853862
instance.block_names = list(sub_blocks.keys())
854863
instance.sub_blocks = sub_blocks
864+
865+
if description is not None:
866+
instance.description = description
867+
855868
return instance
856869

857870
def __init__(self):
858871
sub_blocks = InsertableDict()
859-
for block_name, block_cls in zip(self.block_names, self.block_classes):
860-
sub_blocks[block_name] = block_cls()
872+
for block_name, block in zip(self.block_names, self.block_classes):
873+
if inspect.isclass(block):
874+
sub_blocks[block_name] = block()
875+
else:
876+
sub_blocks[block_name] = block
861877
self.sub_blocks = sub_blocks
862878

863879
def _get_inputs(self):
@@ -1280,8 +1296,11 @@ def outputs(self) -> List[str]:
12801296

12811297
def __init__(self):
12821298
sub_blocks = InsertableDict()
1283-
for block_name, block_cls in zip(self.block_names, self.block_classes):
1284-
sub_blocks[block_name] = block_cls()
1299+
for block_name, block in zip(self.block_names, self.block_classes):
1300+
if inspect.isclass(block):
1301+
sub_blocks[block_name] = block()
1302+
else:
1303+
sub_blocks[block_name] = block
12851304
self.sub_blocks = sub_blocks
12861305

12871306
@classmethod

0 commit comments

Comments
 (0)