-
Notifications
You must be signed in to change notification settings - Fork 261
Open
Description
我尝试在x-flux的训练框架下加入了pulid的id former和pulid_ca模块,并训练这两个模块, 我的损失函数就是flux的的diffusion loss,结果第一阶段发现完全不收敛,请问有人尝试训过并且成功过吗?我的训练脚本是这样的:
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
_supports_gradient_checkpointing = True
def __init__(self, params: FluxParams):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)
print("DoubleStreamBlock", params.depth, "SingleStreamBlock", params.depth_single_blocks)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
self.ip_adapter_proj_model = None
self.pulid_ca = None # PuLIDPipeline(device='cuda', weight_dtype=torch.bfloat16, onnx_provider='gpu')
self.gradient_checkpointing = True
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def set_pulid_ca(self, state_dict=None):
self.pulid_ca = PuLIDPipeline()
self.pulid_ca.load_pretrain('/mnt/gyfs_bj/cchaocchen/PuLID/models/pulid_flux_v0.9.0.safetensors')
def forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
block_controlnet_hidden_states=None,
guidance: Tensor = None,
image_proj: Tensor = None,
ip_scale: float = 1.0,
id: Tensor = None,
id_cond: Tensor = None,
id_vit_hidden: Tensor = None,
pulid2: bool = False,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
if self.ip_adapter_proj_model is not None:
# encode image
image_proj = self.ip_adapter_proj_model(image_proj)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
if block_controlnet_hidden_states is not None:
controlnet_depth = len(block_controlnet_hidden_states)
ca_idx = 0
double_interval = 2
single_interval = 4
for index_block, block in enumerate(self.double_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} #if is_torch_version(">=", "1.11.0") else {}
img, txt = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
img,
txt,
vec,
pe,
image_proj,
ip_scale,
**ckpt_kwargs,
)
else:
img, txt = block(
img=img,
txt=txt,
vec=vec,
pe=pe,
image_proj=image_proj,
ip_scale=ip_scale,
)
# controlnet residual
if block_controlnet_hidden_states is not None:
idx = controlnet_depth
# img = img + block_controlnet_hidden_states[index_block % idx]
img = img + block_controlnet_hidden_states[0][index_block % idx] # pose
img = img + block_controlnet_hidden_states[1][index_block % idx] # IP
if self.pulid_ca is not None and id_cond is not None and id_vit_hidden is not None and index_block % double_interval == 0:
img = self.pulid_ca(ca_idx, id, id_cond, id_vit_hidden, img)
ca_idx += 1
img = torch.cat((txt, img), 1)
# for block in self.single_blocks:
for index_block, block in enumerate(self.single_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False}# if is_torch_version(">=", "1.11.0") else {}
img = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
img,
vec,
pe,
**ckpt_kwargs,
)
else:
img = block(img, vec=vec, pe=pe)
# if not pulid2 and self.pulid_ca is not None and id_cond is not None and id_vit_hidden is not None and index_block % single_interval == 0:
# img = self.pulid_ca(ca_idx, id, id_cond, id_vit_hidden, img)
# ca_idx += 1
if self.pulid_ca is not None and id_cond is not None and id_vit_hidden is not None and index_block % single_interval == 0:
real_img, txt = img[:, txt.shape[1]:, ...], img[:, :txt.shape[1], ...]
real_img = self.pulid_ca(ca_idx, id, id_cond, id_vit_hidden, real_img)
# real_img = real_img + id_weight * self.pulid_ca[ca_idx](id, real_img)
ca_idx += 1
img = torch.cat((txt, real_img), 1)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
然后在训练代码中,设置pulid相关的参数可训练,我的学习率在1e-5,用的x-flux官方的训练脚本默认参数
vae.requires_grad_(False)
t5.requires_grad_(False)
clip.requires_grad_(False)
dit.to(accelerator.device)
dit = dit.to(torch.bfloat16)
dit.train()
dit.pulid_ca.requires_grad_(True)
dit.pulid_ca.to(accelerator.device)
dit.gradient_checkpointing = args.gradient_checkpointing
optimizer_cls = torch.optim.AdamW
for n, param in dit.named_parameters():
if 'pulid' not in n: # or 'encoder' in n:
param.requires_grad = False
else:
param.requires_grad = True
print("Init layers: ", n)
print(sum([p.numel() for p in dit.parameters() if p.requires_grad]) / 1000000, 'parameters')
allp = [p for p in dit.parameters() if p.requires_grad]
optimizer = optimizer_cls(
[p for p in dit.parameters() if p.requires_grad],
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels