Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions fme/downscaling/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def build_topography(
coarse_coords: LatLonCoordinates,
requires_topography: bool,
static_inputs_from_checkpoint: StaticInputs | None = None,
downscale_factor: int | None = None,
) -> Topography | None:
if requires_topography is False:
return None
Expand All @@ -202,11 +203,13 @@ def build_topography(
self.lat_extent,
full_coarse_coord=coarse_coords.lat,
full_fine_coord=topography.coords.lat,
downscale_factor=downscale_factor,
)
fine_lon_interval = adjust_fine_coord_range(
self.lon_extent,
full_coarse_coord=coarse_coords.lon,
full_fine_coord=topography.coords.lon,
downscale_factor=downscale_factor,
)
subset_topography = topography.subset_latlon(
lat_interval=fine_lat_interval, lon_interval=fine_lon_interval
Expand Down
1 change: 0 additions & 1 deletion fme/downscaling/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def __init__(
f"lon wraparound not implemented, received lon_min {lon_min} but "
f"expected lon_min < {self.lon_interval.start + 360.0}"
)

assert lats.numel() > 0, "No latitudes found in the specified range."
assert lons.numel() > 0, "No longitudes found in the specified range."

Expand Down
1 change: 1 addition & 0 deletions fme/downscaling/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def build(self) -> Downscaler:
requirements=self.model.data_requirements,
patch=self.patch,
static_inputs_from_checkpoint=model.static_inputs,
downscale_factor=model.downscale_factor,
)
for output_cfg in self.outputs
]
Expand Down
8 changes: 8 additions & 0 deletions fme/downscaling/inference/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def _build_gridded_data(
requirements: DataRequirements,
dist: Distributed | None = None,
static_inputs_from_checkpoint: StaticInputs | None = None,
downscale_factor: int | None = None,
) -> SliceWorkItemGriddedData:
xr_dataset, properties = loader_config.get_xarray_dataset(
names=requirements.coarse_names, n_timesteps=1
Expand All @@ -232,6 +233,7 @@ def _build_gridded_data(
requires_topography=requirements.use_fine_topography,
# TODO: update to support full list of static inputs
static_inputs_from_checkpoint=static_inputs_from_checkpoint,
downscale_factor=downscale_factor,
)
if topography is None:
raise ValueError("Topography is required for downscaling generation.")
Expand Down Expand Up @@ -286,6 +288,7 @@ def _build(
patch: PatchPredictionConfig,
coarse: list[XarrayDataConfig],
static_inputs_from_checkpoint: StaticInputs | None = None,
downscale_factor: int | None = None,
) -> DownscalingOutput:
updated_loader_config = self._replace_loader_config(
time,
Expand All @@ -299,6 +302,7 @@ def _build(
updated_loader_config,
requirements,
static_inputs_from_checkpoint=static_inputs_from_checkpoint,
downscale_factor=downscale_factor,
)

if self.zarr_chunks is None:
Expand Down Expand Up @@ -386,6 +390,7 @@ def build(
requirements: DataRequirements,
patch: PatchPredictionConfig,
static_inputs_from_checkpoint: StaticInputs | None = None,
downscale_factor: int | None = None,
) -> DownscalingOutput:
# Convert single time to TimeSlice
time: Slice | TimeSlice
Expand All @@ -409,6 +414,7 @@ def build(
patch=patch,
coarse=coarse,
static_inputs_from_checkpoint=static_inputs_from_checkpoint,
downscale_factor=downscale_factor,
)


Expand Down Expand Up @@ -469,6 +475,7 @@ def build(
requirements: DataRequirements,
patch: PatchPredictionConfig,
static_inputs_from_checkpoint: StaticInputs | None = None,
downscale_factor: int | None = None,
) -> DownscalingOutput:
coarse = self._single_xarray_config(loader_config.coarse)
return self._build(
Expand All @@ -480,4 +487,5 @@ def build(
patch=patch,
coarse=coarse,
static_inputs_from_checkpoint=static_inputs_from_checkpoint,
downscale_factor=downscale_factor,
)