22
33import torch
44
5- from invokeai .backend .patches .layers .lora_layer import LoRALayer
65from invokeai .backend .patches .layers .base_layer_patch import BaseLayerPatch
6+ from invokeai .backend .patches .layers .lora_layer import LoRALayer
77from invokeai .backend .patches .layers .merged_layer_patch import MergedLayerPatch , Range
8- from invokeai .backend .patches .layers .utils import any_lora_layer_from_state_dict , swap_shift_scale_for_linear_weight , decomposite_weight_matric_with_rank
8+ from invokeai .backend .patches .layers .utils import (
9+ any_lora_layer_from_state_dict ,
10+ decomposite_weight_matric_with_rank ,
11+ swap_shift_scale_for_linear_weight ,
12+ )
913from invokeai .backend .patches .lora_conversions .flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
1014from invokeai .backend .patches .model_patch_raw import ModelPatchRaw
1115
@@ -30,46 +34,47 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
3034
3135 return all_keys_in_peft_format and all_expected_keys_present
3236
37+
3338def approximate_flux_adaLN_lora_layer_from_diffusers_state_dict (state_dict : Dict [str , torch .Tensor ]) -> LoRALayer :
34- ''' Approximate given diffusers AdaLN loRA layer in our Flux model'''
39+ """ Approximate given diffusers AdaLN loRA layer in our Flux model"""
3540
36- if not "lora_up.weight" in state_dict :
41+ if "lora_up.weight" not in state_dict :
3742 raise ValueError (f"Unsupported lora format: { state_dict .keys ()} , missing lora_up" )
38-
39- if not "lora_down.weight" in state_dict :
43+
44+ if "lora_down.weight" not in state_dict :
4045 raise ValueError (f"Unsupported lora format: { state_dict .keys ()} , missing lora_down" )
41-
42- up = state_dict .pop ('lora_up.weight' )
43- down = state_dict .pop ('lora_down.weight' )
4446
45- # layer-patcher upcast things to f32,
47+ up = state_dict .pop ("lora_up.weight" )
48+ down = state_dict .pop ("lora_down.weight" )
49+
50+ # layer-patcher upcast things to f32,
4651 # we want to maintain a better precison for this one
4752 dtype = torch .float32
4853
4954 device = up .device
5055 up_shape = up .shape
5156 down_shape = down .shape
52-
57+
5358 # desired low rank
5459 rank = up_shape [1 ]
5560
5661 # up scaling for more precise
5762 up = up .to (torch .float32 )
5863 down = down .to (torch .float32 )
5964
60- weight = up .reshape (up_shape [0 ], - 1 ) @ down .reshape (down_shape [0 ], - 1 )
65+ weight = up .reshape (up_shape [0 ], - 1 ) @ down .reshape (down_shape [0 ], - 1 )
6166
6267 # swap to our linear format
6368 swapped = swap_shift_scale_for_linear_weight (weight )
6469
6570 _up , _down = decomposite_weight_matric_with_rank (swapped , rank )
6671
67- assert ( _up .shape == up_shape )
68- assert ( _down .shape == down_shape )
72+ assert _up .shape == up_shape
73+ assert _down .shape == down_shape
6974
7075 # down scaling to original dtype, device
71- state_dict [' lora_up.weight' ] = _up .to (dtype ).to (device = device )
72- state_dict [' lora_down.weight' ] = _down .to (dtype ).to (device = device )
76+ state_dict [" lora_up.weight" ] = _up .to (dtype ).to (device = device )
77+ state_dict [" lora_down.weight" ] = _down .to (dtype ).to (device = device )
7378
7479 return LoRALayer .from_state_dict_values (state_dict )
7580
@@ -131,7 +136,7 @@ def add_adaLN_lora_layer_if_present(src_key: str, dst_key: str) -> None:
131136 src_layer_dict = grouped_state_dict .pop (src_key )
132137 values = get_lora_layer_values (src_layer_dict )
133138 layers [dst_key ] = approximate_flux_adaLN_lora_layer_from_diffusers_state_dict (values )
134-
139+
135140 def add_qkv_lora_layer_if_present (
136141 src_keys : list [str ],
137142 src_weight_shapes : list [tuple [int , int ]],
@@ -274,8 +279,8 @@ def add_qkv_lora_layer_if_present(
274279 # Final layer.
275280 add_lora_layer_if_present ("proj_out" , "final_layer.linear" )
276281 add_adaLN_lora_layer_if_present (
277- ' norm_out.linear' ,
278- ' final_layer.adaLN_modulation.1' ,
282+ " norm_out.linear" ,
283+ " final_layer.adaLN_modulation.1" ,
279284 )
280285
281286 # Assert that all keys were processed.
0 commit comments