diff --git a/fme/downscaling/data/config.py b/fme/downscaling/data/config.py index 4ce1f928b..697196b7b 100644 --- a/fme/downscaling/data/config.py +++ b/fme/downscaling/data/config.py @@ -182,6 +182,7 @@ def build_topography( coarse_coords: LatLonCoordinates, requires_topography: bool, static_inputs_from_checkpoint: StaticInputs | None = None, + downscale_factor: int | None = None, ) -> Topography | None: if requires_topography is False: return None @@ -202,11 +203,13 @@ def build_topography( self.lat_extent, full_coarse_coord=coarse_coords.lat, full_fine_coord=topography.coords.lat, + downscale_factor=downscale_factor, ) fine_lon_interval = adjust_fine_coord_range( self.lon_extent, full_coarse_coord=coarse_coords.lon, full_fine_coord=topography.coords.lon, + downscale_factor=downscale_factor, ) subset_topography = topography.subset_latlon( lat_interval=fine_lat_interval, lon_interval=fine_lon_interval diff --git a/fme/downscaling/data/datasets.py b/fme/downscaling/data/datasets.py index eb13893a9..314926b61 100644 --- a/fme/downscaling/data/datasets.py +++ b/fme/downscaling/data/datasets.py @@ -154,7 +154,6 @@ def __init__( f"lon wraparound not implemented, received lon_min {lon_min} but " f"expected lon_min < {self.lon_interval.start + 360.0}" ) - assert lats.numel() > 0, "No latitudes found in the specified range." assert lons.numel() > 0, "No longitudes found in the specified range." diff --git a/fme/downscaling/inference/inference.py b/fme/downscaling/inference/inference.py index cf1ce055c..5143a1943 100644 --- a/fme/downscaling/inference/inference.py +++ b/fme/downscaling/inference/inference.py @@ -243,6 +243,7 @@ def build(self) -> Downscaler: requirements=self.model.data_requirements, patch=self.patch, static_inputs_from_checkpoint=model.static_inputs, + downscale_factor=model.downscale_factor, ) for output_cfg in self.outputs ] diff --git a/fme/downscaling/inference/output.py b/fme/downscaling/inference/output.py index a794f103b..ec5da7d9c 100644 --- a/fme/downscaling/inference/output.py +++ b/fme/downscaling/inference/output.py @@ -217,6 +217,7 @@ def _build_gridded_data( requirements: DataRequirements, dist: Distributed | None = None, static_inputs_from_checkpoint: StaticInputs | None = None, + downscale_factor: int | None = None, ) -> SliceWorkItemGriddedData: xr_dataset, properties = loader_config.get_xarray_dataset( names=requirements.coarse_names, n_timesteps=1 @@ -232,6 +233,7 @@ def _build_gridded_data( requires_topography=requirements.use_fine_topography, # TODO: update to support full list of static inputs static_inputs_from_checkpoint=static_inputs_from_checkpoint, + downscale_factor=downscale_factor, ) if topography is None: raise ValueError("Topography is required for downscaling generation.") @@ -286,6 +288,7 @@ def _build( patch: PatchPredictionConfig, coarse: list[XarrayDataConfig], static_inputs_from_checkpoint: StaticInputs | None = None, + downscale_factor: int | None = None, ) -> DownscalingOutput: updated_loader_config = self._replace_loader_config( time, @@ -299,6 +302,7 @@ def _build( updated_loader_config, requirements, static_inputs_from_checkpoint=static_inputs_from_checkpoint, + downscale_factor=downscale_factor, ) if self.zarr_chunks is None: @@ -386,6 +390,7 @@ def build( requirements: DataRequirements, patch: PatchPredictionConfig, static_inputs_from_checkpoint: StaticInputs | None = None, + downscale_factor: int | None = None, ) -> DownscalingOutput: # Convert single time to TimeSlice time: Slice | TimeSlice @@ -409,6 +414,7 @@ def build( patch=patch, coarse=coarse, static_inputs_from_checkpoint=static_inputs_from_checkpoint, + downscale_factor=downscale_factor, ) @@ -469,6 +475,7 @@ def build( requirements: DataRequirements, patch: PatchPredictionConfig, static_inputs_from_checkpoint: StaticInputs | None = None, + downscale_factor: int | None = None, ) -> DownscalingOutput: coarse = self._single_xarray_config(loader_config.coarse) return self._build( @@ -480,4 +487,5 @@ def build( patch=patch, coarse=coarse, static_inputs_from_checkpoint=static_inputs_from_checkpoint, + downscale_factor=downscale_factor, )