From df60d88b2f2a8092ce3d4aa8bbb84f16e45d474f Mon Sep 17 00:00:00 2001 From: Luo-Yihang Date: Mon, 30 Jun 2025 08:29:39 +0000 Subject: [PATCH] fix norm not training in train_control_lora_flux.py --- examples/flux-control/train_control_lora_flux.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py index 3c8b75a08808..53ee0f89e280 100644 --- a/examples/flux-control/train_control_lora_flux.py +++ b/examples/flux-control/train_control_lora_flux.py @@ -837,11 +837,6 @@ def main(args): assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0) flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels) - if args.train_norm_layers: - for name, param in flux_transformer.named_parameters(): - if any(k in name for k in NORM_LAYER_PREFIXES): - param.requires_grad = True - if args.lora_layers is not None: if args.lora_layers != "all-linear": target_modules = [layer.strip() for layer in args.lora_layers.split(",")] @@ -879,6 +874,11 @@ def main(args): ) flux_transformer.add_adapter(transformer_lora_config) + if args.train_norm_layers: + for name, param in flux_transformer.named_parameters(): + if any(k in name for k in NORM_LAYER_PREFIXES): + param.requires_grad = True + def unwrap_model(model): model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model