Skip to content

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Aug 22, 2025

Qwen-Image

  • Text2Image
  • controlnet
  • inpaint
  • controlnet + inpaint
  • img2img
  • controlnet + img2img
  • diffdiff
Test Script for Qwen-Image Auto Pipeline
import os
import torch

from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.qwenimage import ALL_BLOCKS

from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor
import numpy as np
from PIL import Image

device = "cuda:2"

modular_repo = "YiYiXu/QwenImage-modular"
qwen_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["auto"])
pipeline = qwen_blocks.init_pipeline(modular_repo)
pipeline.load_default_components(torch_dtype=torch.bfloat16)
pipeline.to(device)

print("pipeline loaded")
print(pipeline)
print(f" ")
print(f"pipeline.blocks")
print(pipeline.blocks)
print(f" ")


# text2image

# prompt = "A painting of a squirrel eating a burger"
prompt = "现实主义风格的人像摄影作品,画面主体是一位容貌惊艳的女性面部特写。她拥有一头自然微卷的短发,发丝根根分明,蓬松的刘海修饰着额头,增添俏皮感。头上佩戴一顶绿色格子蕾丝边头巾,增添复古与柔美气息。身着一件简约绿色背心裙,在纯白色背景下格外突出。两只手分别握着半个红色桃子,双手轻轻贴在脸颊两侧,营造出可爱又富有创意的视觉效果。  人物表情生动,一只眼睛睁开,另一只微微闭合,展现出调皮与自信的神态。整体构图采用个性视角、非对称构图,聚焦人物主体,增强现场感和既视感。背景虚化处理,层次丰富,景深效果强烈,营造出低光氛围下浓厚的情绪张力。  画面细节精致,色彩生动饱满却不失柔和,呈现出富士胶片独有的温润质感。光影运用充满美学张力,带有轻微超现实的光效处理,提升整体画面高级感。整体风格为现实主义人像摄影,强调细腻的纹理与艺术化的光线表现,堪称一幅细节丰富、氛围拉满的杰作。超清,4K,电影级构图"
inputs = {
    "prompt": prompt,
    "generator": torch.manual_seed(0),
    "negative_prompt": " ",
    "height": 1328,
    "width": 1328,
    "num_inference_steps": 50,
    "num_images_per_prompt": 1,
}

output_images = pipeline(**inputs, output="images")
output_images[0].save(f"test_qwen_modular_output_1_text2image.png")
print(f"image saved at {os.path.abspath(f'test_qwen_modular_output_1_text2image.png')}")

# inpaint

prompt = "cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney"
negative_prompt = " "
source = load_image("https://github.com/Trgtuan10/Image_storage/blob/main/cute_cat.png?raw=true")
mask = load_image("https://github.com/Trgtuan10/Image_storage/blob/main/mask_cat.png?raw=true")

strengths = [0.9]

for strength in strengths:
    image = pipeline(
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=source.size[1],
        width=source.size[0],
        image=source,
        mask_image=mask,
        strength=strength,
        num_inference_steps=35,
        generator=torch.Generator(device="cuda").manual_seed(42),
        output="images"
    )[0]
    image.save(f"test_qwen_modular_output_2_inpaint_{strength}.png")
    print(f"image saved at {os.path.abspath(f'test_qwen_modular_output_2_inpaint_{strength}.png')}")


# test controlnet

# canny
control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/qwencond_input.png")
prompt = "Aesthetics art, traditional asian pagoda, elaborate golden accents, sky blue and white color palette, swirling cloud pattern, digital illustration, east asian architecture, ornamental rooftop, intricate detailing on building, cultural representation."
controlnet_conditioning_scale = 1.0

