diff --git a/deepfloyd_if/modules/base.py b/deepfloyd_if/modules/base.py index c808a3c..d1f0fbd 100644 --- a/deepfloyd_if/modules/base.py +++ b/deepfloyd_if/modules/base.py @@ -177,6 +177,24 @@ def model_fn(x_t, ts, **kwargs): noise = torch.randn( (batch_size * bs_scale, 3, image_h, image_w), device=self.device, dtype=self.model.dtype) else: + if support_noise.shape != (1, 3, image_h, image_w): + # Resize support noise and mask to image size. + support_noise = torch.nn.functional.interpolate( + support_noise, + size=(image_h, image_w), + mode="bilinear", + align_corners=False, + ) + if inpainting_mask is not None: + inpainting_mask = torch.nn.functional.interpolate( + inpainting_mask.float(), + size=(image_h, image_w), + mode="bilinear", + align_corners=False, + ) + inpainting_mask = inpainting_mask > 0.5 # Back to bool. + inpainting_mask = inpainting_mask.float() # Back to 0/1. + assert support_noise_less_qsample_steps < len(diffusion.timestep_map) - 1 assert support_noise.shape == (1, 3, image_h, image_w) q_sample_steps = torch.tensor([int(len(diffusion.timestep_map) - 1 - support_noise_less_qsample_steps)])