From d9449c10ecec55f73028adaad45d0745304787f0 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Thu, 12 Feb 2026 15:19:57 -0800 Subject: [PATCH 1/2] Keep preconditioning coefs in float32, use amp for unet call --- fme/downscaling/models.py | 2 +- fme/downscaling/modules/diffusion_registry.py | 4 +-- fme/downscaling/modules/preconditioners.py | 30 +++++++++++++++---- fme/downscaling/modules/unet_diffusion.py | 29 +++++------------- 4 files changed, 34 insertions(+), 31 deletions(-) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 47b9615df..6807e4eed 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -99,7 +99,7 @@ class DiffusionModelConfig: num_diffusion_generation_steps: Number of diffusion generation steps use_fine_topography: Whether to use fine topography in the model. use_amp_bf16: Whether to use automatic mixed precision (bfloat16) in the - UNetDiffusionModule. + UNet forward pass within EDMPrecond (after sigma and c values are computed). """ module: DiffusionModuleRegistrySelector diff --git a/fme/downscaling/modules/diffusion_registry.py b/fme/downscaling/modules/diffusion_registry.py index b5a66668a..d64d41afd 100644 --- a/fme/downscaling/modules/diffusion_registry.py +++ b/fme/downscaling/modules/diffusion_registry.py @@ -90,8 +90,8 @@ def build( EDMPrecond( unet, sigma_data=sigma_data, + use_amp_bf16=use_amp_bf16, ), - use_amp_bf16=use_amp_bf16, ) @@ -154,8 +154,8 @@ def build( EDMPrecond( unet, sigma_data=sigma_data, + use_amp_bf16=use_amp_bf16, ), - use_amp_bf16=use_amp_bf16, channels_last=self.use_apex_gn, ) return module diff --git a/fme/downscaling/modules/preconditioners.py b/fme/downscaling/modules/preconditioners.py index 6ad4521cc..102621c7e 100644 --- a/fme/downscaling/modules/preconditioners.py +++ b/fme/downscaling/modules/preconditioners.py @@ -6,6 +6,8 @@ # fmt: off # flake8: noqa +import contextlib + # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 @@ -57,6 +59,9 @@ class EDMPrecond(torch.nn.Module): Execute the underlying model at FP16 precision?, by default False. sigma_data : float Expected standard deviation of the training data, by default 0.5. + use_amp_bf16 : bool + Use automatic mixed precision (bfloat16) for the UNet forward pass only, + after sigma and c values are computed and added to x. By default False. Note ---- @@ -71,12 +76,23 @@ def __init__( label_dim=0, use_fp16=False, sigma_data=0.5, + use_amp_bf16=False, ): super().__init__() self.label_dim = label_dim self.use_fp16 = use_fp16 self.sigma_data = sigma_data self.model = model + self.use_amp_bf16 = use_amp_bf16 + if self.use_amp_bf16: + device = get_device() + if device.type == "mps": + raise ValueError("MPS does not support bfloat16 autocast.") + self._amp_context = torch.amp.autocast( + device.type, dtype=torch.bfloat16 + ) + else: + self._amp_context = contextlib.nullcontext() def forward( self, @@ -111,13 +127,15 @@ def forward( if condition is not None: arg = torch.cat([arg, condition], dim=1) - F_x = self.model( - arg.to(dtype), - c_noise.flatten(), - class_labels=class_labels, - ) + model_input_dtype = dtype if not self.use_amp_bf16 else torch.float32 + with self._amp_context: + F_x = self.model( + arg.to(model_input_dtype), + c_noise.flatten(), + class_labels=class_labels, + ) - if (F_x.dtype != dtype) and not _is_autocast_enabled(): + if not self.use_amp_bf16 and (F_x.dtype != dtype) and not _is_autocast_enabled(): raise ValueError( f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." ) diff --git a/fme/downscaling/modules/unet_diffusion.py b/fme/downscaling/modules/unet_diffusion.py index 93a3d2b79..7136c28e0 100644 --- a/fme/downscaling/modules/unet_diffusion.py +++ b/fme/downscaling/modules/unet_diffusion.py @@ -1,5 +1,3 @@ -import contextlib - import torch from fme.core.device import get_device @@ -17,8 +15,7 @@ class UNetDiffusionModule(torch.nn.Module): grid fields. Args: - unet: The U-Net model. - use_amp_bf16: use automatic mixed precision casting to bfloat16 in forward pass + unet: The U-Net model (typically EDMPrecond wrapping a SongUNet). channels_last: Convert input tensors to channels last format. Conversion should only be used for UNet modules compatible with Apex GroupNorm, e.g., `SongUNetv2`. Defaults to False for backwards @@ -28,23 +25,12 @@ class UNetDiffusionModule(torch.nn.Module): def __init__( self, unet: torch.nn.Module, - use_amp_bf16: bool = False, channels_last: bool = False, ): super().__init__() self.unet = unet.to(get_device()) - self.use_amp_bf16 = use_amp_bf16 self._memory_format = torch.channels_last if channels_last is True else None - if self.use_amp_bf16: - if get_device().type == "mps": - raise ValueError("MPS does not support bfloat16 autocast.") - self._amp_context = torch.amp.autocast( - get_device().type, dtype=torch.bfloat16 - ) - else: - self._amp_context = contextlib.nullcontext() - def forward( self, latent: torch.Tensor, @@ -60,10 +46,9 @@ def forward( noise_level: The noise level of each example in the batch. """ device = get_device() - with self._amp_context: - return self.unet( - latent.to(device, memory_format=self._memory_format), - conditioning.to(device, memory_format=self._memory_format), - sigma=noise_level.to(device), - class_labels=None, - ) + return self.unet( + latent.to(device, memory_format=self._memory_format), + conditioning.to(device, memory_format=self._memory_format), + sigma=noise_level.to(device), + class_labels=None, + ) From 39e10aca90ab0864c690b0a4270ac8da4c939d9a Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Thu, 12 Feb 2026 15:31:50 -0800 Subject: [PATCH 2/2] disable autocast for most of precond forward --- fme/downscaling/modules/preconditioners.py | 48 +++++++++++----------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/fme/downscaling/modules/preconditioners.py b/fme/downscaling/modules/preconditioners.py index 6ad4521cc..394e4cd83 100644 --- a/fme/downscaling/modules/preconditioners.py +++ b/fme/downscaling/modules/preconditioners.py @@ -86,30 +86,31 @@ def forward( class_labels=None, force_fp32=False, ): - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) - class_labels = ( - None - if self.label_dim == 0 - else torch.zeros([1, self.label_dim], device=x.device) - if class_labels is None - else class_labels.to(torch.float32).reshape(-1, self.label_dim) - ) - dtype = ( - torch.float16 - if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") - else torch.float32 - ) + with torch.amp.autocast(device_type="cuda", enabled=False): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) - c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) - c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() - c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() - c_noise = sigma.log() / 4 + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 - arg = c_in * x + arg = c_in * x - if condition is not None: - arg = torch.cat([arg, condition], dim=1) + if condition is not None: + arg = torch.cat([arg, condition], dim=1) F_x = self.model( arg.to(dtype), @@ -121,5 +122,6 @@ def forward( raise ValueError( f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." ) - D_x = c_skip * x + c_out * F_x.to(torch.float32) - return D_x + with torch.amp.autocast(device_type=get_device().type, enabled=False): + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x