StableDiffusion: Decode latents separately to run larger batches#1150
StableDiffusion: Decode latents separately to run larger batches#1150patrickvonplaten merged 12 commits intohuggingface:mainfrom
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
|
Hey @kig, I don't think we should make this the default as it necessarily makes the execution slower - however it might make a lot of sense to add a |
|
E.g. a code snippet that works with your PR for 8GB RAM GPU but not for the current pipeline implementation? |
|
Yeah, I agree on the from diffusers import StableDiffusionPipeline
import torch
import os
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16)
pipe.enable_attention_slicing()
# Disable safety_checker for testing, it's triggered by noise.
pipe.safety_checker = None
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
for samples in [1, 4, 8, 16, 32]:
print(f"Generating {samples} image{'s' if samples > 1 else ''}")
images = pipe([prompt] * samples, num_inference_steps=1).images
if len(images) != samples:
raise RuntimeError(f"Expected {samples} images, got {len(images)}")I added some simple time.time() profiling around the VAE decode too, from before the decode to after the
$ python test_samples.py
Generating 1 image
100%|█| 1/1 [00:01<00:00, 1.70s/it]
VAE decode elapsed 0.10824728012084961
Generating 4 images
100%|█| 1/1 [00:00<00:00, 2.49it/s]
VAE decode elapsed 0.905648946762085
Generating 8 images
100%|█| 1/1 [00:01<00:00, 1.33s/it]
VAE decode elapsed 2.5215229988098145
Generating 16 images
100%|█| 1/1 [00:02<00:00, 2.72s/it]
Traceback (most recent call last):
File "test_samples.py", line 15, in <module>
images = pipe([prompt] * samples, num_inference_steps=1).images
...
RuntimeError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 8.00 GiB total capacity; 5.53 GiB already allocated; 0 bytes free; 6.54 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
$ python test_samples.py
Generating 1 image
100%|█| 1/1 [00:01<00:00, 1.69s/it]
VAE decode elapsed 0.11412405967712402
Generating 4 images
100%|█| 1/1 [00:00<00:00, 2.51it/s]
VAE decode elapsed 0.4106144905090332
Generating 8 images
100%|█| 1/1 [00:00<00:00, 1.29it/s]
VAE decode elapsed 0.8257348537445068
Generating 16 images
100%|█| 1/1 [00:01<00:00, 1.52s/it]
VAE decode elapsed 1.6556949615478516
Generating 32 images
100%|█| 1/1 [00:08<00:00, 8.05s/it]
VAE decode elapsed 3.727773666381836These scale more or less linearly. The 32 images is starting to hit something though, I had that decode time fluctuate between 3.5 and 8 seconds. Huh, that was not what I was expecting. I thought the full batch at a time would have some small efficiency benefit from avoiding setup work but looks like there's something else in play. |
|
Testing on a 24GB card, the VAE decode time scales linearly, but it runs into an issue with 32 samples. This with the "full batch at a time"-approach. # python test.py
Generating 1 image
100%|█| 1/1 [00:00<00:00, 3.21it/s]
VAE decode elapsed 0.08451700210571289
Generating 4 images
100%|█| 1/1 [00:00<00:00, 3.14it/s]
VAE decode elapsed 0.32993221282958984
Generating 8 images
100%|█| 1/1 [00:00<00:00, 2.01it/s]
VAE decode elapsed 0.6512401103973389
Generating 16 images
100%|█| 1/1 [00:00<00:00, 1.00it/s]
VAE decode elapsed 1.289898157119751
Generating 32 images
100%|█| 1/1 [00:01<00:00, 1.94s/it]
Traceback (most recent call last):
File "test.py", line 15, in <module>
images = pipe([prompt] * samples, num_inference_steps=1).images
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/app/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 403, in __call__
image = self.vae.decode(latents).sample
File "/app/diffusers/src/diffusers/models/vae.py", line 581, in decode
dec = self.decoder(z)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
return forward_call(*input, **kwargs)
File "/app/diffusers/src/diffusers/models/vae.py", line 217, in forward
sample = up_block(sample)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
return forward_call(*input, **kwargs)
File "/app/diffusers/src/diffusers/models/unet_2d_blocks.py", line 1322, in forward
hidden_states = upsampler(hidden_states)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
return forward_call(*input, **kwargs)
File "/app/diffusers/src/diffusers/models/resnet.py", line 58, in forward
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 3918, in interpolate
return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
RuntimeError: upsample_nearest_nhwc only supports output tensors with less than INT_MAX elements[edit] Here doing VAE one image at a time seems to be 15% faster at batch size 8. |
|
Doing VAE one image at a time seems to be 15% faster at batch size 8 and 2% slower at batch size 1. It's a small enough difference that it might be noise. The time difference is 10 ms per image, which doesn't move the needle all that much if an 8-image batch takes 8 seconds to generate. However, you can go a step further and do tiled VAE decode coupled with xformers to render 8k images on a 24GB GPU... Prompt: "a beautiful landscape photograph, 8k", 150 seconds per iteration. 3840x2160 goes 15x faster, so there's some perf cliff there. |
|
btw, I have a speed boost for the decoder here: eliminates a |
|
|
||
| latents = 1 / 0.18215 * latents | ||
| image = self.vae.decode(latents).sample | ||
| image = torch.cat([self.vae.decode(latent).sample for latent in latents.split(1)]) |
There was a problem hiding this comment.
Could we please add a enable_vae_slicing() for this? I don't think all the backends such as MPS like the "for loop". Happy to add this features with a enable_vae_slicing
There was a problem hiding this comment.
this list comprehension is iterating over a regular python array, not indexing into a tensor. I don't think MPS would have trouble with that. a similar pattern is working for me on MPS (split a tensor then cat the splits):
https://github.com/apple/ml-ane-transformers/blob/da64000fa56cc85b0859bc17cb16a3d753b8304a/ane_transformers/reference/multihead_attention.py#L116
There was a problem hiding this comment.
Thanks! Added enable_vae_slicing() in the latest commit.
|
Sorry for the late reply, it's been hectic. I added an Let me know if you prefer it in the pipeline and I can move it there. |
Thanks! I tried making the VAE use xformers attention and that did help with memory use. But the Resnet convolution layers in src/diffusers/models/resnet.py |
|
Hey @kig, Great the API looks very nice to me now :-) Could we do two last things:
Maybe adding two links that might help:
Let me know if you need more pointers :-) |
|
@patrickvonplaten here we go, I added tests and docs. Let me know how they look. |
|
Hey @kig, Awesome job :-) Merging this PR! |
…gingface#1150) * StableDiffusion: Decode latents separately to run larger batches * Move VAE sliced decode under enable_vae_sliced_decode and vae.enable_sliced_decode * Rename sliced_decode to slicing * fix whitespace * fix quality check and repository consistency * VAE slicing tests and documentation * API doc hooks for VAE slicing * reformat vae slicing tests * Skip VAE slicing for one-image batches * Documentation tweaks for VAE slicing Co-authored-by: Ilmari Heikkinen <ilmari@fhtr.org>
…gingface#1150) * StableDiffusion: Decode latents separately to run larger batches * Move VAE sliced decode under enable_vae_sliced_decode and vae.enable_sliced_decode * Rename sliced_decode to slicing * fix whitespace * fix quality check and repository consistency * VAE slicing tests and documentation * API doc hooks for VAE slicing * reformat vae slicing tests * Skip VAE slicing for one-image batches * Documentation tweaks for VAE slicing Co-authored-by: Ilmari Heikkinen <ilmari@fhtr.org>

You can use larger batch sizes if you do the VAE decode one image at a time. Now VAE decode runs on the full batch, which limits 8GB GPU batch size and max throughput.
This PR makes VAE decode run one image at a time. This makes 20-image (512x512) batches possible on 8GB GPUs with fp16.