Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions examples/optimal-performance/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,25 @@ def image_generation_process(
num_inference_steps=50,
)

start = 100
while True:
try:
start_time = time.time()

x_outputs = stream.stream.txt2img_sd_turbo(batch_size).cpu()
start += 50
# None is the default action, fully random
noise = None

# This will generate seeds the change every iteration (but the video will be the same each time it is run)
# noise = list(range(start, start+batch_size))

# This will generate a constant image stream
# noise = [808] * batch_size

# You can uncomment this with either of the noise lists to generator and pre-process a noise tensor
# noise = stream.stream.noise_from_seeds(noise).neg()

x_outputs = stream.stream.txt2img_sd_turbo(batch_size, noise).cpu()
queue.put(x_outputs, block=False)

fps = 1 / (time.time() - start_time) * batch_size
Expand Down Expand Up @@ -168,7 +182,7 @@ def receive_images(queue: Queue, fps_queue: Queue) -> None:
def main(
prompt: str = "cat with sunglasses and a hat, photoreal, 8K",
model_id_or_path: str = "stabilityai/sd-turbo",
batch_size: int = 12,
batch_size: int = 8,
acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt",
) -> None:
"""
Expand Down
9 changes: 8 additions & 1 deletion examples/optimal-performance/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,14 @@ def image_generation_process(
try:
start_time = time.time()

x_outputs = stream.stream.txt2img_sd_turbo(1).cpu()
# Pick a supported way to seed it, or leave fully random
# (Longer descriptions in multi.py)
noise = None
# noise = 555
# noise = [1234]
# noise = stream.stream.noise_from_seeds(noise).neg()

x_outputs = stream.stream.txt2img_sd_turbo(1, noise).cpu()
queue.put(x_outputs, block=False)

fps = 1 / (time.time() - start_time)
Expand Down
90 changes: 73 additions & 17 deletions src/streamdiffusion/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,13 @@ def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor:

@torch.no_grad()
def __call__(
self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray] = None
self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray] = None,
noise: Optional[Union[torch.Tensor, int, List[int]]] = None
) -> torch.Tensor:
"""
if x is not None, do img2img with similar image filter
if x is None, do txt2img, with batch_size 1, optionally setting the noise
"""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
Expand All @@ -454,10 +459,7 @@ def __call__(
return self.prev_image_result
x_t_latent = self.encode_image(x)
else:
# TODO: check the dimension of x_t_latent
x_t_latent = torch.randn((1, 4, self.latent_height, self.latent_width)).to(
device=self.device, dtype=self.dtype
)
x_t_latent = self.build_x_t_latent(batch_size=1, noise=noise)
x_0_pred_out = self.predict_x0_batch(x_t_latent)
x_output = self.decode_image(x_0_pred_out).detach().clone()

Expand All @@ -469,21 +471,17 @@ def __call__(
return x_output

@torch.no_grad()
def txt2img(self, batch_size: int = 1) -> torch.Tensor:
x_0_pred_out = self.predict_x0_batch(
torch.randn((batch_size, 4, self.latent_height, self.latent_width)).to(
device=self.device, dtype=self.dtype
)
)
def txt2img(self, batch_size: int = 1, noise: Optional[Union[torch.Tensor, int, List[int]]] = None) -> torch.Tensor:
x_t_latent = self.build_x_t_latent(batch_size, noise)
x_0_pred_out = self.predict_x0_batch(x_t_latent)
x_output = self.decode_image(x_0_pred_out).detach().clone()
return x_output

def txt2img_sd_turbo(self, batch_size: int = 1) -> torch.Tensor:
x_t_latent = torch.randn(
(batch_size, 4, self.latent_height, self.latent_width),
device=self.device,
dtype=self.dtype,
)
def txt2img_sd_turbo(self,
batch_size: int = 1,
noise: Optional[Union[torch.Tensor, int, List[int]]] = None
) -> torch.Tensor:
x_t_latent = self.build_x_t_latent(batch_size, noise)
model_pred = self.unet(
x_t_latent,
self.sub_timesteps_tensor,
Expand All @@ -494,3 +492,61 @@ def txt2img_sd_turbo(self, batch_size: int = 1) -> torch.Tensor:
x_t_latent - self.beta_prod_t_sqrt * model_pred
) / self.alpha_prod_t_sqrt
return self.decode_image(x_0_pred_out)

def noise_size(self, batch_size) -> torch.Size:
return torch.Size((batch_size, 4, self.latent_height, self.latent_width))

def noise_from_seeds(self, seeds: List[int]) -> torch.Tensor:
"""
This generates a seeded noise vector for all batch items.
seeds should be of length `batch_size`
Each image will we seeded with the proper seeded noise vector
"""
tensors = [torch.randn(
(4, self.latent_height, self.latent_width),
device=self.device,
dtype=self.dtype,
generator=torch.Generator(device=self.device).manual_seed(seed))
for seed in seeds]
return torch.stack(tensors)

def build_x_t_latent(self,
batch_size: int,
noise: Optional[Union[torch.Tensor, int, List[int]]] = None) -> torch.Tensor:
"""
noise: Optional[Union[torch.Tensor, int, List[int]]]
This is the original noise latent to use for generation.
If a Tensor, it must be the proper size.
If List[int], it must be of length `batch_size` and will be used as random seeds for each image
int is a shorthand for List[int] when batch_size is 1
Throws an exception on improper usage (wrong type or size)
"""
if batch_size < 1:
raise Exception(f"Batch size cannot be {batch_size}")

# Treat an int as list of size 1
if isinstance(noise, int):
noise = [noise]

# Cover remaining cases
if isinstance(noise, list):
if len(noise) != batch_size:
raise Exception(f"Noise list size {len(noise)} but batch size is {batch_size}")
if not isinstance(noise[0], int):
name = type(noise[0]).__name__
raise Exception(f"Expected List[int] got List[{name}]")
return self.noise_from_seeds(noise)
elif isinstance(noise, torch.Tensor):
expected_noise = self.noise_size(batch_size)
if noise.shape != expected_noise:
raise Exception(f"Expected noise of {expected_noise}, got {noise.shape}")
return noise
elif noise is None:
return torch.randn(
self.noise_size(batch_size),
device=self.device,
dtype=self.dtype,
)
else:
name = type(noise).__name__
raise Exception(f"Unsupported type of noise: {name}")
37 changes: 30 additions & 7 deletions utils/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def __call__(
self,
image: Optional[Union[str, Image.Image, torch.Tensor]] = None,
prompt: Optional[str] = None,
noise: Optional[Union[torch.Tensor, int, List[int]]] = None,
) -> Union[Image.Image, List[Image.Image]]:
"""
Performs img2img or txt2img based on the mode.
Expand All @@ -218,19 +219,27 @@ def __call__(
The image to generate from.
prompt : Optional[str]
The prompt to generate images from.
noise: Optional[Union[torch.Tensor, int, List[int]]]
This is the original noise latent to use for generation.
If a Tensor, it must be the proper size.
If List[int], it must be of length `batch_size` and will be used as random seeds for each image
int is a shorthand for List[int] when batch_size is 1
Throws an exception on improper usage (wrong type or size)

Returns
-------
Union[Image.Image, List[Image.Image]]
The generated image.
"""
if self.mode == "img2img":
return self.img2img(image, prompt)
return self.img2img(image, prompt, noise)
else:
return self.txt2img(prompt)
return self.txt2img(prompt, noise)

def txt2img(
self, prompt: Optional[str] = None
self,
prompt: Optional[str] = None,
noise: Optional[Union[torch.Tensor, int, List[int]]] = None,
) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]:
"""
Performs txt2img.
Expand All @@ -239,6 +248,12 @@ def txt2img(
----------
prompt : Optional[str]
The prompt to generate images from.
noise: Optional[Union[torch.Tensor, int, List[int]]]
This is the original noise latent to use for generation.
If a Tensor, it must be the proper size.
If List[int], it must be of length `batch_size` and will be used as random seeds for each image
int is a shorthand for List[int] when batch_size is 1
Throws an exception on improper usage (wrong type or size)

Returns
-------
Expand All @@ -249,9 +264,9 @@ def txt2img(
self.stream.update_prompt(prompt)

if self.sd_turbo:
image_tensor = self.stream.txt2img_sd_turbo(self.batch_size)
image_tensor = self.stream.txt2img_sd_turbo(self.batch_size, noise)
else:
image_tensor = self.stream.txt2img(self.frame_buffer_size)
image_tensor = self.stream.txt2img(self.frame_buffer_size, noise)
image = self.postprocess_image(image_tensor, output_type=self.output_type)

if self.use_safety_checker:
Expand All @@ -267,7 +282,9 @@ def txt2img(
return image

def img2img(
self, image: Union[str, Image.Image, torch.Tensor], prompt: Optional[str] = None
self,
image: Union[str, Image.Image, torch.Tensor], prompt: Optional[str] = None,
noise: Optional[Union[torch.Tensor, int, List[int]]] = None,
) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]:
"""
Performs img2img.
Expand All @@ -276,6 +293,12 @@ def img2img(
----------
image : Union[str, Image.Image, torch.Tensor]
The image to generate from.
noise: Optional[Union[torch.Tensor, int, List[int]]]
This is the original noise latent to use for generation - this will be mixed in with the source image rather than random noise
If a Tensor, it must be the proper size.
If List[int], it must be of length `batch_size` and will be used as random seeds for each image
int is a shorthand for List[int] when batch_size is 1
Throws an exception on improper usage (wrong type or size)

Returns
-------
Expand All @@ -288,7 +311,7 @@ def img2img(
if isinstance(image, str) or isinstance(image, Image.Image):
image = self.preprocess_image(image)

image_tensor = self.stream(image)
image_tensor = self.stream(image, noise)
image = self.postprocess_image(image_tensor, output_type=self.output_type)

if self.use_safety_checker:
Expand Down