diff --git a/examples/optimal-performance/multi.py b/examples/optimal-performance/multi.py index c298d59f..d5b41e17 100644 --- a/examples/optimal-performance/multi.py +++ b/examples/optimal-performance/multi.py @@ -85,11 +85,25 @@ def image_generation_process( num_inference_steps=50, ) + start = 100 while True: try: start_time = time.time() - x_outputs = stream.stream.txt2img_sd_turbo(batch_size).cpu() + start += 50 + # None is the default action, fully random + noise = None + + # This will generate seeds the change every iteration (but the video will be the same each time it is run) + # noise = list(range(start, start+batch_size)) + + # This will generate a constant image stream + # noise = [808] * batch_size + + # You can uncomment this with either of the noise lists to generator and pre-process a noise tensor + # noise = stream.stream.noise_from_seeds(noise).neg() + + x_outputs = stream.stream.txt2img_sd_turbo(batch_size, noise).cpu() queue.put(x_outputs, block=False) fps = 1 / (time.time() - start_time) * batch_size @@ -168,7 +182,7 @@ def receive_images(queue: Queue, fps_queue: Queue) -> None: def main( prompt: str = "cat with sunglasses and a hat, photoreal, 8K", model_id_or_path: str = "stabilityai/sd-turbo", - batch_size: int = 12, + batch_size: int = 8, acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", ) -> None: """ diff --git a/examples/optimal-performance/single.py b/examples/optimal-performance/single.py index 65594aa8..2baa56d5 100644 --- a/examples/optimal-performance/single.py +++ b/examples/optimal-performance/single.py @@ -55,7 +55,14 @@ def image_generation_process( try: start_time = time.time() - x_outputs = stream.stream.txt2img_sd_turbo(1).cpu() + # Pick a supported way to seed it, or leave fully random + # (Longer descriptions in multi.py) + noise = None + # noise = 555 + # noise = [1234] + # noise = stream.stream.noise_from_seeds(noise).neg() + + x_outputs = stream.stream.txt2img_sd_turbo(1, noise).cpu() queue.put(x_outputs, block=False) fps = 1 / (time.time() - start_time) diff --git a/src/streamdiffusion/pipeline.py b/src/streamdiffusion/pipeline.py index 66c08c8f..3f2f16e4 100644 --- a/src/streamdiffusion/pipeline.py +++ b/src/streamdiffusion/pipeline.py @@ -438,8 +438,13 @@ def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: @torch.no_grad() def __call__( - self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray] = None + self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray] = None, + noise: Optional[Union[torch.Tensor, int, List[int]]] = None ) -> torch.Tensor: + """ + if x is not None, do img2img with similar image filter + if x is None, do txt2img, with batch_size 1, optionally setting the noise + """ start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() @@ -454,10 +459,7 @@ def __call__( return self.prev_image_result x_t_latent = self.encode_image(x) else: - # TODO: check the dimension of x_t_latent - x_t_latent = torch.randn((1, 4, self.latent_height, self.latent_width)).to( - device=self.device, dtype=self.dtype - ) + x_t_latent = self.build_x_t_latent(batch_size=1, noise=noise) x_0_pred_out = self.predict_x0_batch(x_t_latent) x_output = self.decode_image(x_0_pred_out).detach().clone() @@ -469,21 +471,17 @@ def __call__( return x_output @torch.no_grad() - def txt2img(self, batch_size: int = 1) -> torch.Tensor: - x_0_pred_out = self.predict_x0_batch( - torch.randn((batch_size, 4, self.latent_height, self.latent_width)).to( - device=self.device, dtype=self.dtype - ) - ) + def txt2img(self, batch_size: int = 1, noise: Optional[Union[torch.Tensor, int, List[int]]] = None) -> torch.Tensor: + x_t_latent = self.build_x_t_latent(batch_size, noise) + x_0_pred_out = self.predict_x0_batch(x_t_latent) x_output = self.decode_image(x_0_pred_out).detach().clone() return x_output - def txt2img_sd_turbo(self, batch_size: int = 1) -> torch.Tensor: - x_t_latent = torch.randn( - (batch_size, 4, self.latent_height, self.latent_width), - device=self.device, - dtype=self.dtype, - ) + def txt2img_sd_turbo(self, + batch_size: int = 1, + noise: Optional[Union[torch.Tensor, int, List[int]]] = None + ) -> torch.Tensor: + x_t_latent = self.build_x_t_latent(batch_size, noise) model_pred = self.unet( x_t_latent, self.sub_timesteps_tensor, @@ -494,3 +492,61 @@ def txt2img_sd_turbo(self, batch_size: int = 1) -> torch.Tensor: x_t_latent - self.beta_prod_t_sqrt * model_pred ) / self.alpha_prod_t_sqrt return self.decode_image(x_0_pred_out) + + def noise_size(self, batch_size) -> torch.Size: + return torch.Size((batch_size, 4, self.latent_height, self.latent_width)) + + def noise_from_seeds(self, seeds: List[int]) -> torch.Tensor: + """ + This generates a seeded noise vector for all batch items. + seeds should be of length `batch_size` + Each image will we seeded with the proper seeded noise vector + """ + tensors = [torch.randn( + (4, self.latent_height, self.latent_width), + device=self.device, + dtype=self.dtype, + generator=torch.Generator(device=self.device).manual_seed(seed)) + for seed in seeds] + return torch.stack(tensors) + + def build_x_t_latent(self, + batch_size: int, + noise: Optional[Union[torch.Tensor, int, List[int]]] = None) -> torch.Tensor: + """ + noise: Optional[Union[torch.Tensor, int, List[int]]] + This is the original noise latent to use for generation. + If a Tensor, it must be the proper size. + If List[int], it must be of length `batch_size` and will be used as random seeds for each image + int is a shorthand for List[int] when batch_size is 1 + Throws an exception on improper usage (wrong type or size) + """ + if batch_size < 1: + raise Exception(f"Batch size cannot be {batch_size}") + + # Treat an int as list of size 1 + if isinstance(noise, int): + noise = [noise] + + # Cover remaining cases + if isinstance(noise, list): + if len(noise) != batch_size: + raise Exception(f"Noise list size {len(noise)} but batch size is {batch_size}") + if not isinstance(noise[0], int): + name = type(noise[0]).__name__ + raise Exception(f"Expected List[int] got List[{name}]") + return self.noise_from_seeds(noise) + elif isinstance(noise, torch.Tensor): + expected_noise = self.noise_size(batch_size) + if noise.shape != expected_noise: + raise Exception(f"Expected noise of {expected_noise}, got {noise.shape}") + return noise + elif noise is None: + return torch.randn( + self.noise_size(batch_size), + device=self.device, + dtype=self.dtype, + ) + else: + name = type(noise).__name__ + raise Exception(f"Unsupported type of noise: {name}") diff --git a/utils/wrapper.py b/utils/wrapper.py index cb49bcc2..f343b318 100644 --- a/utils/wrapper.py +++ b/utils/wrapper.py @@ -208,6 +208,7 @@ def __call__( self, image: Optional[Union[str, Image.Image, torch.Tensor]] = None, prompt: Optional[str] = None, + noise: Optional[Union[torch.Tensor, int, List[int]]] = None, ) -> Union[Image.Image, List[Image.Image]]: """ Performs img2img or txt2img based on the mode. @@ -218,6 +219,12 @@ def __call__( The image to generate from. prompt : Optional[str] The prompt to generate images from. + noise: Optional[Union[torch.Tensor, int, List[int]]] + This is the original noise latent to use for generation. + If a Tensor, it must be the proper size. + If List[int], it must be of length `batch_size` and will be used as random seeds for each image + int is a shorthand for List[int] when batch_size is 1 + Throws an exception on improper usage (wrong type or size) Returns ------- @@ -225,12 +232,14 @@ def __call__( The generated image. """ if self.mode == "img2img": - return self.img2img(image, prompt) + return self.img2img(image, prompt, noise) else: - return self.txt2img(prompt) + return self.txt2img(prompt, noise) def txt2img( - self, prompt: Optional[str] = None + self, + prompt: Optional[str] = None, + noise: Optional[Union[torch.Tensor, int, List[int]]] = None, ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]: """ Performs txt2img. @@ -239,6 +248,12 @@ def txt2img( ---------- prompt : Optional[str] The prompt to generate images from. + noise: Optional[Union[torch.Tensor, int, List[int]]] + This is the original noise latent to use for generation. + If a Tensor, it must be the proper size. + If List[int], it must be of length `batch_size` and will be used as random seeds for each image + int is a shorthand for List[int] when batch_size is 1 + Throws an exception on improper usage (wrong type or size) Returns ------- @@ -249,9 +264,9 @@ def txt2img( self.stream.update_prompt(prompt) if self.sd_turbo: - image_tensor = self.stream.txt2img_sd_turbo(self.batch_size) + image_tensor = self.stream.txt2img_sd_turbo(self.batch_size, noise) else: - image_tensor = self.stream.txt2img(self.frame_buffer_size) + image_tensor = self.stream.txt2img(self.frame_buffer_size, noise) image = self.postprocess_image(image_tensor, output_type=self.output_type) if self.use_safety_checker: @@ -267,7 +282,9 @@ def txt2img( return image def img2img( - self, image: Union[str, Image.Image, torch.Tensor], prompt: Optional[str] = None + self, + image: Union[str, Image.Image, torch.Tensor], prompt: Optional[str] = None, + noise: Optional[Union[torch.Tensor, int, List[int]]] = None, ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]: """ Performs img2img. @@ -276,6 +293,12 @@ def img2img( ---------- image : Union[str, Image.Image, torch.Tensor] The image to generate from. + noise: Optional[Union[torch.Tensor, int, List[int]]] + This is the original noise latent to use for generation - this will be mixed in with the source image rather than random noise + If a Tensor, it must be the proper size. + If List[int], it must be of length `batch_size` and will be used as random seeds for each image + int is a shorthand for List[int] when batch_size is 1 + Throws an exception on improper usage (wrong type or size) Returns ------- @@ -288,7 +311,7 @@ def img2img( if isinstance(image, str) or isinstance(image, Image.Image): image = self.preprocess_image(image) - image_tensor = self.stream(image) + image_tensor = self.stream(image, noise) image = self.postprocess_image(image_tensor, output_type=self.output_type) if self.use_safety_checker: