diff --git a/fme/downscaling/modules/preconditioners.py b/fme/downscaling/modules/preconditioners.py index 6ad4521cc..394e4cd83 100644 --- a/fme/downscaling/modules/preconditioners.py +++ b/fme/downscaling/modules/preconditioners.py @@ -86,30 +86,31 @@ def forward( class_labels=None, force_fp32=False, ): - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) - class_labels = ( - None - if self.label_dim == 0 - else torch.zeros([1, self.label_dim], device=x.device) - if class_labels is None - else class_labels.to(torch.float32).reshape(-1, self.label_dim) - ) - dtype = ( - torch.float16 - if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") - else torch.float32 - ) + with torch.amp.autocast(device_type="cuda", enabled=False): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) - c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) - c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() - c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() - c_noise = sigma.log() / 4 + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 - arg = c_in * x + arg = c_in * x - if condition is not None: - arg = torch.cat([arg, condition], dim=1) + if condition is not None: + arg = torch.cat([arg, condition], dim=1) F_x = self.model( arg.to(dtype), @@ -121,5 +122,6 @@ def forward( raise ValueError( f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." ) - D_x = c_skip * x + c_out * F_x.to(torch.float32) - return D_x + with torch.amp.autocast(device_type=get_device().type, enabled=False): + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x