Skip to content

关于flux第一阶段的训练 #219

@panganqi

Description

@panganqi

我尝试在x-flux的训练框架下加入了pulid的id former和pulid_ca模块,并训练这两个模块, 我的损失函数就是flux的的diffusion loss,结果第一阶段发现完全不收敛,请问有人尝试训过并且成功过吗?我的训练脚本是这样的:
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
_supports_gradient_checkpointing = True

def __init__(self, params: FluxParams):
    super().__init__()

    self.params = params
    self.in_channels = params.in_channels
    self.out_channels = self.in_channels
    if params.hidden_size % params.num_heads != 0:
        raise ValueError(
            f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
        )
    pe_dim = params.hidden_size // params.num_heads
    if sum(params.axes_dim) != pe_dim:
        raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
    self.hidden_size = params.hidden_size
    self.num_heads = params.num_heads
    self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
    self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
    self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
    self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
    self.guidance_in = (
        MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
    )
    self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)

    self.double_blocks = nn.ModuleList(
        [
            DoubleStreamBlock(
                self.hidden_size,
                self.num_heads,
                mlp_ratio=params.mlp_ratio,
                qkv_bias=params.qkv_bias,
            )
            for _ in range(params.depth)
        ]
    )

    self.single_blocks = nn.ModuleList(
        [
            SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
            for _ in range(params.depth_single_blocks)
        ]
    )
    print("DoubleStreamBlock", params.depth, "SingleStreamBlock", params.depth_single_blocks)
    self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)

    self.ip_adapter_proj_model = None

    self.pulid_ca = None # PuLIDPipeline(device='cuda', weight_dtype=torch.bfloat16, onnx_provider='gpu')

    self.gradient_checkpointing = True

def _set_gradient_checkpointing(self, module, value=False):
    if hasattr(module, "gradient_checkpointing"):
        module.gradient_checkpointing = value

def set_pulid_ca(self, state_dict=None):
    self.pulid_ca = PuLIDPipeline()
    self.pulid_ca.load_pretrain('/mnt/gyfs_bj/cchaocchen/PuLID/models/pulid_flux_v0.9.0.safetensors')

def forward(
    self,
    img: Tensor,
    img_ids: Tensor,
    txt: Tensor,
    txt_ids: Tensor,
    timesteps: Tensor,
    y: Tensor,
    block_controlnet_hidden_states=None,
    guidance: Tensor = None,
    image_proj: Tensor  = None, 
    ip_scale: float = 1.0, 
    id: Tensor = None,
    id_cond: Tensor = None,
    id_vit_hidden: Tensor = None,
    pulid2: bool = False,
) -> Tensor:
    if img.ndim != 3 or txt.ndim != 3:
        raise ValueError("Input img and txt tensors must have 3 dimensions.")

    # running on sequences img
    img = self.img_in(img)
    vec = self.time_in(timestep_embedding(timesteps, 256))
    if self.params.guidance_embed:
        if guidance is None:
            raise ValueError("Didn't get guidance strength for guidance distilled model.")
        vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
    vec = vec + self.vector_in(y)
    txt = self.txt_in(txt)

    if self.ip_adapter_proj_model is not None:
        # encode image
        image_proj = self.ip_adapter_proj_model(image_proj)

    ids = torch.cat((txt_ids, img_ids), dim=1)
    pe = self.pe_embedder(ids)
    if block_controlnet_hidden_states is not None:
        controlnet_depth = len(block_controlnet_hidden_states)
    ca_idx = 0
    double_interval = 2
    single_interval = 4

    for index_block, block in enumerate(self.double_blocks):
        if self.training and self.gradient_checkpointing:

            def create_custom_forward(module, return_dict=None):
                def custom_forward(*inputs):
                    if return_dict is not None:
                        return module(*inputs, return_dict=return_dict)
                    else:
                        return module(*inputs)

                return custom_forward

            ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} #if is_torch_version(">=", "1.11.0") else {}
            img, txt = torch.utils.checkpoint.checkpoint(
                create_custom_forward(block),
                img,
                txt,
                vec,
                pe,
                image_proj,
                ip_scale,
                **ckpt_kwargs,
            )
        else:
            img, txt = block(
                img=img, 
                txt=txt, 
                vec=vec, 
                pe=pe, 
                image_proj=image_proj,
                ip_scale=ip_scale, 
            )
        # controlnet residual
        if block_controlnet_hidden_states is not None:
            idx = controlnet_depth
            # img = img + block_controlnet_hidden_states[index_block % idx]
            img = img + block_controlnet_hidden_states[0][index_block % idx] # pose 
            img = img + block_controlnet_hidden_states[1][index_block % idx] # IP

        if self.pulid_ca is not None and id_cond is not None and id_vit_hidden is not None and index_block % double_interval == 0:
            img = self.pulid_ca(ca_idx, id, id_cond, id_vit_hidden, img)
            ca_idx += 1

    img = torch.cat((txt, img), 1)
    # for block in self.single_blocks:
    for index_block, block in enumerate(self.single_blocks):
        if self.training and self.gradient_checkpointing:

            def create_custom_forward(module, return_dict=None):
                def custom_forward(*inputs):
                    if return_dict is not None:
                        return module(*inputs, return_dict=return_dict)
                    else:
                        return module(*inputs)

                return custom_forward

            ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False}# if is_torch_version(">=", "1.11.0") else {}
            img = torch.utils.checkpoint.checkpoint(
                create_custom_forward(block),
                img,
                vec,
                pe,
                **ckpt_kwargs,
            )
        else:
            img = block(img, vec=vec, pe=pe)
        
        # if not pulid2 and self.pulid_ca is not None and id_cond is not None and id_vit_hidden is not None and index_block % single_interval == 0:
        #     img = self.pulid_ca(ca_idx, id, id_cond, id_vit_hidden, img)
        #     ca_idx += 1
        if self.pulid_ca is not None and id_cond is not None and id_vit_hidden is not None and index_block % single_interval == 0:
            real_img, txt = img[:, txt.shape[1]:, ...], img[:, :txt.shape[1], ...]
            real_img = self.pulid_ca(ca_idx, id, id_cond, id_vit_hidden, real_img)
            # real_img = real_img + id_weight * self.pulid_ca[ca_idx](id, real_img)
            ca_idx += 1
            img = torch.cat((txt, real_img), 1)

    img = img[:, txt.shape[1] :, ...]

    img = self.final_layer(img, vec)  # (N, T, patch_size ** 2 * out_channels)
    return img

然后在训练代码中,设置pulid相关的参数可训练,我的学习率在1e-5,用的x-flux官方的训练脚本默认参数
vae.requires_grad_(False)
t5.requires_grad_(False)
clip.requires_grad_(False)
dit.to(accelerator.device)
dit = dit.to(torch.bfloat16)
dit.train()
dit.pulid_ca.requires_grad_(True)
dit.pulid_ca.to(accelerator.device)
dit.gradient_checkpointing = args.gradient_checkpointing

optimizer_cls = torch.optim.AdamW
for n, param in dit.named_parameters():
    if 'pulid' not in n: #  or 'encoder' in n:
        param.requires_grad = False
    else:
        param.requires_grad = True
        print("Init layers: ", n)
print(sum([p.numel() for p in dit.parameters() if p.requires_grad]) / 1000000, 'parameters')
allp = [p for p in dit.parameters() if p.requires_grad]
optimizer = optimizer_cls(
    [p for p in dit.parameters() if p.requires_grad],
    lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2),
    weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon,
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions