From 4be4c0fbb36909961f5a09a781e4722c530318c1 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Wed, 28 Jan 2026 14:31:58 -0800 Subject: [PATCH 1/3] pass downscale factor through build mthods --- fme/downscaling/data/config.py | 3 +++ fme/downscaling/data/datasets.py | 1 - fme/downscaling/inference/inference.py | 1 + fme/downscaling/inference/output.py | 8 ++++++++ fme/downscaling/samplers.py | 13 +++++++++---- 5 files changed, 21 insertions(+), 5 deletions(-) diff --git a/fme/downscaling/data/config.py b/fme/downscaling/data/config.py index 4ce1f928b..697196b7b 100644 --- a/fme/downscaling/data/config.py +++ b/fme/downscaling/data/config.py @@ -182,6 +182,7 @@ def build_topography( coarse_coords: LatLonCoordinates, requires_topography: bool, static_inputs_from_checkpoint: StaticInputs | None = None, + downscale_factor: int | None = None, ) -> Topography | None: if requires_topography is False: return None @@ -202,11 +203,13 @@ def build_topography( self.lat_extent, full_coarse_coord=coarse_coords.lat, full_fine_coord=topography.coords.lat, + downscale_factor=downscale_factor, ) fine_lon_interval = adjust_fine_coord_range( self.lon_extent, full_coarse_coord=coarse_coords.lon, full_fine_coord=topography.coords.lon, + downscale_factor=downscale_factor, ) subset_topography = topography.subset_latlon( lat_interval=fine_lat_interval, lon_interval=fine_lon_interval diff --git a/fme/downscaling/data/datasets.py b/fme/downscaling/data/datasets.py index eb13893a9..314926b61 100644 --- a/fme/downscaling/data/datasets.py +++ b/fme/downscaling/data/datasets.py @@ -154,7 +154,6 @@ def __init__( f"lon wraparound not implemented, received lon_min {lon_min} but " f"expected lon_min < {self.lon_interval.start + 360.0}" ) - assert lats.numel() > 0, "No latitudes found in the specified range." assert lons.numel() > 0, "No longitudes found in the specified range." diff --git a/fme/downscaling/inference/inference.py b/fme/downscaling/inference/inference.py index cf1ce055c..5143a1943 100644 --- a/fme/downscaling/inference/inference.py +++ b/fme/downscaling/inference/inference.py @@ -243,6 +243,7 @@ def build(self) -> Downscaler: requirements=self.model.data_requirements, patch=self.patch, static_inputs_from_checkpoint=model.static_inputs, + downscale_factor=model.downscale_factor, ) for output_cfg in self.outputs ] diff --git a/fme/downscaling/inference/output.py b/fme/downscaling/inference/output.py index a794f103b..ec5da7d9c 100644 --- a/fme/downscaling/inference/output.py +++ b/fme/downscaling/inference/output.py @@ -217,6 +217,7 @@ def _build_gridded_data( requirements: DataRequirements, dist: Distributed | None = None, static_inputs_from_checkpoint: StaticInputs | None = None, + downscale_factor: int | None = None, ) -> SliceWorkItemGriddedData: xr_dataset, properties = loader_config.get_xarray_dataset( names=requirements.coarse_names, n_timesteps=1 @@ -232,6 +233,7 @@ def _build_gridded_data( requires_topography=requirements.use_fine_topography, # TODO: update to support full list of static inputs static_inputs_from_checkpoint=static_inputs_from_checkpoint, + downscale_factor=downscale_factor, ) if topography is None: raise ValueError("Topography is required for downscaling generation.") @@ -286,6 +288,7 @@ def _build( patch: PatchPredictionConfig, coarse: list[XarrayDataConfig], static_inputs_from_checkpoint: StaticInputs | None = None, + downscale_factor: int | None = None, ) -> DownscalingOutput: updated_loader_config = self._replace_loader_config( time, @@ -299,6 +302,7 @@ def _build( updated_loader_config, requirements, static_inputs_from_checkpoint=static_inputs_from_checkpoint, + downscale_factor=downscale_factor, ) if self.zarr_chunks is None: @@ -386,6 +390,7 @@ def build( requirements: DataRequirements, patch: PatchPredictionConfig, static_inputs_from_checkpoint: StaticInputs | None = None, + downscale_factor: int | None = None, ) -> DownscalingOutput: # Convert single time to TimeSlice time: Slice | TimeSlice @@ -409,6 +414,7 @@ def build( patch=patch, coarse=coarse, static_inputs_from_checkpoint=static_inputs_from_checkpoint, + downscale_factor=downscale_factor, ) @@ -469,6 +475,7 @@ def build( requirements: DataRequirements, patch: PatchPredictionConfig, static_inputs_from_checkpoint: StaticInputs | None = None, + downscale_factor: int | None = None, ) -> DownscalingOutput: coarse = self._single_xarray_config(loader_config.coarse) return self._build( @@ -480,4 +487,5 @@ def build( patch=patch, coarse=coarse, static_inputs_from_checkpoint=static_inputs_from_checkpoint, + downscale_factor=downscale_factor, ) diff --git a/fme/downscaling/samplers.py b/fme/downscaling/samplers.py index 7c9139d46..078c906d5 100644 --- a/fme/downscaling/samplers.py +++ b/fme/downscaling/samplers.py @@ -104,8 +104,13 @@ def stochastic_sampler( f"{img_lr.shape[0]} vs {latents.shape[0]}." ) + # Use float32 for MPS since it doesn't support float64 + high_precision_dtype = ( + torch.float32 if latents.device.type == "mps" else torch.float64 + ) + # Time step discretization. - step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + step_indices = torch.arange(num_steps, dtype=high_precision_dtype, device=latents.device) t_steps = ( sigma_max ** (1 / rho) + step_indices @@ -119,7 +124,7 @@ def stochastic_sampler( x_lr = img_lr # Main sampling loop. - x_next = latents.to(torch.float64) * t_steps[0] + x_next = latents.to(high_precision_dtype) * t_steps[0] latent_steps = [x_next.to(latents.dtype)] for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 x_cur = x_next @@ -141,7 +146,7 @@ def stochastic_sampler( x_hat_batch, x_lr, t_hat, - ).to(torch.float64) + ).to(high_precision_dtype) d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur @@ -155,7 +160,7 @@ def stochastic_sampler( x_next_batch, x_lr, t_next, - ).to(torch.float64) + ).to(high_precision_dtype) d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) latent_steps.append(x_next.to(latents.dtype)) From cbb508a7abf270defda42e22c59d93065ca6b326 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Wed, 28 Jan 2026 23:26:25 +0000 Subject: [PATCH 2/3] revert samplers changes --- fme/downscaling/samplers.py | 195 ++++++++++-------------------------- 1 file changed, 54 insertions(+), 141 deletions(-) diff --git a/fme/downscaling/samplers.py b/fme/downscaling/samplers.py index 078c906d5..c062fbdba 100644 --- a/fme/downscaling/samplers.py +++ b/fme/downscaling/samplers.py @@ -1,167 +1,80 @@ """ -This file is vendorized from physicsnemo/physicsnemo/utis/generative/stochastic_sampler.py which you can find here: -https://github.com/NVIDIA/physicsnemo/blob/327d9928abc17983ad7aa3df94da9566c197c468/physicsnemo/utils/generative/stochastic_sampler.py +This file is vendorized from edm/generate.py which you can find here: +https://github.com/NVlabs/edm/blob/008a4e5316c8e3bfe61a62f874bddba254295afb/generate.py """ # fmt: off # flake8: noqa -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# This work is licensed under a Creative Commons +# Attribution-NonCommercial-ShareAlike 4.0 International License. +# You should have received a copy of the license along with this +# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ +"""Generate random images using the techniques described in the paper +"Elucidating the Design Space of Diffusion-Based Generative Models".""" -from typing import Callable +from typing import Tuple +import numpy as np import torch -from torch import Tensor - - -def stochastic_sampler( - net: torch.nn.Module, - latents: Tensor, - img_lr: Tensor, - randn_like: Callable[[Tensor], Tensor] = torch.randn_like, - num_steps: int = 18, - sigma_min: float = 0.002, - sigma_max: float = 80.0, - rho: float = 7.0, - S_churn: float = 0.0, - S_min: float = 0.0, - S_max: float = float("inf"), - S_noise: float = 1.0, -) -> Tensor: - """ - Proposed EDM sampler (Algorithm 2) with minor changes to enable - super-resolution and patch-based diffusion. - - Parameters - ---------- - net : torch.nn.Module - The neural network model that generates denoised images from noisy - inputs. - Expected signature: `net(x, x_lr, t_hat)`, - where: - x (torch.Tensor): Noisy input of shape (batch_size, C_out, H, W) - x_lr (torch.Tensor): Conditioning input of shape (batch_size, C_cond, H, W) - t_hat (torch.Tensor): Noise level of shape (batch_size, 1, 1, 1) or scalar - Returns: - torch.Tensor: Denoised prediction of shape (batch_size, C_out, H, W) - latents : Tensor - The latent variables (e.g., noise) used as the initial input for the - sampler. Has shape (batch_size, C_out, img_shape_y, img_shape_x). - img_lr : Tensor - Low-resolution input image for conditioning the super-resolution - process. Must have shape (batch_size, C_lr, img_lr_ shape_y, - img_lr_shape_x). - randn_like : Callable[[Tensor], Tensor] - Function to generate random noise with the same shape as the input - tensor. - By default torch.randn_like. - num_steps : int - Number of time steps for the sampler. By default 18. - sigma_min : float - Minimum noise level. By default 0.002. - sigma_max : float - Maximum noise level. By default 800. - rho : float - Exponent used in the time step discretization. By default 7. - S_churn : float - Churn parameter controlling the level of noise added in each step. By - default 0. - S_min : float - Minimum time step for applying churn. By default 0. - S_max : float - Maximum time step for applying churn. By default float("inf"). - S_noise : float - Noise scaling factor applied during the churn step. By default 1. - - Returns - ------- - Tensor - The final denoised image produced by the sampler. Same shape as - `latents`: (batch_size, C_out, img_shape_y, img_shape_x). - - """ - - # img_lr and latents must also have the same batch_size, otherwise mismatch - # when processed by the network - if img_lr.shape[0] != latents.shape[0]: - raise ValueError( - f"img_lr and latents must have the same batch size, but found " - f"{img_lr.shape[0]} vs {latents.shape[0]}." - ) - - # Use float32 for MPS since it doesn't support float64 - high_precision_dtype = ( - torch.float32 if latents.device.type == "mps" else torch.float64 - ) + +def edm_sampler( + net, latents, coarse, randn_like=torch.randn_like, + num_steps=18, sigma_min=0.002, sigma_max=80.0, rho=7, + S_churn=0.0, S_min=0.0, S_max=float('inf'), S_noise=1, +) -> Tuple[torch.Tensor, torch.Tensor]: # Time step discretization. - step_indices = torch.arange(num_steps, dtype=high_precision_dtype, device=latents.device) - t_steps = ( - sigma_max ** (1 / rho) - + step_indices - / (num_steps - 1) - * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) - ) ** rho - t_steps = torch.cat( - [t_steps, torch.zeros_like(t_steps[:1])] - ) # t_N = 0 - - x_lr = img_lr + step_indices = torch.arange(num_steps, dtype=torch.float32, device=latents.device) + t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho + t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 # Main sampling loop. - x_next = latents.to(high_precision_dtype) * t_steps[0] - latent_steps = [x_next.to(latents.dtype)] - for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_next = latents.to(torch.float32) * t_steps[0] + latent_steps = [x_next] + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 x_cur = x_next + # Increase noise temporarily. - gamma = S_churn / num_steps if S_min <= t_cur <= S_max else 0 + gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 t_hat = t_cur + gamma * t_cur + x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) - x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) - - # Euler step. Perform patching operation on score tensor if patch-based - # generation is used denoised = net(x_hat, t_hat, - # ).to(torch.float64) - - x_hat_batch = (x_hat).to( - latents.device - ) - x_lr = x_lr.to(latents.device) - denoised = net( - x_hat_batch, - x_lr, - t_hat, - ).to(high_precision_dtype) - + # Euler step. + denoised = net(x_hat, coarse, t_hat).to(torch.float32) d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur # Apply 2nd order correction. if i < num_steps - 1: - x_next_batch = (x_next).to( - latents.device - ) - denoised = net( - x_next_batch, - x_lr, - t_next, - ).to(high_precision_dtype) + denoised = net(x_next, coarse, t_next).to(torch.float32) d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) - latent_steps.append(x_next.to(latents.dtype)) - return x_next.to(latents.dtype), latent_steps + + latent_steps.append(x_next) + + return x_next, latent_steps + + +#---------------------------------------------------------------------------- +# Wrapper for torch.Generator that allows specifying a different random seed +# for each sample in a minibatch. + +class StackedRandomGenerator: + def __init__(self, device, seeds): + super().__init__() + self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] + + def randn(self, size, **kwargs): + assert size[0] == len(self.generators) + return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) + + def randn_like(self, input): + return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) + + def randint(self, *args, size, **kwargs): + assert size[0] == len(self.generators) + return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) From a108fa0ad0e95868c02fa77f4ef4d42c58c13449 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Wed, 28 Jan 2026 23:34:25 +0000 Subject: [PATCH 3/3] fix samplers --- fme/downscaling/samplers.py | 190 ++++++++++++++++++++++++++---------- 1 file changed, 136 insertions(+), 54 deletions(-) diff --git a/fme/downscaling/samplers.py b/fme/downscaling/samplers.py index c062fbdba..7c9139d46 100644 --- a/fme/downscaling/samplers.py +++ b/fme/downscaling/samplers.py @@ -1,80 +1,162 @@ """ -This file is vendorized from edm/generate.py which you can find here: -https://github.com/NVlabs/edm/blob/008a4e5316c8e3bfe61a62f874bddba254295afb/generate.py +This file is vendorized from physicsnemo/physicsnemo/utis/generative/stochastic_sampler.py which you can find here: +https://github.com/NVIDIA/physicsnemo/blob/327d9928abc17983ad7aa3df94da9566c197c468/physicsnemo/utils/generative/stochastic_sampler.py """ # fmt: off # flake8: noqa -# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # -# This work is licensed under a Creative Commons -# Attribution-NonCommercial-ShareAlike 4.0 International License. -# You should have received a copy of the license along with this -# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -"""Generate random images using the techniques described in the paper -"Elucidating the Design Space of Diffusion-Based Generative Models".""" -from typing import Tuple +from typing import Callable -import numpy as np import torch +from torch import Tensor + + +def stochastic_sampler( + net: torch.nn.Module, + latents: Tensor, + img_lr: Tensor, + randn_like: Callable[[Tensor], Tensor] = torch.randn_like, + num_steps: int = 18, + sigma_min: float = 0.002, + sigma_max: float = 80.0, + rho: float = 7.0, + S_churn: float = 0.0, + S_min: float = 0.0, + S_max: float = float("inf"), + S_noise: float = 1.0, +) -> Tensor: + """ + Proposed EDM sampler (Algorithm 2) with minor changes to enable + super-resolution and patch-based diffusion. + + Parameters + ---------- + net : torch.nn.Module + The neural network model that generates denoised images from noisy + inputs. + Expected signature: `net(x, x_lr, t_hat)`, + where: + x (torch.Tensor): Noisy input of shape (batch_size, C_out, H, W) + x_lr (torch.Tensor): Conditioning input of shape (batch_size, C_cond, H, W) + t_hat (torch.Tensor): Noise level of shape (batch_size, 1, 1, 1) or scalar + Returns: + torch.Tensor: Denoised prediction of shape (batch_size, C_out, H, W) + latents : Tensor + The latent variables (e.g., noise) used as the initial input for the + sampler. Has shape (batch_size, C_out, img_shape_y, img_shape_x). + img_lr : Tensor + Low-resolution input image for conditioning the super-resolution + process. Must have shape (batch_size, C_lr, img_lr_ shape_y, + img_lr_shape_x). + randn_like : Callable[[Tensor], Tensor] + Function to generate random noise with the same shape as the input + tensor. + By default torch.randn_like. + num_steps : int + Number of time steps for the sampler. By default 18. + sigma_min : float + Minimum noise level. By default 0.002. + sigma_max : float + Maximum noise level. By default 800. + rho : float + Exponent used in the time step discretization. By default 7. + S_churn : float + Churn parameter controlling the level of noise added in each step. By + default 0. + S_min : float + Minimum time step for applying churn. By default 0. + S_max : float + Maximum time step for applying churn. By default float("inf"). + S_noise : float + Noise scaling factor applied during the churn step. By default 1. + + Returns + ------- + Tensor + The final denoised image produced by the sampler. Same shape as + `latents`: (batch_size, C_out, img_shape_y, img_shape_x). + + """ + + # img_lr and latents must also have the same batch_size, otherwise mismatch + # when processed by the network + if img_lr.shape[0] != latents.shape[0]: + raise ValueError( + f"img_lr and latents must have the same batch size, but found " + f"{img_lr.shape[0]} vs {latents.shape[0]}." + ) - -def edm_sampler( - net, latents, coarse, randn_like=torch.randn_like, - num_steps=18, sigma_min=0.002, sigma_max=80.0, rho=7, - S_churn=0.0, S_min=0.0, S_max=float('inf'), S_noise=1, -) -> Tuple[torch.Tensor, torch.Tensor]: # Time step discretization. - step_indices = torch.arange(num_steps, dtype=torch.float32, device=latents.device) - t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho - t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + t_steps = ( + sigma_max ** (1 / rho) + + step_indices + / (num_steps - 1) + * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) + ) ** rho + t_steps = torch.cat( + [t_steps, torch.zeros_like(t_steps[:1])] + ) # t_N = 0 + + x_lr = img_lr # Main sampling loop. - x_next = latents.to(torch.float32) * t_steps[0] - latent_steps = [x_next] - for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_next = latents.to(torch.float64) * t_steps[0] + latent_steps = [x_next.to(latents.dtype)] + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 x_cur = x_next - # Increase noise temporarily. - gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 + gamma = S_churn / num_steps if S_min <= t_cur <= S_max else 0 t_hat = t_cur + gamma * t_cur - x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) - # Euler step. - denoised = net(x_hat, coarse, t_hat).to(torch.float32) + x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) + + # Euler step. Perform patching operation on score tensor if patch-based + # generation is used denoised = net(x_hat, t_hat, + # ).to(torch.float64) + + x_hat_batch = (x_hat).to( + latents.device + ) + x_lr = x_lr.to(latents.device) + denoised = net( + x_hat_batch, + x_lr, + t_hat, + ).to(torch.float64) + d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur # Apply 2nd order correction. if i < num_steps - 1: - denoised = net(x_next, coarse, t_next).to(torch.float32) + x_next_batch = (x_next).to( + latents.device + ) + denoised = net( + x_next_batch, + x_lr, + t_next, + ).to(torch.float64) d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) - - latent_steps.append(x_next) - - return x_next, latent_steps - - -#---------------------------------------------------------------------------- -# Wrapper for torch.Generator that allows specifying a different random seed -# for each sample in a minibatch. - -class StackedRandomGenerator: - def __init__(self, device, seeds): - super().__init__() - self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] - - def randn(self, size, **kwargs): - assert size[0] == len(self.generators) - return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) - - def randn_like(self, input): - return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) - - def randint(self, *args, size, **kwargs): - assert size[0] == len(self.generators) - return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) + latent_steps.append(x_next.to(latents.dtype)) + return x_next.to(latents.dtype), latent_steps