Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions fme/downscaling/_deterministic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from fme.core.optimization import NullOptimization, Optimization
from fme.core.packer import Packer
from fme.core.typing_ import TensorDict
from fme.downscaling.data import BatchData, PairedBatchData, Topography
from fme.downscaling.data import BatchData, PairedBatchData, StaticInputs
from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate
from fme.downscaling.models import ModelOutputs, PairedNormalizationConfig
from fme.downscaling.modules.registry import ModuleRegistrySelector
Expand Down Expand Up @@ -121,20 +121,20 @@ def modules(self) -> torch.nn.ModuleList:
def train_on_batch(
self,
batch: PairedBatchData,
topography: Topography | None,
static_inputs: StaticInputs | None,
optimization: Optimization | NullOptimization,
) -> ModelOutputs:
return self._run_on_batch(batch, topography, optimization)
return self._run_on_batch(batch, static_inputs, optimization)

def generate_on_batch(
self,
batch: PairedBatchData,
topography: Topography | None,
static_inputs: StaticInputs | None,
n_samples: int = 1,
) -> ModelOutputs:
if n_samples != 1:
raise ValueError("n_samples must be 1 for deterministic models")
result = self._run_on_batch(batch, topography, self.null_optimization)
result = self._run_on_batch(batch, static_inputs, self.null_optimization)
for k, v in result.prediction.items():
result.prediction[k] = v.unsqueeze(1) # insert sample dimension
for k, v in result.target.items():
Expand All @@ -144,7 +144,7 @@ def generate_on_batch(
def generate_on_batch_no_target(
self,
batch: BatchData,
topography: Topography | None,
static_inputs: StaticInputs | None,
n_samples: int = 1,
) -> TensorDict:
raise NotImplementedError(
Expand All @@ -154,7 +154,7 @@ def generate_on_batch_no_target(
def _run_on_batch(
self,
batch: PairedBatchData,
topography: Topography | None,
static_inputs: StaticInputs | None,
optimizer: Optimization | NullOptimization,
) -> ModelOutputs:
coarse, fine = batch.coarse.data, batch.fine.data
Expand All @@ -166,14 +166,14 @@ def _run_on_batch(
interpolated = interpolate(coarse_norm, self.downscale_factor)

if self.config.use_fine_topography:
if topography is None:
if static_inputs is None:
raise ValueError(
"Topography must be provided for each batch when use of fine "
"topography is enabled."
)
else:
# Join the normalized topography to the input (see dataset for details)
topo = topography.data.unsqueeze(self._channel_axis)
topo = static_inputs.fields[0].data.unsqueeze(self._channel_axis)
coarse_norm = torch.concat(
[interpolated, topo], axis=self._channel_axis
)
Expand Down
2 changes: 2 additions & 0 deletions fme/downscaling/configs/test_train_config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
experiment_dir: /will/be/overwritten
static_inputs:
HGTsfc: /will/be/overwritten
max_epochs: 2
validate_using_ema: true
logging:
Expand Down
2 changes: 1 addition & 1 deletion fme/downscaling/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
PairedBatchItem,
PairedGriddedData,
)
from .topography import StaticInputs, Topography, get_normalized_topography
from .topography import StaticInput, StaticInputs, get_normalized_topography
from .utils import (
BatchedLatLonCoordinates,
ClosedInterval,
Expand Down
58 changes: 29 additions & 29 deletions fme/downscaling/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
)
from fme.downscaling.data.topography import (
StaticInputs,
Topography,
get_normalized_topography,
get_topography_downscale_factor,
)
Expand Down Expand Up @@ -177,41 +176,39 @@ def get_xarray_dataset(
strict=self.strict_ensemble,
)

def build_topography(
def build_static_inputs(
self,
coarse_coords: LatLonCoordinates,
requires_topography: bool,
static_inputs_from_checkpoint: StaticInputs | None = None,
) -> Topography | None:
static_inputs: StaticInputs | None = None,
) -> StaticInputs | None:
if requires_topography is False:
return None
if static_inputs_from_checkpoint is not None:
if static_inputs is not None:
# TODO: change to use full static inputs list
topography = static_inputs_from_checkpoint[0]
full_static_inputs = static_inputs
else:
if self.topography is None:
raise ValueError(
"Topography is required for this model, but no topography "
"dataset was specified in the configuration nor provided "
"in model checkpoint."
)
topography = get_normalized_topography(self.topography)
raise ValueError(
"Static inputs required for this model, but no static inputs "
"datasets were specified in the trainer configuration or provided "
"in model checkpoint."
)

# Fine grid boundaries are adjusted to exactly match the coarse grid
fine_lat_interval = adjust_fine_coord_range(
self.lat_extent,
full_coarse_coord=coarse_coords.lat,
full_fine_coord=topography.coords.lat,
full_fine_coord=full_static_inputs.coords.lat,
)
fine_lon_interval = adjust_fine_coord_range(
self.lon_extent,
full_coarse_coord=coarse_coords.lon,
full_fine_coord=topography.coords.lon,
full_fine_coord=full_static_inputs.coords.lon,
)
subset_topography = topography.subset_latlon(
subset_static_inputs = full_static_inputs.subset_latlon(
lat_interval=fine_lat_interval, lon_interval=fine_lon_interval
)
return subset_topography.to_device()
return subset_static_inputs.to_device()

def build_batchitem_dataset(
self,
Expand Down Expand Up @@ -244,7 +241,7 @@ def build(
self,
requirements: DataRequirements,
dist: Distributed | None = None,
static_inputs_from_checkpoint: StaticInputs | None = None,
static_inputs: StaticInputs | None = None,
) -> GriddedData:
# TODO: static_inputs_from_checkpoint is currently passed from the model
# to allow loading fine topography when no fine data is available.
Expand Down Expand Up @@ -286,14 +283,14 @@ def build(
persistent_workers=True if self.num_data_workers > 0 else False,
)
example = dataset[0]
subset_topography = self.build_topography(
subset_static_inputs = self.build_static_inputs(
coarse_coords=latlon_coords,
requires_topography=requirements.use_fine_topography,
static_inputs_from_checkpoint=static_inputs_from_checkpoint,
static_inputs=static_inputs,
)
return GriddedData(
_loader=dataloader,
topography=subset_topography,
static_inputs=subset_static_inputs,
shape=example.horizontal_shape,
dims=example.latlon_coordinates.dims,
variable_metadata=dataset.variable_metadata,
Expand Down Expand Up @@ -397,7 +394,7 @@ def build(
train: bool,
requirements: DataRequirements,
dist: Distributed | None = None,
static_inputs_from_checkpoint: StaticInputs | None = None,
static_inputs: StaticInputs | None = None,
) -> PairedGriddedData:
# TODO: static_inputs_from_checkpoint is currently passed from the model
# to allow loading fine topography when no fine data is available.
Expand Down Expand Up @@ -458,9 +455,8 @@ def build(
)

if requirements.use_fine_topography:
if static_inputs_from_checkpoint is not None:
# TODO: change to use full static inputs list
fine_topography = static_inputs_from_checkpoint[0]
if static_inputs is not None:
fine_topography = static_inputs
elif self.topography is None:
data_path = self.fine[0].data_path
file_pattern = self.fine[0].file_pattern
Expand All @@ -469,14 +465,18 @@ def build(
raise ValueError(
f"No files found matching '{data_path}/{file_pattern}'."
)
fine_topography = get_normalized_topography(raw_paths[0])
fine_topography = StaticInputs(
fields=[get_normalized_topography(raw_paths[0])]
)
else:
fine_topography = get_normalized_topography(self.topography)
fine_topography = StaticInputs(
fields=[get_normalized_topography(self.topography)]
)

fine_topography = fine_topography.to_device()
if (
get_topography_downscale_factor(
fine_topography.data.shape,
fine_topography.shape,
properties_fine.horizontal_coordinates.shape,
)
!= 1
Expand Down Expand Up @@ -558,7 +558,7 @@ def build(

return PairedGriddedData(
_loader=dataloader,
topography=fine_topography,
static_inputs=fine_topography,
coarse_shape=example.coarse.horizontal_shape,
downscale_factor=example.downscale_factor,
dims=example.fine.latlon_coordinates.dims,
Expand Down
Loading