images = pipeline(
    prompt=prompt,
    negative_prompt=" ",
    control_image=control_image,
    controlnet_conditioning_scale=controlnet_conditioning_scale,
    width=control_image.size[0],
    height=control_image.size[1],
    generator=torch.Generator(device="cuda").manual_seed(42),
    output="images"
)
images[0].save(f"test_qwen_modular_output_3_controlnet.png")
print(f"image saved at {os.path.abspath(f'test_qwen_modular_output_3_controlnet.png')}")

# 4 controlnet + inpaint

prompt = "a blue robot singing opera with human-like expressions"
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")

head_mask = np.zeros_like(image)
head_mask[65:580,300:642] = 255
mask_image = Image.fromarray(head_mask)

processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(image)[0].convert("RGB")


image_output = pipeline(
    prompt=prompt,
    image=image,
    mask_image=mask_image,
    control_image=control_image,
    strength=0.9,
    num_inference_steps=30,
    output="images",
)[0]
image_output.save("test_qwen_modular_output_4_controlnet_inpaint.png")
print(f"image saved at {os.path.abspath(f'test_qwen_modular_output_4_controlnet_inpaint.png')}")



# update guider (PAG)

from diffusers import LayerSkipConfig, PerturbedAttentionGuidance

config = LayerSkipConfig(indices=[2, 9], skip_attention=False, skip_attention_scores=True, skip_ff=False)
guider = PerturbedAttentionGuidance(
    guidance_scale=5.0, perturbed_guidance_scale=2.5, perturbed_guidance_config=config
)
pipeline.update_components(guider=guider)

print("pipeline.guider")
print(pipeline.guider)


prompt = "现实主义风格的人像摄影作品,画面主体是一位容貌惊艳的女性面部特写。她拥有一头自然微卷的短发,发丝根根分明,蓬松的刘海修饰着额头,增添俏皮感。头上佩戴一顶绿色格子蕾丝边头巾,增添复古与柔美气息。身着一件简约绿色背心裙,在纯白色背景下格外突出。两只手分别握着半个红色桃子,双手轻轻贴在脸颊两侧,营造出可爱又富有创意的视觉效果。  人物表情生动,一只眼睛睁开,另一只微微闭合,展现出调皮与自信的神态。整体构图采用个性视角、非对称构图,聚焦人物主体,增强现场感和既视感。背景虚化处理,层次丰富,景深效果强烈,营造出低光氛围下浓厚的情绪张力。  画面细节精致,色彩生动饱满却不失柔和,呈现出富士胶片独有的温润质感。光影运用充满美学张力,带有轻微超现实的光效处理,提升整体画面高级感。整体风格为现实主义人像摄影,强调细腻的纹理与艺术化的光线表现,堪称一幅细节丰富、氛围拉满的杰作。超清,4K,电影级构图"
inputs = {
    "prompt": prompt,
    "generator": torch.manual_seed(0),
    "negative_prompt": " ",
    "height": 1328,
    "width": 1328,
    "num_inference_steps": 50,
    "num_images_per_prompt": 1,
}

output_images = pipeline(**inputs, output="images")
output_images[0].save(f"test_qwen_modular_output_5_guider.png")
print(f"image saved at {os.path.abspath(f'test_qwen_modular_output_5_guider.png')}")

QwenImage Edit

  • Edit
  • Edit + Inpaint
Test script for QwenImage-Edit in Modular
import os
import torch

from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.qwenimage import ALL_BLOCKS

from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor
import numpy as np
from PIL import Image

device = "cuda:2"
output_name_prefix = "test_qwen_edit_output"

modular_repo = "YiYiXu/QwenImage-edit-modular"
qwen_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["edit_auto"])
pipeline = qwen_blocks.init_pipeline(modular_repo)
pipeline.load_components(torch_dtype=torch.bfloat16)
pipeline.to(device)

print("pipeline loaded")
print(pipeline)
print(f" ")
print(f"pipeline.blocks")
print(pipeline.blocks)
print(f" ")


