Skip to content
18 changes: 9 additions & 9 deletions fme/downscaling/_deterministic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from fme.core.optimization import NullOptimization, Optimization
from fme.core.packer import Packer
from fme.core.typing_ import TensorDict
from fme.downscaling.data import BatchData, PairedBatchData, Topography
from fme.downscaling.data import BatchData, PairedBatchData, StaticInputs
from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate
from fme.downscaling.models import ModelOutputs, PairedNormalizationConfig
from fme.downscaling.modules.registry import ModuleRegistrySelector
Expand Down Expand Up @@ -121,20 +121,20 @@ def modules(self) -> torch.nn.ModuleList:
def train_on_batch(
self,
batch: PairedBatchData,
topography: Topography | None,
static_inputs: StaticInputs | None,
optimization: Optimization | NullOptimization,
) -> ModelOutputs:
return self._run_on_batch(batch, topography, optimization)
return self._run_on_batch(batch, static_inputs, optimization)

def generate_on_batch(
self,
batch: PairedBatchData,
topography: Topography | None,
static_inputs: StaticInputs | None,
n_samples: int = 1,
) -> ModelOutputs:
if n_samples != 1:
raise ValueError("n_samples must be 1 for deterministic models")
result = self._run_on_batch(batch, topography, self.null_optimization)
result = self._run_on_batch(batch, static_inputs, self.null_optimization)
for k, v in result.prediction.items():
result.prediction[k] = v.unsqueeze(1) # insert sample dimension
for k, v in result.target.items():
Expand All @@ -144,7 +144,7 @@ def generate_on_batch(
def generate_on_batch_no_target(
self,
batch: BatchData,
topography: Topography | None,
static_inputs: StaticInputs | None,
n_samples: int = 1,
) -> TensorDict:
raise NotImplementedError(
Expand All @@ -154,7 +154,7 @@ def generate_on_batch_no_target(
def _run_on_batch(
self,
batch: PairedBatchData,
topography: Topography | None,
static_inputs: StaticInputs | None,
optimizer: Optimization | NullOptimization,
) -> ModelOutputs:
coarse, fine = batch.coarse.data, batch.fine.data
Expand All @@ -166,14 +166,14 @@ def _run_on_batch(
interpolated = interpolate(coarse_norm, self.downscale_factor)

if self.config.use_fine_topography:
if topography is None:
if static_inputs is None:
raise ValueError(
"Topography must be provided for each batch when use of fine "
"topography is enabled."
)
else:
# Join the normalized topography to the input (see dataset for details)
topo = topography.data.unsqueeze(self._channel_axis)
topo = static_inputs.fields[0].data.unsqueeze(self._channel_axis)
coarse_norm = torch.concat(
[interpolated, topo], axis=self._channel_axis
)
Expand Down
2 changes: 2 additions & 0 deletions fme/downscaling/configs/test_train_config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
experiment_dir: /will/be/overwritten
static_inputs:
HGTsfc: /will/be/overwritten
max_epochs: 2
validate_using_ema: true
logging:
Expand Down
2 changes: 1 addition & 1 deletion fme/downscaling/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
PairedBatchItem,
PairedGriddedData,
)
from .topography import StaticInputs, Topography, get_normalized_topography
from .topography import StaticInput, StaticInputs, get_normalized_topography
from .utils import (
BatchedLatLonCoordinates,
ClosedInterval,
Expand Down
58 changes: 29 additions & 29 deletions fme/downscaling/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
)
from fme.downscaling.data.topography import (
StaticInputs,
Topography,
get_normalized_topography,
get_topography_downscale_factor,
)
Expand Down Expand Up @@ -177,41 +176,39 @@ def get_xarray_dataset(
strict=self.strict_ensemble,
)

def build_topography(
def build_static_inputs(
self,
coarse_coords: LatLonCoordinates,
requires_topography: bool,
static_inputs_from_checkpoint: StaticInputs | None = None,
) -> Topography | None:
static_inputs: StaticInputs | None = None,
) -> StaticInputs | None:
if requires_topography is False:
return None
if static_inputs_from_checkpoint is not None:
if static_inputs is not None:
# TODO: change to use full static inputs list
topography = static_inputs_from_checkpoint[0]
full_static_inputs = static_inputs
else:
if self.topography is None:
raise ValueError(
"Topography is required for this model, but no topography "
"dataset was specified in the configuration nor provided "
"in model checkpoint."
)
topography = get_normalized_topography(self.topography)
raise ValueError(
"Static inputs required for this model, but no static inputs "
"datasets were specified in the trainer configuration or provided "
"in model checkpoint."
)

# Fine grid boundaries are adjusted to exactly match the coarse grid
fine_lat_interval = adjust_fine_coord_range(
self.lat_extent,
full_coarse_coord=coarse_coords.lat,
full_fine_coord=topography.coords.lat,
full_fine_coord=full_static_inputs.coords.lat,
)
fine_lon_interval = adjust_fine_coord_range(
self.lon_extent,
full_coarse_coord=coarse_coords.lon,
full_fine_coord=topography.coords.lon,
full_fine_coord=full_static_inputs.coords.lon,
)
subset_topography = topography.subset_latlon(
subset_static_inputs = full_static_inputs.subset_latlon(
lat_interval=fine_lat_interval, lon_interval=fine_lon_interval
)
return subset_topography.to_device()
return subset_static_inputs.to_device()

def build_batchitem_dataset(
self,
Expand Down Expand Up @@ -244,7 +241,7 @@ def build(
self,
requirements: DataRequirements,
dist: Distributed | None = None,
static_inputs_from_checkpoint: StaticInputs | None = None,
static_inputs: StaticInputs | None = None,
) -> GriddedData:
# TODO: static_inputs_from_checkpoint is currently passed from the model
# to allow loading fine topography when no fine data is available.
Expand Down Expand Up @@ -286,14 +283,14 @@ def build(
persistent_workers=True if self.num_data_workers > 0 else False,
)
example = dataset[0]
subset_topography = self.build_topography(
subset_static_inputs = self.build_static_inputs(
coarse_coords=latlon_coords,
requires_topography=requirements.use_fine_topography,
static_inputs_from_checkpoint=static_inputs_from_checkpoint,
static_inputs=static_inputs,
)
return GriddedData(
_loader=dataloader,
topography=subset_topography,
static_inputs=subset_static_inputs,
shape=example.horizontal_shape,
dims=example.latlon_coordinates.dims,
variable_metadata=dataset.variable_metadata,
Expand Down Expand Up @@ -397,7 +394,7 @@ def build(
train: bool,
requirements: DataRequirements,
dist: Distributed | None = None,
static_inputs_from_checkpoint: StaticInputs | None = None,
static_inputs: StaticInputs | None = None,
) -> PairedGriddedData:
# TODO: static_inputs_from_checkpoint is currently passed from the model
# to allow loading fine topography when no fine data is available.
Expand Down Expand Up @@ -458,9 +455,8 @@ def build(
)

if requirements.use_fine_topography:
if static_inputs_from_checkpoint is not None:
# TODO: change to use full static inputs list
fine_topography = static_inputs_from_checkpoint[0]
if static_inputs is not None:
fine_topography = static_inputs
elif self.topography is None:
data_path = self.fine[0].data_path
file_pattern = self.fine[0].file_pattern
Expand All @@ -469,14 +465,18 @@ def build(
raise ValueError(
f"No files found matching '{data_path}/{file_pattern}'."
)
fine_topography = get_normalized_topography(raw_paths[0])
fine_topography = StaticInputs(
fields=[get_normalized_topography(raw_paths[0])]
)
else:
fine_topography = get_normalized_topography(self.topography)
fine_topography = StaticInputs(
fields=[get_normalized_topography(self.topography)]
)

fine_topography = fine_topography.to_device()
if (
get_topography_downscale_factor(
fine_topography.data.shape,
fine_topography.shape,
properties_fine.horizontal_coordinates.shape,
)
!= 1
Expand Down Expand Up @@ -558,7 +558,7 @@ def build(

return PairedGriddedData(
_loader=dataloader,
topography=fine_topography,
static_inputs=fine_topography,
coarse_shape=example.coarse.horizontal_shape,
downscale_factor=example.downscale_factor,
dims=example.fine.latlon_coordinates.dims,
Expand Down
Loading