diff --git a/nodes/patch_lib/FluxPatch.py b/nodes/patch_lib/FluxPatch.py index 31e370f..dc73f78 100644 --- a/nodes/patch_lib/FluxPatch.py +++ b/nodes/patch_lib/FluxPatch.py @@ -174,6 +174,12 @@ def block_wrap(args): if i < len(control_o): add = control_o[i] if add is not None: + img_slice = img[:, txt.shape[1]:, ...] + if img_slice.shape[1] != add.shape[1]: + padding_size = img_slice.shape[1] - add.shape[1] + if padding_size > 0: + padding = torch.zeros(add.shape[0], padding_size, add.shape[2], device=add.device, dtype=add.dtype) + add = torch.cat([add, padding], dim=1) img[:, txt.shape[1]:, ...] += add return img @@ -270,6 +276,11 @@ def block_wrap(args): if i < len(control_i): add = control_i[i] if add is not None: + if img.shape[1] != add.shape[1]: + padding_size = img.shape[1] - add.shape[1] + if padding_size > 0: + padding = torch.zeros(add.shape[0], padding_size, add.shape[2], device=add.device, dtype=add.dtype) + add = torch.cat([add, padding], dim=1) img += add del blocks_replace