prompt = "change the hat to red"
negative_prompt = " "
source = load_image("https://github.com/Trgtuan10/Image_storage/blob/main/cute_cat.png?raw=true")
mask = load_image("https://github.com/Trgtuan10/Image_storage/blob/main/mask_cat.png?raw=true")

# edit
output_images = pipeline(
    prompt=prompt,
    negative_prompt=negative_prompt,
    image=source,
    num_inference_steps=35,
    generator=torch.Generator(device="cuda").manual_seed(42),
).images
output_images[0].save(f"{output_name_prefix}_1_edit.png")
print(f"image saved at {os.path.abspath(f'{output_name_prefix}_1_edit.png')}")

# inpaint

strengths = [0.9, 1.0]

for strength in strengths:
    image_output = pipeline(
        prompt=prompt,
        negative_prompt=negative_prompt,
        image=source,
        mask_image=mask,
        strength=strength,
        num_inference_steps=35,
        generator=torch.Generator(device="cuda").manual_seed(42),
    ).images[0]
    image_output.save(f"{output_name_prefix}_2_inpaint_{strength}.png")
    print(f"image saved at {os.path.abspath(f'{output_name_prefix}_2_inpaint_{strength}.png')}")

Load from standard repo

import torch
from diffusers import ModularPipeline, ComponentsManager

repo_id = "Qwen/Qwen-Image"
# repo_id = "Qwen/Qwen-Image-Edit"

components = ComponentsManager()
components.enable_auto_cpu_offload(device="cuda")
pipeline = ModularPipeline.from_pretrained(repo_id, components_manager=components)
pipeline.load_components(torch_dtype=torch.float16)
print(pipeline)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -539,8 +540,11 @@ class AutoPipelineBlocks(ModularPipelineBlocks):

def __init__(self):
sub_blocks = InsertableDict()
for block_name, block_cls in zip(self.block_names, self.block_classes):
sub_blocks[block_name] = block_cls()
for block_name, block in zip(self.block_names, self.block_classes):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started to make dynamic blocks that can be configured during __init__

a simple made-up example would be

inpaint_vae_encoder = DynamicVaeEncoder(input_name="mask_image, output_name = "mask_image_latents")
vae_encoder = DynamicVaeEncoder(input_name="image", output_name="image_latents")

the first one takes input mask_image and return intermediate_outputs mask_image_latents
the second one takes image and return intermediate_outputs image_latents

this change is to support this use case

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice.

(nit): Maybe we can have a comment explaining this because otherwise, this pattern seems a bit concerning:

if inspect.isclass(block):
    sub_blocks[block_name] = block()
else:
    sub_blocks[block_name] = block

# Prepare Latents steps


class QwenImagePackLatentsDynamicStep(ModularPipelineBlocks):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

example of a dynamic block, this one can be used to pack different latents: image_latents, control_image_latents, noise etc

@yiyixuxu yiyixuxu requested review from DN6, sayakpaul and asomoza August 27, 2025 10:46
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@@ -298,4 +307,5 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
# not sure what this is yet.
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding. This one is for?

@@ -838,6 +838,134 @@ def apply_overlay(
return image


class InpaintProcessor(ConfigMixin):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice!

(not for this PR, we could attempt to have an example of the processor for an inpaint pipeline)

@@ -539,8 +540,11 @@ class AutoPipelineBlocks(ModularPipelineBlocks):

def __init__(self):
sub_blocks = InsertableDict()
for block_name, block_cls in zip(self.block_names, self.block_classes):
sub_blocks[block_name] = block_cls()
for block_name, block in zip(self.block_names, self.block_classes):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice.

(nit): Maybe we can have a comment explaining this because otherwise, this pattern seems a bit concerning:

if inspect.isclass(block):
    sub_blocks[block_name] = block()
else:
    sub_blocks[block_name] = block


final_batch_size = block_state.batch_size * block_state.num_images_per_prompt

for input_name in self._latents_input_names:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean ❤️


@property
def description(self) -> str:
return "Step that patchifies latents and expands batch dimension. Works with outputs from QwenImageVaeEncoderDynamicStep."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): would add a line to comment on the dynamic nature of this block.

return latents


class QwenImageDecodeDynamicStep(ModularPipelineBlocks):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the decoding step need to be dynamic? 👀

ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 4.0}),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the QwenImage pipeline, guidance_scale is akin to the one we have in Flux. However, I think we want to enable CFG with this which is done through true_cfg_scale. Should this be taken into account?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good questions,

