From 8ca8305b486c4f94410fe8f4bb7f6c8ce926f5c6 Mon Sep 17 00:00:00 2001 From: Pierre Pereira Date: Mon, 19 Jun 2023 17:22:21 +0200 Subject: [PATCH] fix: resize the support noise and inpainting mask to match the image size --- deepfloyd_if/modules/base.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/deepfloyd_if/modules/base.py b/deepfloyd_if/modules/base.py index c808a3c..d1f0fbd 100644 --- a/deepfloyd_if/modules/base.py +++ b/deepfloyd_if/modules/base.py @@ -177,6 +177,24 @@ def model_fn(x_t, ts, **kwargs): noise = torch.randn( (batch_size * bs_scale, 3, image_h, image_w), device=self.device, dtype=self.model.dtype) else: + if support_noise.shape != (1, 3, image_h, image_w): + # Resize support noise and mask to image size. + support_noise = torch.nn.functional.interpolate( + support_noise, + size=(image_h, image_w), + mode="bilinear", + align_corners=False, + ) + if inpainting_mask is not None: + inpainting_mask = torch.nn.functional.interpolate( + inpainting_mask.float(), + size=(image_h, image_w), + mode="bilinear", + align_corners=False, + ) + inpainting_mask = inpainting_mask > 0.5 # Back to bool. + inpainting_mask = inpainting_mask.float() # Back to 0/1. + assert support_noise_less_qsample_steps < len(diffusion.timestep_map) - 1 assert support_noise.shape == (1, 3, image_h, image_w) q_sample_steps = torch.tensor([int(len(diffusion.timestep_map) - 1 - support_noise_less_qsample_steps)])