1- from typing import Dict
1+ from typing import Dict , Tuple
22
33import torch
44
1010from invokeai .backend .patches .layers .lokr_layer import LoKRLayer
1111from invokeai .backend .patches .layers .lora_layer import LoRALayer
1212from invokeai .backend .patches .layers .norm_layer import NormLayer
13- from invokeai .backend .patches .layers .diffusers_ada_ln_lora_layer import DiffusersAdaLN_LoRALayer
1413
1514
1615def any_lora_layer_from_state_dict (state_dict : Dict [str , torch .Tensor ]) -> BaseLayerPatch :
@@ -36,8 +35,70 @@ def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseL
3635 raise ValueError (f"Unsupported lora format: { state_dict .keys ()} " )
3736
3837
39- def diffusers_adaLN_lora_layer_from_state_dict (state_dict : Dict [str , torch .Tensor ]) -> DiffusersAdaLN_LoRALayer :
38+
39+ def swap_shift_scale_for_linear_weight (weight : torch .Tensor ) -> torch .Tensor :
40+ """Swap shift/scale for given linear layer back and forth"""
41+ # In SD3 and Flux implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
42+ # while in diffusers it split into scale, shift. This will flip them around
43+ chunk1 , chunk2 = weight .chunk (2 , dim = 0 )
44+ return torch .cat ([chunk2 , chunk1 ], dim = 0 )
45+
46+ def decomposite_weight_matric_with_rank (
47+ delta : torch .Tensor ,
48+ rank : int ,
49+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
50+ """Decompose given matrix with a specified rank."""
51+ U , S , V = torch .svd (delta )
52+
53+ # Truncate to rank r:
54+ U_r = U [:, :rank ]
55+ S_r = S [:rank ]
56+ V_r = V [:, :rank ]
57+
58+ S_sqrt = torch .sqrt (S_r )
59+
60+ up = torch .matmul (U_r , torch .diag (S_sqrt ))
61+ down = torch .matmul (torch .diag (S_sqrt ), V_r .T )
62+
63+ return up , down
64+
65+
66+ def approximate_flux_adaLN_lora_layer_from_diffusers_state_dict (state_dict : Dict [str , torch .Tensor ]) -> LoRALayer :
67+ '''Approximate given diffusers AdaLN loRA layer in our Flux model'''
68+
4069 if not "lora_up.weight" in state_dict :
41- raise ValueError (f"Unsupported lora format: { state_dict .keys ()} " )
70+ raise ValueError (f"Unsupported lora format: { state_dict .keys ()} , missing lora_up " )
4271
43- return DiffusersAdaLN_LoRALayer .from_state_dict_values (state_dict )
72+ if not "lora_down.weight" in state_dict :
73+ raise ValueError (f"Unsupported lora format: { state_dict .keys ()} , missing lora_down" )
74+
75+ up = state_dict .pop ('lora_up.weight' )
76+ down = state_dict .pop ('lora_down.weight' )
77+
78+ dtype = up .dtype
79+ device = up .device
80+ up_shape = up .shape
81+ down_shape = down .shape
82+
83+ # desired low rank
84+ rank = up_shape [1 ]
85+
86+ # up scaling for more precise
87+ up .double ()
88+ down .double ()
89+ weight = up .reshape (up .shape [0 ], - 1 ) @ down .reshape (down .shape [0 ], - 1 )
90+
91+ # swap to our linear format
92+ swapped = swap_shift_scale_for_linear_weight (weight )
93+
94+ _up , _down = decomposite_weight_matric_with_rank (swapped , rank )
95+
96+ assert (_up .shape == up_shape )
97+ assert (_down .shape == down_shape )
98+
99+ # down scaling to original dtype, device
100+ state_dict ['lora_up.weight' ] = _up .to (dtype ).to (device = device )
101+ state_dict ['lora_down.weight' ] = _down .to (dtype ).to (device = device )
102+
103+ return LoRALayer .from_state_dict_values (state_dict )
104+
0 commit comments