Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,5 @@ dmypy.json

# Pyre type checker
.pyre/

custom_jupyters/
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
v1.0.2rc
v1.1.0rc
-------

- uses separated tokenizer_path to init tokenizer in T5Embedder
- pytorch2 support
- more effective attention
- freeU support, adding more details
- SD-upscaler: fix normalization for new version of diffusers

v1.0.1
------
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ We introduce DeepFloyd IF, a novel state-of-the-art open-source text-to-image mo
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/DeepFloyd/IF)

```shell
pip install deepfloyd_if==1.0.2rc0
pip install deepfloyd_if==1.0.2rc1
pip install xformers==0.0.16
pip install git+https://github.com/openai/CLIP.git --no-deps
```
Expand Down
18 changes: 11 additions & 7 deletions deepfloyd_if/model/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,21 @@ def shape(x):

# (bs*n_heads, class_token_length, length+class_token_length):
scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
weight = torch.einsum(
'bct,bcs->bts', q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)

# (bs*n_heads, dim_per_head, class_token_length)
a = torch.einsum('bts,bcs->bct', weight, v)
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
q, k, v = map(lambda t: t.permute(0, 2, 1), (q, k, v))
a = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0,
is_causal=False)
a = a.permute(0, 2, 1)
else:
weight = torch.einsum(
'bct,bcs->bts', q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum('bts,bcs->bct', weight, v)

# (bs, length+1, width)
a = a.reshape(bs, -1, 1).transpose(1, 2)

return a[:, 0, :] # cls_token


Expand Down
32 changes: 29 additions & 3 deletions deepfloyd_if/model/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,13 @@ def forward(self, qkv, encoder_kv=None):
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
if _FORCE_MEM_EFFICIENT_ATTN:

if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
q, k, v = map(lambda t: t.permute(0, 2, 1), (q, k, v))
a = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0,
is_causal=False)
a = a.permute(0, 2, 1)
elif _FORCE_MEM_EFFICIENT_ATTN:
q, k, v = map(lambda t: t.permute(0, 2, 1).contiguous(), (q, k, v))
a = memory_efficient_attention(q, k, v)
a = a.permute(0, 2, 1)
Expand Down Expand Up @@ -625,7 +631,8 @@ def __init__(

self.cache = None

def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None, use_cache=False, **kwargs):
def forward(self, x, timesteps, text_emb, timestep_text_emb=None, free_ub=None, free_us=None, free_ur=0.2,
aug_emb=None, use_cache=False, **kwargs):
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels, dtype=self.dtype))

Expand Down Expand Up @@ -654,12 +661,31 @@ def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None,
hs.append(h)
h = self.middle_block(h, emb, encoder_out)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = torch.cat([self._h_apply(h, free_ub), self._hs_apply(hs.pop(), free_us, free_ur)], dim=1)
h = module(h, emb, encoder_out)
h = h.type(self.dtype)
h = self.out(h)
return h

@staticmethod
def _hs_apply(hs, free_us, free_ur=0.2):
if free_us is None:
return hs
dtype = hs.dtype
hs = torch.fft.fft(hs)
hs_ind = hs.imag.abs() < free_ur*np.pi
hs[hs_ind] = free_us * hs[hs_ind]
hs = torch.fft.ifft(hs)
return hs.to(dtype=dtype)

@staticmethod
def _h_apply(h, free_ub):
if free_ub is None:
return h
ch = h.shape[1]
h[:, :ch//2] = h[:, :ch//2]*free_ub
return h


class SuperResUNetModel(UNetModel):
"""
Expand Down
5 changes: 4 additions & 1 deletion deepfloyd_if/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def embeddings_to_image(
guidance_scale=7.0,
aug_level=0.25,
positive_mixer=0.15,
free_us=None,
free_ub=None,
free_ur=0.2,
blur_sigma=None,
img_size=None,
img_scale=4.0,
Expand All @@ -101,7 +104,7 @@ def embeddings_to_image(
def model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // bs_scale]
combined = torch.cat([half]*bs_scale, dim=0)
model_out = self.model(combined, ts, **kwargs)
model_out = self.model(combined, ts, free_us=free_us, free_ub=free_ub, free_ur=free_ur, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
if bs_scale == 3:
cond_eps, pos_cond_eps, uncond_eps = torch.split(eps, len(eps) // bs_scale, dim=0)
Expand Down
6 changes: 4 additions & 2 deletions deepfloyd_if/modules/stage_I.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def __init__(self, *args, model_kwargs=None, pil_img_size=64, **kwargs):
def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None,
batch_repeat=1, dynamic_thresholding_p=0.95, sample_loop='ddpm', positive_mixer=0.25,
sample_timestep_respacing='150', dynamic_thresholding_c=1.5, guidance_scale=7.0,
aspect_ratio='1:1', progress=True, seed=None, sample_fn=None, **kwargs):
free_us=0.98, free_ub=1.03, free_ur=0.2, aspect_ratio='1:1', progress=True, seed=None,
img_size=64, sample_fn=None, **kwargs):

return super().embeddings_to_image(
t5_embs=t5_embs,
Expand All @@ -40,11 +41,12 @@ def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5_embs=None
sample_loop=sample_loop,
sample_timestep_respacing=sample_timestep_respacing,
guidance_scale=guidance_scale,
img_size=64,
img_size=img_size,
aspect_ratio=aspect_ratio,
progress=progress,
seed=seed,
sample_fn=sample_fn,
positive_mixer=positive_mixer,
free_us=free_us, free_ub=free_ub, free_ur=free_ur,
**kwargs
)
3 changes: 2 additions & 1 deletion deepfloyd_if/modules/stage_II.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def embeddings_to_image(
self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1,
aug_level=0.25, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, sample_loop='ddpm',
sample_timestep_respacing='smart50', guidance_scale=4.0, img_scale=4.0, positive_mixer=0.5,
progress=True, seed=None, sample_fn=None, **kwargs):
progress=True, seed=None, sample_fn=None, free_us=1.01, free_ub=1.01, free_ur=0.05, **kwargs):
return super().embeddings_to_image(
t5_embs=t5_embs,
low_res=low_res,
Expand All @@ -42,5 +42,6 @@ def embeddings_to_image(
progress=progress,
seed=seed,
sample_fn=sample_fn,
free_us=free_us, free_ub=free_ub, free_ur=free_ur,
**kwargs
)
3 changes: 1 addition & 2 deletions deepfloyd_if/modules/stage_III_sd_x4.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def embeddings_to_image(
'output_type': 'pt',
}

images = self.model(**metadata).images

images = self.model(**metadata).images * 2 - 1
sample = self._IFBaseModule__validate_generations(images)

return sample, metadata
19 changes: 10 additions & 9 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
tqdm
ftfy
numpy
torch<2.0.0
torch
torchvision
sentencepiece
omegaconf
matplotlib
huggingface_hub
beautifulsoup4
accelerate
protobuf
Pillow>=9.2.0
huggingface_hub>=0.13.2
transformers~=4.25.1
accelerate~=0.15.0
diffusers~=0.16.0
tokenizers~=0.13.2
sentencepiece~=0.1.97
ftfy~=6.1.1
beautifulsoup4~=4.11.1
transformers~=4.34.0
diffusers~=0.21.4
tokenizers>=0.13.2