the true_cfg_scale in flux/qwen is actually just guidance_scale in every other pipeline - it is part of guider and should be set in guider

we had to use a different name (true_cfg_scale) for flux because guidance_scale was already taken to use as an input for distilled model. I think it would have been a lot better if we had gave the distilled guidance a different name so that we can keep the definition of guidance_scale consistent across all pipelines

I'd like to fix it here in modular. IMO It won't confuse user too much because they won't be able to use guidance_scale or true_cfg_scale during runtime in modular as it is, so they will have to take some time to figure out how to use guidance properly and we will have chance to explain.

cc @DN6 @asomoza too, let me know if you have any thoughts around this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would have been a lot better if we had gave the distilled guidance a different name so that we can keep the definition of guidance_scale consistent across all pipelines

I like this point a lot! However, we have guidance_scale in Flux (without the use of the Guider component):

InputParam("guidance_scale", default=3.5),

Maybe we could change that to something better suited (something like distilled_guidance_scale). This way, we can keep the meaning of guidance_scale consistent across the pipelines.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completely agree, let's keep the guidance_scale consistent and use a different one for the distilled models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the proposal is that guidance scale would always imply CFG guidance scale?

I would argue that keeping guidance_scale for all guidance methods makes sense since it implies how large of a step you want take in the guidance direction.

Alternatively we could introduce the concept of a DistilledGuidance guider which is effectively a no-op and it makes it more explicit about exactly what's happening with latents rather than having to introduce new scale parameters, internal checks for negative embeds or checks like self._is_cfg_enabled?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so ideally, we should always use the same for all the models, which is to just consider the guidance_scale for using or not cfg, if you provide or not a negative prompt shouldn't be a condition for using it. This is true for almost all the pipelines except Flux and QwenImage, but with modular, we should keep the API consistent. That's why is also ideal to use a separate guidance for the distilled models.

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the challenge here is the use case when both true CFG and distilled guidance are used and potentially with different scale (currently allowed in our pipelines)

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this use case is not meaningful (even though we have been functionally allowing)
and yea than I think we will be able to keep the same parameter for both if we only need to use one or another

Copy link
Collaborator

@DN6 DN6 Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so ideally, we should always use the same for all the models, which is to just consider the guidance_scale for using or not cfg, if you provide or not a negative prompt shouldn't be a condition for using it. This is true for almost all the pipelines except Flux and QwenImage, but with modular, we should keep the API consistent. That's why is also ideal to use a separate guidance for the distilled models.

Hmm so then perhaps introduce DistilledGuidance Guider? It's very clear then what is being applied to the model? And if you want to swap out the method, just change the Guider and we don't have to introduce an additional parameter term? And like @yiyixuxu said, doing both Distilled Guidance and CFG Guidance simultaneously is probably not doing anything meaningful.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have seen people using both at the same time though, both CFG and distilled guidance with different scale
cc @linoytsaban to confirm & educate us a bit, since she is the OG of true_cfg

@DN6 DN6 added the roadmap Add to current release roadmap label Sep 2, 2025
@jferments
Copy link

Will there be any features added soon for training Qwen models with HF libraries? What are the major barriers right now to making this happen?

@sayakpaul
Copy link
Member

This is not the right PR to discuss Qwen training.

You can check out https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_qwen.md as well as https://github.com/ostris/ai-toolkit for training with HF libs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

6 participants