diff --git a/.gitignore b/.gitignore index dd2501a..15b51f2 100644 --- a/.gitignore +++ b/.gitignore @@ -128,3 +128,5 @@ dmypy.json # Pyre type checker .pyre/ + +custom_jupyters/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d3ce4a..a5c3a02 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,11 @@ -v1.0.2rc +v1.1.0rc ------- - uses separated tokenizer_path to init tokenizer in T5Embedder +- pytorch2 support +- more effective attention +- freeU support, adding more details +- SD-upscaler: fix normalization for new version of diffusers v1.0.1 ------ diff --git a/README.md b/README.md index 68b6d2b..6b38d41 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ We introduce DeepFloyd IF, a novel state-of-the-art open-source text-to-image mo [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/DeepFloyd/IF) ```shell -pip install deepfloyd_if==1.0.2rc0 +pip install deepfloyd_if==1.0.2rc1 pip install xformers==0.0.16 pip install git+https://github.com/openai/CLIP.git --no-deps ``` diff --git a/deepfloyd_if/model/nn.py b/deepfloyd_if/model/nn.py index 4f1a0f0..b2170c4 100644 --- a/deepfloyd_if/model/nn.py +++ b/deepfloyd_if/model/nn.py @@ -88,17 +88,21 @@ def shape(x): # (bs*n_heads, class_token_length, length+class_token_length): scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) - weight = torch.einsum( - 'bct,bcs->bts', q * scale, k * scale - ) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - # (bs*n_heads, dim_per_head, class_token_length) - a = torch.einsum('bts,bcs->bct', weight, v) + if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): + q, k, v = map(lambda t: t.permute(0, 2, 1), (q, k, v)) + a = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, + is_causal=False) + a = a.permute(0, 2, 1) + else: + weight = torch.einsum( + 'bct,bcs->bts', q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum('bts,bcs->bct', weight, v) # (bs, length+1, width) a = a.reshape(bs, -1, 1).transpose(1, 2) - return a[:, 0, :] # cls_token diff --git a/deepfloyd_if/model/unet.py b/deepfloyd_if/model/unet.py index bb83590..74a1cf5 100644 --- a/deepfloyd_if/model/unet.py +++ b/deepfloyd_if/model/unet.py @@ -310,7 +310,13 @@ def forward(self, qkv, encoder_kv=None): k = torch.cat([ek, k], dim=-1) v = torch.cat([ev, v], dim=-1) scale = 1 / math.sqrt(math.sqrt(ch)) - if _FORCE_MEM_EFFICIENT_ATTN: + + if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): + q, k, v = map(lambda t: t.permute(0, 2, 1), (q, k, v)) + a = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, + is_causal=False) + a = a.permute(0, 2, 1) + elif _FORCE_MEM_EFFICIENT_ATTN: q, k, v = map(lambda t: t.permute(0, 2, 1).contiguous(), (q, k, v)) a = memory_efficient_attention(q, k, v) a = a.permute(0, 2, 1) @@ -625,7 +631,8 @@ def __init__( self.cache = None - def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None, use_cache=False, **kwargs): + def forward(self, x, timesteps, text_emb, timestep_text_emb=None, free_ub=None, free_us=None, free_ur=0.2, + aug_emb=None, use_cache=False, **kwargs): hs = [] emb = self.time_embed(timestep_embedding(timesteps, self.model_channels, dtype=self.dtype)) @@ -654,12 +661,31 @@ def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None, hs.append(h) h = self.middle_block(h, emb, encoder_out) for module in self.output_blocks: - h = torch.cat([h, hs.pop()], dim=1) + h = torch.cat([self._h_apply(h, free_ub), self._hs_apply(hs.pop(), free_us, free_ur)], dim=1) h = module(h, emb, encoder_out) h = h.type(self.dtype) h = self.out(h) return h + @staticmethod + def _hs_apply(hs, free_us, free_ur=0.2): + if free_us is None: + return hs + dtype = hs.dtype + hs = torch.fft.fft(hs) + hs_ind = hs.imag.abs() < free_ur*np.pi + hs[hs_ind] = free_us * hs[hs_ind] + hs = torch.fft.ifft(hs) + return hs.to(dtype=dtype) + + @staticmethod + def _h_apply(h, free_ub): + if free_ub is None: + return h + ch = h.shape[1] + h[:, :ch//2] = h[:, :ch//2]*free_ub + return h + class SuperResUNetModel(UNetModel): """ diff --git a/deepfloyd_if/modules/base.py b/deepfloyd_if/modules/base.py index c808a3c..2e24f86 100644 --- a/deepfloyd_if/modules/base.py +++ b/deepfloyd_if/modules/base.py @@ -80,6 +80,9 @@ def embeddings_to_image( guidance_scale=7.0, aug_level=0.25, positive_mixer=0.15, + free_us=None, + free_ub=None, + free_ur=0.2, blur_sigma=None, img_size=None, img_scale=4.0, @@ -101,7 +104,7 @@ def embeddings_to_image( def model_fn(x_t, ts, **kwargs): half = x_t[: len(x_t) // bs_scale] combined = torch.cat([half]*bs_scale, dim=0) - model_out = self.model(combined, ts, **kwargs) + model_out = self.model(combined, ts, free_us=free_us, free_ub=free_ub, free_ur=free_ur, **kwargs) eps, rest = model_out[:, :3], model_out[:, 3:] if bs_scale == 3: cond_eps, pos_cond_eps, uncond_eps = torch.split(eps, len(eps) // bs_scale, dim=0) diff --git a/deepfloyd_if/modules/stage_I.py b/deepfloyd_if/modules/stage_I.py index a9c62cc..3f3a12d 100644 --- a/deepfloyd_if/modules/stage_I.py +++ b/deepfloyd_if/modules/stage_I.py @@ -27,7 +27,8 @@ def __init__(self, *args, model_kwargs=None, pil_img_size=64, **kwargs): def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1, dynamic_thresholding_p=0.95, sample_loop='ddpm', positive_mixer=0.25, sample_timestep_respacing='150', dynamic_thresholding_c=1.5, guidance_scale=7.0, - aspect_ratio='1:1', progress=True, seed=None, sample_fn=None, **kwargs): + free_us=0.98, free_ub=1.03, free_ur=0.2, aspect_ratio='1:1', progress=True, seed=None, + img_size=64, sample_fn=None, **kwargs): return super().embeddings_to_image( t5_embs=t5_embs, @@ -40,11 +41,12 @@ def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5_embs=None sample_loop=sample_loop, sample_timestep_respacing=sample_timestep_respacing, guidance_scale=guidance_scale, - img_size=64, + img_size=img_size, aspect_ratio=aspect_ratio, progress=progress, seed=seed, sample_fn=sample_fn, positive_mixer=positive_mixer, + free_us=free_us, free_ub=free_ub, free_ur=free_ur, **kwargs ) diff --git a/deepfloyd_if/modules/stage_II.py b/deepfloyd_if/modules/stage_II.py index d14b838..53b1225 100644 --- a/deepfloyd_if/modules/stage_II.py +++ b/deepfloyd_if/modules/stage_II.py @@ -22,7 +22,7 @@ def embeddings_to_image( self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1, aug_level=0.25, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, sample_loop='ddpm', sample_timestep_respacing='smart50', guidance_scale=4.0, img_scale=4.0, positive_mixer=0.5, - progress=True, seed=None, sample_fn=None, **kwargs): + progress=True, seed=None, sample_fn=None, free_us=1.01, free_ub=1.01, free_ur=0.05, **kwargs): return super().embeddings_to_image( t5_embs=t5_embs, low_res=low_res, @@ -42,5 +42,6 @@ def embeddings_to_image( progress=progress, seed=seed, sample_fn=sample_fn, + free_us=free_us, free_ub=free_ub, free_ur=free_ur, **kwargs ) diff --git a/deepfloyd_if/modules/stage_III_sd_x4.py b/deepfloyd_if/modules/stage_III_sd_x4.py index 307fad2..b602ca9 100644 --- a/deepfloyd_if/modules/stage_III_sd_x4.py +++ b/deepfloyd_if/modules/stage_III_sd_x4.py @@ -77,8 +77,7 @@ def embeddings_to_image( 'output_type': 'pt', } - images = self.model(**metadata).images - + images = self.model(**metadata).images * 2 - 1 sample = self._IFBaseModule__validate_generations(images) return sample, metadata diff --git a/requirements.txt b/requirements.txt index 8fd0cbb..6e29d01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,16 @@ tqdm +ftfy numpy -torch<2.0.0 +torch torchvision +sentencepiece omegaconf matplotlib +huggingface_hub +beautifulsoup4 +accelerate +protobuf Pillow>=9.2.0 -huggingface_hub>=0.13.2 -transformers~=4.25.1 -accelerate~=0.15.0 -diffusers~=0.16.0 -tokenizers~=0.13.2 -sentencepiece~=0.1.97 -ftfy~=6.1.1 -beautifulsoup4~=4.11.1 +transformers~=4.34.0 +diffusers~=0.21.4 +tokenizers>=0.13.2