diff --git a/fme/downscaling/_deterministic_models.py b/fme/downscaling/_deterministic_models.py index 87e533865..1dec44037 100644 --- a/fme/downscaling/_deterministic_models.py +++ b/fme/downscaling/_deterministic_models.py @@ -12,7 +12,7 @@ from fme.core.optimization import NullOptimization, Optimization from fme.core.packer import Packer from fme.core.typing_ import TensorDict -from fme.downscaling.data import BatchData, PairedBatchData, Topography +from fme.downscaling.data import BatchData, PairedBatchData, StaticInputs from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate from fme.downscaling.models import ModelOutputs, PairedNormalizationConfig from fme.downscaling.modules.registry import ModuleRegistrySelector @@ -121,20 +121,20 @@ def modules(self) -> torch.nn.ModuleList: def train_on_batch( self, batch: PairedBatchData, - topography: Topography | None, + static_inputs: StaticInputs | None, optimization: Optimization | NullOptimization, ) -> ModelOutputs: - return self._run_on_batch(batch, topography, optimization) + return self._run_on_batch(batch, static_inputs, optimization) def generate_on_batch( self, batch: PairedBatchData, - topography: Topography | None, + static_inputs: StaticInputs | None, n_samples: int = 1, ) -> ModelOutputs: if n_samples != 1: raise ValueError("n_samples must be 1 for deterministic models") - result = self._run_on_batch(batch, topography, self.null_optimization) + result = self._run_on_batch(batch, static_inputs, self.null_optimization) for k, v in result.prediction.items(): result.prediction[k] = v.unsqueeze(1) # insert sample dimension for k, v in result.target.items(): @@ -144,7 +144,7 @@ def generate_on_batch( def generate_on_batch_no_target( self, batch: BatchData, - topography: Topography | None, + static_inputs: StaticInputs | None, n_samples: int = 1, ) -> TensorDict: raise NotImplementedError( @@ -154,7 +154,7 @@ def generate_on_batch_no_target( def _run_on_batch( self, batch: PairedBatchData, - topography: Topography | None, + static_inputs: StaticInputs | None, optimizer: Optimization | NullOptimization, ) -> ModelOutputs: coarse, fine = batch.coarse.data, batch.fine.data @@ -166,14 +166,14 @@ def _run_on_batch( interpolated = interpolate(coarse_norm, self.downscale_factor) if self.config.use_fine_topography: - if topography is None: + if static_inputs is None: raise ValueError( "Topography must be provided for each batch when use of fine " "topography is enabled." ) else: # Join the normalized topography to the input (see dataset for details) - topo = topography.data.unsqueeze(self._channel_axis) + topo = static_inputs.fields[0].data.unsqueeze(self._channel_axis) coarse_norm = torch.concat( [interpolated, topo], axis=self._channel_axis ) diff --git a/fme/downscaling/configs/test_train_config.yaml b/fme/downscaling/configs/test_train_config.yaml index 5c8b7c34f..5911716c2 100644 --- a/fme/downscaling/configs/test_train_config.yaml +++ b/fme/downscaling/configs/test_train_config.yaml @@ -1,4 +1,6 @@ experiment_dir: /will/be/overwritten +static_inputs: + HGTsfc: /will/be/overwritten max_epochs: 2 validate_using_ema: true logging: diff --git a/fme/downscaling/data/__init__.py b/fme/downscaling/data/__init__.py index b696f2a2e..bc8f3ae4d 100644 --- a/fme/downscaling/data/__init__.py +++ b/fme/downscaling/data/__init__.py @@ -12,7 +12,7 @@ PairedBatchItem, PairedGriddedData, ) -from .topography import StaticInputs, Topography, get_normalized_topography +from .topography import StaticInput, StaticInputs, get_normalized_topography from .utils import ( BatchedLatLonCoordinates, ClosedInterval, diff --git a/fme/downscaling/data/config.py b/fme/downscaling/data/config.py index 4ce1f928b..533b89be4 100644 --- a/fme/downscaling/data/config.py +++ b/fme/downscaling/data/config.py @@ -23,7 +23,6 @@ ) from fme.downscaling.data.topography import ( StaticInputs, - Topography, get_normalized_topography, get_topography_downscale_factor, ) @@ -177,41 +176,39 @@ def get_xarray_dataset( strict=self.strict_ensemble, ) - def build_topography( + def build_static_inputs( self, coarse_coords: LatLonCoordinates, requires_topography: bool, - static_inputs_from_checkpoint: StaticInputs | None = None, - ) -> Topography | None: + static_inputs: StaticInputs | None = None, + ) -> StaticInputs | None: if requires_topography is False: return None - if static_inputs_from_checkpoint is not None: + if static_inputs is not None: # TODO: change to use full static inputs list - topography = static_inputs_from_checkpoint[0] + full_static_inputs = static_inputs else: - if self.topography is None: - raise ValueError( - "Topography is required for this model, but no topography " - "dataset was specified in the configuration nor provided " - "in model checkpoint." - ) - topography = get_normalized_topography(self.topography) + raise ValueError( + "Static inputs required for this model, but no static inputs " + "datasets were specified in the trainer configuration or provided " + "in model checkpoint." + ) # Fine grid boundaries are adjusted to exactly match the coarse grid fine_lat_interval = adjust_fine_coord_range( self.lat_extent, full_coarse_coord=coarse_coords.lat, - full_fine_coord=topography.coords.lat, + full_fine_coord=full_static_inputs.coords.lat, ) fine_lon_interval = adjust_fine_coord_range( self.lon_extent, full_coarse_coord=coarse_coords.lon, - full_fine_coord=topography.coords.lon, + full_fine_coord=full_static_inputs.coords.lon, ) - subset_topography = topography.subset_latlon( + subset_static_inputs = full_static_inputs.subset_latlon( lat_interval=fine_lat_interval, lon_interval=fine_lon_interval ) - return subset_topography.to_device() + return subset_static_inputs.to_device() def build_batchitem_dataset( self, @@ -244,7 +241,7 @@ def build( self, requirements: DataRequirements, dist: Distributed | None = None, - static_inputs_from_checkpoint: StaticInputs | None = None, + static_inputs: StaticInputs | None = None, ) -> GriddedData: # TODO: static_inputs_from_checkpoint is currently passed from the model # to allow loading fine topography when no fine data is available. @@ -286,14 +283,14 @@ def build( persistent_workers=True if self.num_data_workers > 0 else False, ) example = dataset[0] - subset_topography = self.build_topography( + subset_static_inputs = self.build_static_inputs( coarse_coords=latlon_coords, requires_topography=requirements.use_fine_topography, - static_inputs_from_checkpoint=static_inputs_from_checkpoint, + static_inputs=static_inputs, ) return GriddedData( _loader=dataloader, - topography=subset_topography, + static_inputs=subset_static_inputs, shape=example.horizontal_shape, dims=example.latlon_coordinates.dims, variable_metadata=dataset.variable_metadata, @@ -397,7 +394,7 @@ def build( train: bool, requirements: DataRequirements, dist: Distributed | None = None, - static_inputs_from_checkpoint: StaticInputs | None = None, + static_inputs: StaticInputs | None = None, ) -> PairedGriddedData: # TODO: static_inputs_from_checkpoint is currently passed from the model # to allow loading fine topography when no fine data is available. @@ -458,9 +455,8 @@ def build( ) if requirements.use_fine_topography: - if static_inputs_from_checkpoint is not None: - # TODO: change to use full static inputs list - fine_topography = static_inputs_from_checkpoint[0] + if static_inputs is not None: + fine_topography = static_inputs elif self.topography is None: data_path = self.fine[0].data_path file_pattern = self.fine[0].file_pattern @@ -469,14 +465,18 @@ def build( raise ValueError( f"No files found matching '{data_path}/{file_pattern}'." ) - fine_topography = get_normalized_topography(raw_paths[0]) + fine_topography = StaticInputs( + fields=[get_normalized_topography(raw_paths[0])] + ) else: - fine_topography = get_normalized_topography(self.topography) + fine_topography = StaticInputs( + fields=[get_normalized_topography(self.topography)] + ) fine_topography = fine_topography.to_device() if ( get_topography_downscale_factor( - fine_topography.data.shape, + fine_topography.shape, properties_fine.horizontal_coordinates.shape, ) != 1 @@ -558,7 +558,7 @@ def build( return PairedGriddedData( _loader=dataloader, - topography=fine_topography, + static_inputs=fine_topography, coarse_shape=example.coarse.horizontal_shape, downscale_factor=example.downscale_factor, dims=example.fine.latlon_coordinates.dims, diff --git a/fme/downscaling/data/datasets.py b/fme/downscaling/data/datasets.py index eb13893a9..9d0f9d292 100644 --- a/fme/downscaling/data/datasets.py +++ b/fme/downscaling/data/datasets.py @@ -20,7 +20,7 @@ from fme.core.generics.data import SizedMap from fme.core.typing_ import TensorMapping from fme.downscaling.data.patching import Patch, get_patches -from fme.downscaling.data.topography import Topography +from fme.downscaling.data.topography import StaticInputs from fme.downscaling.data.utils import ( BatchedLatLonCoordinates, ClosedInterval, @@ -318,7 +318,7 @@ class GriddedData: dims: list[str] variable_metadata: Mapping[str, VariableMetadata] all_times: xr.CFTimeIndex - topography: Topography | None + static_inputs: StaticInputs | None @property def loader(self) -> DataLoader[BatchItem]: @@ -329,24 +329,24 @@ def on_device(batch: BatchItem) -> BatchItem: @property def topography_downscale_factor(self) -> int | None: - if self.topography: + if self.static_inputs: if ( - self.topography.shape[0] % self.shape[0] != 0 - or self.topography.shape[1] % self.shape[1] != 0 + self.static_inputs.shape[0] % self.shape[0] != 0 + or self.static_inputs.shape[1] % self.shape[1] != 0 ): raise ValueError( "Topography shape must be evenly divisible by data shape. " - f"Got topography {self.topography.shape} and data {self.shape}" + f"Got topography {self.static_inputs.shape} and data {self.shape}" ) - return self.topography.shape[0] // self.shape[0] + return self.static_inputs.shape[0] // self.shape[0] else: return None def get_generator( self, - ) -> Iterator[tuple["BatchData", Topography | None]]: + ) -> Iterator[tuple["BatchData", StaticInputs | None]]: for batch in self.loader: - yield (batch, self.topography) + yield (batch, self.static_inputs) def get_patched_generator( self, @@ -354,10 +354,10 @@ def get_patched_generator( overlap: int = 0, drop_partial_patches: bool = True, random_offset: bool = False, - ) -> Iterator[tuple["BatchData", Topography | None]]: + ) -> Iterator[tuple["BatchData", StaticInputs | None]]: patched_generator = patched_batch_gen_from_loader( loader=self.loader, - topography=self.topography, + static_inputs=self.static_inputs, coarse_yx_extent=self.shape, coarse_yx_patch_extent=yx_patch_extent, downscale_factor=self.topography_downscale_factor, @@ -367,7 +367,7 @@ def get_patched_generator( ) return cast( - Iterator[tuple[BatchData, Topography | None]], + Iterator[tuple[BatchData, StaticInputs | None]], patched_generator, ) @@ -380,7 +380,7 @@ class PairedGriddedData: dims: list[str] variable_metadata: Mapping[str, VariableMetadata] all_times: xr.CFTimeIndex - topography: Topography | None + static_inputs: StaticInputs | None @property def loader(self) -> DataLoader[PairedBatchItem]: @@ -391,9 +391,9 @@ def on_device(batch: PairedBatchItem) -> PairedBatchItem: def get_generator( self, - ) -> Iterator[tuple["PairedBatchData", Topography | None]]: + ) -> Iterator[tuple["PairedBatchData", StaticInputs | None]]: for batch in self.loader: - yield (batch, self.topography) + yield (batch, self.static_inputs) def get_patched_generator( self, @@ -402,10 +402,10 @@ def get_patched_generator( drop_partial_patches: bool = True, random_offset: bool = False, shuffle: bool = False, - ) -> Iterator[tuple["PairedBatchData", Topography | None]]: + ) -> Iterator[tuple["PairedBatchData", StaticInputs | None]]: patched_generator = patched_batch_gen_from_paired_loader( self.loader, - self.topography, + self.static_inputs, coarse_yx_extent=self.coarse_shape, coarse_yx_patch_extent=coarse_yx_patch_extent, downscale_factor=self.downscale_factor, @@ -415,7 +415,7 @@ def get_patched_generator( shuffle=shuffle, ) return cast( - Iterator[tuple[PairedBatchData, Topography | None]], + Iterator[tuple[PairedBatchData, StaticInputs | None]], patched_generator, ) @@ -722,7 +722,7 @@ def _get_paired_patches( def patched_batch_gen_from_loader( loader: DataLoader[BatchItem], - topography: Topography | None, + static_inputs: StaticInputs | None, coarse_yx_extent: tuple[int, int], coarse_yx_patch_extent: tuple[int, int], downscale_factor: int | None, @@ -730,7 +730,7 @@ def patched_batch_gen_from_loader( drop_partial_patches: bool = True, random_offset: bool = False, shuffle: bool = False, -) -> Iterator[tuple[BatchData, Topography | None]]: +) -> Iterator[tuple[BatchData, StaticInputs | None]]: for batch in loader: coarse_patches, fine_patches = _get_paired_patches( coarse_yx_extent=coarse_yx_extent, @@ -743,23 +743,23 @@ def patched_batch_gen_from_loader( ) batch_data_patches = batch.generate_from_patches(coarse_patches) - if topography is not None: + if static_inputs is not None: if fine_patches is None: raise ValueError( "Topography provided but downscale_factor is None, cannot " "generate fine patches." ) - topography_patches = topography.generate_from_patches(fine_patches) + static_inputs_patches = static_inputs.generate_from_patches(fine_patches) else: - topography_patches = null_generator(len(coarse_patches)) + static_inputs_patches = null_generator(len(coarse_patches)) # Combine outputs from both generators - yield from zip(batch_data_patches, topography_patches) + yield from zip(batch_data_patches, static_inputs_patches) def patched_batch_gen_from_paired_loader( loader: DataLoader[PairedBatchItem], - topography: Topography | None, + static_inputs: StaticInputs | None, coarse_yx_extent: tuple[int, int], coarse_yx_patch_extent: tuple[int, int], downscale_factor: int, @@ -767,7 +767,7 @@ def patched_batch_gen_from_paired_loader( drop_partial_patches: bool = True, random_offset: bool = False, shuffle: bool = False, -) -> Iterator[tuple[PairedBatchData, Topography | None]]: +) -> Iterator[tuple[PairedBatchData, StaticInputs | None]]: for batch in loader: coarse_patches, fine_patches = _get_paired_patches( coarse_yx_extent=coarse_yx_extent, @@ -780,15 +780,15 @@ def patched_batch_gen_from_paired_loader( ) batch_data_patches = batch.generate_from_patches(coarse_patches, fine_patches) - if topography is not None: + if static_inputs is not None: if fine_patches is None: raise ValueError( - "Topography provided but downscale_factor is None, cannot " + "Static inputs provided but downscale_factor is None, cannot " "generate fine patches." ) - topography_patches = topography.generate_from_patches(fine_patches) + static_inputs_patches = static_inputs.generate_from_patches(fine_patches) else: - topography_patches = null_generator(len(coarse_patches)) + static_inputs_patches = null_generator(len(coarse_patches)) # Combine outputs from both generators - yield from zip(batch_data_patches, topography_patches) + yield from zip(batch_data_patches, static_inputs_patches) diff --git a/fme/downscaling/data/test_config.py b/fme/downscaling/data/test_config.py index a275684f4..0517e6b61 100644 --- a/fme/downscaling/data/test_config.py +++ b/fme/downscaling/data/test_config.py @@ -1,6 +1,7 @@ import dataclasses import pytest +import torch from fme.core.dataset.xarray import XarrayDataConfig from fme.downscaling.data.config import ( @@ -8,7 +9,8 @@ PairedDataLoaderConfig, XarrayEnsembleDataConfig, ) -from fme.downscaling.data.utils import ClosedInterval +from fme.downscaling.data.topography import StaticInput, StaticInputs +from fme.downscaling.data.utils import ClosedInterval, LatLonCoordinates from fme.downscaling.requirements import DataRequirements from fme.downscaling.test_utils import data_paths_helper @@ -63,7 +65,15 @@ def test_DataLoaderConfig_build(tmp_path, very_fast_only: bool): lat_extent=ClosedInterval(1, 4), lon_extent=ClosedInterval(0, 3), ) - data = data_config.build(requirements=requirements) + static_inputs = StaticInputs( + fields=[ + StaticInput( + data=torch.ones((8, 8)), + coords=LatLonCoordinates(lat=torch.ones(8), lon=torch.ones(8)), + ) + ] + ) + data = data_config.build(requirements=requirements, static_inputs=static_inputs) batch = next(iter(data.loader)) # lat/lon midpoints are on (0.5, 1.5, ...) assert batch.data["var0"].shape == (2, 3, 3) diff --git a/fme/downscaling/data/test_patching.py b/fme/downscaling/data/test_patching.py index 12da9b2d2..7630722c4 100644 --- a/fme/downscaling/data/test_patching.py +++ b/fme/downscaling/data/test_patching.py @@ -3,7 +3,7 @@ import torch from fme.core.device import get_device -from fme.downscaling.data import PairedBatchData, Topography +from fme.downscaling.data import PairedBatchData, StaticInput, StaticInputs from fme.downscaling.data.datasets import patched_batch_gen_from_paired_loader from fme.downscaling.data.patching import ( _divide_into_slices, @@ -120,12 +120,14 @@ def test_paired_patches_with_random_offset_consistent(overlap): coarse_shape[1] * downscale_factor, device=get_device(), ) - topography = Topography(data=topography_data, coords=full_fine_coords[0]) + topography = StaticInputs( + fields=[StaticInput(data=topography_data, coords=full_fine_coords[0])] + ) y_offsets = [] x_offsets = [] batch_generator = patched_batch_gen_from_paired_loader( loader=loader, - topography=topography, + static_inputs=topography, coarse_yx_extent=coarse_shape, coarse_yx_patch_extent=(10, 10), downscale_factor=downscale_factor, @@ -181,11 +183,13 @@ def test_paired_patches_shuffle(shuffle): device=get_device(), ) fine_coords = next(iter(loader)).fine.latlon_coordinates[0] - topography = Topography(data=topography_data, coords=fine_coords) + static_inputs = StaticInputs( + fields=[StaticInput(data=topography_data, coords=fine_coords)] + ) generator0 = patched_batch_gen_from_paired_loader( loader=loader, - topography=topography, + static_inputs=static_inputs, coarse_yx_extent=coarse_shape, coarse_yx_patch_extent=(2, 2), downscale_factor=downscale_factor, @@ -196,7 +200,7 @@ def test_paired_patches_shuffle(shuffle): ) generator1 = patched_batch_gen_from_paired_loader( loader=loader, - topography=topography, + static_inputs=static_inputs, coarse_yx_extent=coarse_shape, coarse_yx_patch_extent=(2, 2), downscale_factor=downscale_factor, @@ -220,8 +224,12 @@ def test_paired_patches_shuffle(shuffle): data0 = torch.concat([patch.coarse.data["x"] for patch in patches0], dim=0) data1 = torch.concat([patch.coarse.data["x"] for patch in patches1], dim=0) - topo_concat_0 = torch.concat([t0.data for t0 in topography0], dim=0) - topo_concat_1 = torch.concat([t1.data for t1 in topography1], dim=0) + topo_concat_0 = torch.concat( + [t0.fields[0].data for t0 in topography0 if t0 is not None], dim=0 + ) + topo_concat_1 = torch.concat( + [t1.fields[0].data for t1 in topography1 if t1 is not None], dim=0 + ) if shuffle: assert not torch.equal(data0, data1) diff --git a/fme/downscaling/data/test_topography.py b/fme/downscaling/data/test_topography.py index d1cc425b5..366bcd12a 100644 --- a/fme/downscaling/data/test_topography.py +++ b/fme/downscaling/data/test_topography.py @@ -4,7 +4,7 @@ from fme.core.coordinates import LatLonCoordinates from fme.downscaling.data.patching import Patch, _HorizontalSlice -from .topography import StaticInputs, Topography, _range_to_slice +from .topography import StaticInput, StaticInputs, _range_to_slice from .utils import ClosedInterval @@ -26,7 +26,7 @@ ) def test_Topography_error_cases(init_args): with pytest.raises(ValueError): - Topography(*init_args) + StaticInput(*init_args) def test__range_to_slice(): @@ -43,7 +43,7 @@ def test_subset_latlon(): coords = LatLonCoordinates( lat=torch.linspace(0, 9, 10), lon=torch.linspace(0, 9, 10) ) - topo = Topography(data=data, coords=coords) + topo = StaticInput(data=data, coords=coords) lat_interval = ClosedInterval(2, 5) lon_interval = ClosedInterval(3, 7) subset_topo = topo.subset_latlon(lat_interval, lon_interval) @@ -67,7 +67,7 @@ def test_Topography_generate_from_patches(): output_slice=output_slice, ), ] - topography = Topography( + topography = StaticInput( torch.arange(16).reshape(4, 4), LatLonCoordinates(torch.arange(4), torch.arange(4)), ) @@ -82,7 +82,7 @@ def test_Topography_generate_from_patches(): assert torch.equal(generated_patches[1].data, torch.tensor([[2], [6]])) -def test_StaticInputs_generate_from_patches(): +def testStaticInputs_generate_from_patches(): output_slice = _HorizontalSlice(y=slice(None), x=slice(None)) patches = [ Patch( @@ -95,11 +95,11 @@ def test_StaticInputs_generate_from_patches(): ), ] data = torch.arange(16).reshape(4, 4) - topography = Topography( + topography = StaticInput( data, LatLonCoordinates(torch.arange(4), torch.arange(4)), ) - land_frac = Topography( + land_frac = StaticInput( data * -1.0, LatLonCoordinates(torch.arange(4), torch.arange(4)), ) @@ -124,13 +124,13 @@ def test_StaticInputs_generate_from_patches(): assert torch.equal(generated_patches[1][1].data, expected_topography_patch_1 * -1.0) -def test_StaticInputs_serialize(): +def testStaticInputs_serialize(): data = torch.arange(16).reshape(4, 4) - topography = Topography( + topography = StaticInput( data, LatLonCoordinates(torch.arange(4), torch.arange(4)), ) - land_frac = Topography( + land_frac = StaticInput( data * -1.0, LatLonCoordinates(torch.arange(4), torch.arange(4)), ) diff --git a/fme/downscaling/data/topography.py b/fme/downscaling/data/topography.py index cb36052e3..8a355d225 100644 --- a/fme/downscaling/data/topography.py +++ b/fme/downscaling/data/topography.py @@ -18,9 +18,8 @@ def _range_to_slice(coords: torch.Tensor, range: ClosedInterval) -> slice: return slice(indices[0].item(), indices[-1].item() + 1) -# TODO: rename to StaticInput, make _apply_patch public @dataclasses.dataclass -class Topography: +class StaticInput: data: torch.Tensor coords: LatLonCoordinates @@ -47,14 +46,14 @@ def subset_latlon( self, lat_interval: ClosedInterval, lon_interval: ClosedInterval, - ) -> "Topography": + ) -> "StaticInput": lat_slice = _range_to_slice(self.coords.lat, lat_interval) lon_slice = _range_to_slice(self.coords.lon, lon_interval) return self._latlon_index_slice(lat_slice=lat_slice, lon_slice=lon_slice) - def to_device(self) -> "Topography": + def to_device(self) -> "StaticInput": device = get_device() - return Topography( + return StaticInput( data=self.data.to(device), coords=LatLonCoordinates( lat=self.coords.lat.to(device), @@ -71,13 +70,13 @@ def _latlon_index_slice( self, lat_slice: slice, lon_slice: slice, - ) -> "Topography": + ) -> "StaticInput": sliced_data = self.data[lat_slice, lon_slice] sliced_latlon = LatLonCoordinates( lat=self.coords.lat[lat_slice], lon=self.coords.lon[lon_slice], ) - return Topography( + return StaticInput( data=sliced_data, coords=sliced_latlon, ) @@ -85,7 +84,7 @@ def _latlon_index_slice( def generate_from_patches( self, patches: list[Patch], - ) -> Generator["Topography", None, None]: + ) -> Generator["StaticInput", None, None]: for patch in patches: yield self._apply_patch(patch) @@ -116,7 +115,7 @@ def get_normalized_topography(path: str, topography_name: str = "HGTsfc"): topography_normalized = (topography - topography.mean()) / topography.std() - return Topography( + return StaticInput( data=torch.tensor(topography_normalized.values, dtype=torch.float32), coords=coords, ) @@ -151,7 +150,7 @@ def get_topography_downscale_factor( @dataclasses.dataclass class StaticInputs: - fields: list[Topography] + fields: list[StaticInput] def __post_init__(self): for i, field in enumerate(self.fields[1:]): @@ -164,13 +163,6 @@ def __post_init__(self): def __getitem__(self, index: int): return self.fields[index] - @property - def input_tensors(self) -> list[torch.Tensor]: - if len(self.fields) > 0: - return [field.data for field in self.fields] - else: - return torch.tensor([]) - @property def coords(self) -> LatLonCoordinates: if len(self.fields) == 0: @@ -215,7 +207,7 @@ def to_state(self) -> dict: def from_state(cls, state: dict) -> "StaticInputs": return cls( fields=[ - Topography( + StaticInput( data=field_state["data"], coords=LatLonCoordinates( lat=field_state["coords"]["lat"], diff --git a/fme/downscaling/evaluator.py b/fme/downscaling/evaluator.py index 3e970d805..8c53fbf3d 100644 --- a/fme/downscaling/evaluator.py +++ b/fme/downscaling/evaluator.py @@ -62,11 +62,11 @@ def run(self): else: batch_generator = self.data.get_generator() - for i, (batch, topography) in enumerate(batch_generator): + for i, (batch, static_inputs) in enumerate(batch_generator): with torch.no_grad(): logging.info(f"Generating predictions on batch {i + 1}") outputs = self.model.generate_on_batch( - batch, topography, n_samples=self.n_samples + batch, static_inputs, n_samples=self.n_samples ) logging.info("Recording diagnostics to aggregator") # Add sample dimension to coarse values for generation comparison @@ -113,7 +113,7 @@ def __init__( def run(self): logging.info(f"Running {self.event_name} event evaluation") - batch, topography = next(iter(self.data.get_generator())) + batch, static_inputs = next(iter(self.data.get_generator())) sample_agg = PairedSampleAggregator( target=batch[0].fine.data, coarse=batch[0].coarse.data, @@ -132,7 +132,7 @@ def run(self): f"for event {self.event_name}" ) outputs = self.model.generate_on_batch( - batch, topography, n_samples=end_idx - start_idx + batch, static_inputs, n_samples=end_idx - start_idx ) sample_agg.record_batch(outputs.prediction) @@ -179,7 +179,7 @@ def get_paired_gridded_data( return event_data_config.build( train=False, requirements=requirements, - static_inputs_from_checkpoint=static_inputs_from_checkpoint, + static_inputs=static_inputs_from_checkpoint, ) @@ -210,7 +210,7 @@ def _build_default_evaluator(self) -> Evaluator: dataset = self.data.build( train=False, requirements=self.model.data_requirements, - static_inputs_from_checkpoint=model.static_inputs, + static_inputs=model.static_inputs, ) evaluator_model: DiffusionModel | PatchPredictor if self.patch.divide_generation and self.patch.composite_prediction: diff --git a/fme/downscaling/inference/inference.py b/fme/downscaling/inference/inference.py index cf1ce055c..637a7f8fe 100644 --- a/fme/downscaling/inference/inference.py +++ b/fme/downscaling/inference/inference.py @@ -10,7 +10,7 @@ from fme.core.dicts import to_flat_dict from fme.core.logging_utils import LoggingConfig -from ..data import DataLoaderConfig, Topography +from ..data import DataLoaderConfig, StaticInputs from ..models import CheckpointModelConfig, DiffusionModel from ..predictors import ( CascadePredictor, @@ -56,7 +56,7 @@ def run_all(self): def _get_generation_model( self, - topography: Topography, + static_inputs: StaticInputs, output: DownscalingOutput, ) -> DiffusionModel | PatchPredictor | CascadePredictor: """ @@ -67,7 +67,7 @@ def _get_generation_model( generations. """ model_patch_shape = self.model.fine_shape - actual_shape = tuple(topography.data.shape) + actual_shape = tuple(static_inputs.shape) if model_patch_shape == actual_shape: # short circuit, no patching necessary @@ -111,16 +111,20 @@ def run_output_generation(self, output: DownscalingOutput): total_batches = len(output.data.loader) loaded_item: LoadedSliceWorkItem - topography: Topography - for i, (loaded_item, topography) in enumerate(output.data.get_generator()): + static_inputs: StaticInputs + for i, (loaded_item, static_inputs) in enumerate(output.data.get_generator()): if writer is None: writer = output.get_writer( - latlon_coords=topography.coords, + latlon_coords=static_inputs.coords, output_dir=self.output_dir, ) - writer.initialize_store(topography.data.cpu().numpy().dtype) + writer.initialize_store( + static_inputs.fields[0].data.cpu().numpy().dtype + ) if model is None: - model = self._get_generation_model(topography=topography, output=output) + model = self._get_generation_model( + static_inputs=static_inputs, output=output + ) logging.info( f"[{output.name}] Batch {i+1}/{total_batches}, " @@ -128,7 +132,9 @@ def run_output_generation(self, output: DownscalingOutput): ) output_data = model.generate_on_batch_no_target( - loaded_item.batch, topography=topography, n_samples=loaded_item.n_ens + loaded_item.batch, + static_inputs=static_inputs, + n_samples=loaded_item.n_ens, ) output_np = {key: value.cpu().numpy() for key, value in output_data.items()} insert_slices = loaded_item.dim_insert_slices diff --git a/fme/downscaling/inference/output.py b/fme/downscaling/inference/output.py index a794f103b..452701784 100644 --- a/fme/downscaling/inference/output.py +++ b/fme/downscaling/inference/output.py @@ -227,11 +227,10 @@ def _build_gridded_data( "Downscaling data loader only supports datasets with latlon coords." ) dataset = loader_config.build_batchitem_dataset(xr_dataset, properties) - topography = loader_config.build_topography( + topography = loader_config.build_static_inputs( coords, requires_topography=requirements.use_fine_topography, - # TODO: update to support full list of static inputs - static_inputs_from_checkpoint=static_inputs_from_checkpoint, + static_inputs=static_inputs_from_checkpoint, ) if topography is None: raise ValueError("Topography is required for downscaling generation.") @@ -273,7 +272,7 @@ def _build_gridded_data( all_times=xr_dataset.sample_start_times, dtype=slice_dataset.dtype, max_output_shape=slice_dataset.max_output_shape, - topography=topography, + static_inputs=topography, ) def _build( diff --git a/fme/downscaling/inference/test_inference.py b/fme/downscaling/inference/test_inference.py index edb2a494a..343371a54 100644 --- a/fme/downscaling/inference/test_inference.py +++ b/fme/downscaling/inference/test_inference.py @@ -13,8 +13,8 @@ from fme.core.logging_utils import LoggingConfig from fme.downscaling.data import ( LatLonCoordinates, + StaticInput, StaticInputs, - Topography, get_normalized_topography, ) from fme.downscaling.inference.constants import ENSEMBLE_NAME, TIME_NAME @@ -62,10 +62,10 @@ def mock_output_target(): return target -def get_topography(shape=(16, 16)): +def get_static_inputs(shape=(16, 16)): data = torch.randn(shape) coords = LatLonCoordinates(lat=torch.arange(shape[0]), lon=torch.arange(shape[1])) - return Topography(data=data, coords=coords) + return StaticInputs([StaticInput(data=data, coords=coords)]) # Tests for Downscaler initialization @@ -92,7 +92,7 @@ def test_get_generation_model_exact_match(mock_model, mock_output_target): Test _get_generation_model returns model unchanged when shapes match exactly. """ mock_model.fine_shape = (16, 16) - topo = get_topography(shape=(16, 16)) + static_inputs = get_static_inputs(shape=(16, 16)) downscaler = Downscaler( model=mock_model, @@ -100,7 +100,7 @@ def test_get_generation_model_exact_match(mock_model, mock_output_target): ) result = downscaler._get_generation_model( - topography=topo, + static_inputs=static_inputs, output=mock_output_target, ) @@ -116,7 +116,7 @@ def test_get_generation_model_raises_when_domain_too_small( smaller than model. """ mock_model.fine_shape = (16, 16) - topo = get_topography(shape=topo_shape) + topo = get_static_inputs(shape=topo_shape) downscaler = Downscaler( model=mock_model, @@ -125,7 +125,7 @@ def test_get_generation_model_raises_when_domain_too_small( with pytest.raises(ValueError): downscaler._get_generation_model( - topography=topo, + static_inputs=topo, output=mock_output_target, ) @@ -138,7 +138,7 @@ def test_get_generation_model_creates_patch_predictor_when_needed( large domains with patching. """ mock_model.fine_shape = (16, 16) - topo = get_topography(shape=(32, 32)) # Larger than model + static_inputs = get_static_inputs(shape=(32, 32)) # Larger than model patch_config = PatchPredictionConfig( divide_generation=True, @@ -152,7 +152,7 @@ def test_get_generation_model_creates_patch_predictor_when_needed( ) model = downscaler._get_generation_model( - topography=topo, + static_inputs=static_inputs, output=mock_output_target, ) @@ -168,7 +168,7 @@ def test_get_generation_model_raises_when_large_domain_without_patching( not configured. """ mock_model.fine_shape = (16, 16) - topo = get_topography(shape=(32, 32)) # Larger than model + topo = get_static_inputs(shape=(32, 32)) # Larger than model mock_output_target.patch = PatchPredictionConfig(divide_generation=False) downscaler = Downscaler( @@ -178,7 +178,7 @@ def test_get_generation_model_raises_when_large_domain_without_patching( with pytest.raises(ValueError): downscaler._get_generation_model( - topography=topo, + static_inputs=topo, output=mock_output_target, ) @@ -194,10 +194,10 @@ def test_run_target_generation_skips_padding_items( mock_work_item.n_ens = 4 mock_work_item.batch = MagicMock() - mock_topo = get_topography(shape=(16, 16)) + static_inputs = get_static_inputs(shape=(16, 16)) mock_gridded_data = SliceWorkItemGriddedData( - [mock_work_item], {}, [0], torch.float32, (1, 4, 16, 16), mock_topo + [mock_work_item], {}, [0], torch.float32, (1, 4, 16, 16), static_inputs ) mock_output_target.data = mock_gridded_data mock_model.fine_shape = (16, 16) diff --git a/fme/downscaling/inference/test_output.py b/fme/downscaling/inference/test_output.py index 7dc11b533..f8790f0da 100644 --- a/fme/downscaling/inference/test_output.py +++ b/fme/downscaling/inference/test_output.py @@ -1,10 +1,12 @@ from unittest.mock import MagicMock import pytest +import torch from fme.core.dataset.time import TimeSlice from fme.core.dataset.xarray import XarrayDataConfig -from fme.downscaling.data import ClosedInterval +from fme.downscaling.data import ClosedInterval, StaticInput, StaticInputs +from fme.downscaling.data.utils import LatLonCoordinates from fme.downscaling.inference.output import ( DownscalingOutput, DownscalingOutputConfig, @@ -101,8 +103,17 @@ def test_event_config_build_creates_output_target_with_single_time( lat_extent=ClosedInterval(2.0, 6.0), lon_extent=ClosedInterval(2.0, 6.0), ) - - output_target = config.build(loader_config, requirements, patch_config) + static_inputs = StaticInputs( + fields=[ + StaticInput( + data=torch.ones((8, 8)), + coords=LatLonCoordinates(lat=torch.ones(8), lon=torch.ones(8)), + ) + ] + ) + output_target = config.build( + loader_config, requirements, patch_config, static_inputs + ) # Verify OutputTarget was created assert isinstance(output_target, DownscalingOutput) @@ -130,8 +141,18 @@ def test_region_config_build_creates_output_target_with_time_range( n_ens=4, save_vars=["var0", "var1"], ) + static_inputs = StaticInputs( + fields=[ + StaticInput( + data=torch.ones((8, 8)), + coords=LatLonCoordinates(lat=torch.ones(8), lon=torch.ones(8)), + ) + ] + ) - output_target = config.build(loader_config, requirements, patch_config) + output_target = config.build( + loader_config, requirements, patch_config, static_inputs + ) # Verify OutputTarget was created assert isinstance(output_target, DownscalingOutput) diff --git a/fme/downscaling/inference/work_items.py b/fme/downscaling/inference/work_items.py index ed40627e0..0b5ef27a5 100644 --- a/fme/downscaling/inference/work_items.py +++ b/fme/downscaling/inference/work_items.py @@ -10,7 +10,7 @@ from fme.core.distributed import Distributed from fme.core.generics.data import SizedMap -from ..data import BatchData, Topography +from ..data import BatchData, StaticInputs from ..data.config import BatchItemDatasetAdapter from .constants import ENSEMBLE_NAME, TIME_NAME @@ -297,7 +297,7 @@ class SliceWorkItemGriddedData: all_times: xr.CFTimeIndex dtype: torch.dtype max_output_shape: tuple[int, ...] - topography: Topography + static_inputs: StaticInputs # TODO: currently no protocol or ABC for gridded data objects # if we want to unify, we will need one and just raise @@ -310,7 +310,7 @@ def on_device(work_item: LoadedSliceWorkItem) -> LoadedSliceWorkItem: return SizedMap(on_device, self._loader) - def get_generator(self) -> Iterator[tuple[LoadedSliceWorkItem, Topography]]: + def get_generator(self) -> Iterator[tuple[LoadedSliceWorkItem, StaticInputs]]: work_item: LoadedSliceWorkItem for work_item in self.loader: - yield work_item, self.topography + yield work_item, self.static_inputs diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 47b9615df..aa6a119a6 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -13,13 +13,7 @@ from fme.core.packer import Packer from fme.core.rand import randn, randn_like from fme.core.typing_ import TensorDict, TensorMapping -from fme.downscaling.data import ( - BatchData, - PairedBatchData, - StaticInputs, - Topography, - get_normalized_topography, -) +from fme.downscaling.data import BatchData, PairedBatchData, StaticInputs from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector from fme.downscaling.requirements import DataRequirements @@ -80,7 +74,7 @@ def load(self): @dataclasses.dataclass class DiffusionModelConfig: - """ + f""" This class implements or wraps the algorithms described in `EDM`_. .. _EDM: https://arxiv.org/abs/2206.00364 @@ -100,6 +94,8 @@ class DiffusionModelConfig: use_fine_topography: Whether to use fine topography in the model. use_amp_bf16: Whether to use automatic mixed precision (bfloat16) in the UNetDiffusionModule. + + """ module: DiffusionModuleRegistrySelector @@ -116,6 +112,7 @@ class DiffusionModelConfig: predict_residual: bool use_fine_topography: bool = False use_amp_bf16: bool = False + static_inputs: dict[str, str] | None = None def __post_init__(self): self._interpolate_input = self.module.expects_interpolated_input @@ -317,7 +314,7 @@ def _get_fine_shape(self, coarse_shape: tuple[int, int]) -> tuple[int, int]: ) def _get_input_from_coarse( - self, coarse: TensorMapping, topography: Topography | None + self, coarse: TensorMapping, static_inputs: StaticInputs | None ) -> torch.Tensor: inputs = filter_tensor_mapping(coarse, self.in_packer.names) normalized = self.in_packer.pack( @@ -326,19 +323,20 @@ def _get_input_from_coarse( interpolated = interpolate(normalized, self.downscale_factor) if self.config.use_fine_topography: - if topography is None: + if static_inputs is None: raise ValueError( - "Topography must be provided for each batch when use of fine " - "topography is enabled." + "Static inputs must be provided for each batch when use of fine " + "static inputs is enabled." ) else: n_batches = normalized.shape[0] - # Join the normalized topography to the input (see dataset for details) - topo = topography.data.unsqueeze(0).repeat(n_batches, 1, 1) - topo = topo.unsqueeze(self._channel_axis) - interpolated = torch.concat( - [interpolated, topo], axis=self._channel_axis - ) + # Join normalized static inputs to input (see dataset for details) + for field in static_inputs.fields: + topo = field.data.unsqueeze(0).repeat(n_batches, 1, 1) + topo = topo.unsqueeze(self._channel_axis) + interpolated = torch.concat( + [interpolated, topo], axis=self._channel_axis + ) if self.config._interpolate_input: return interpolated @@ -347,12 +345,12 @@ def _get_input_from_coarse( def train_on_batch( self, batch: PairedBatchData, - topography: Topography | None, + static_inputs: StaticInputs | None, optimizer: Optimization | NullOptimization, ) -> ModelOutputs: """Performs a denoising training step on a batch of data.""" coarse, fine = batch.coarse.data, batch.fine.data - inputs_norm = self._get_input_from_coarse(coarse, topography) + inputs_norm = self._get_input_from_coarse(coarse, static_inputs) targets_norm = self.out_packer.pack( self.normalizer.fine.normalize(dict(fine)), axis=self._channel_axis ) @@ -397,10 +395,10 @@ def train_on_batch( def generate( self, coarse_data: TensorMapping, - topography: torch.Tensor | None, + static_inputs: StaticInputs | None, n_samples: int = 1, ) -> tuple[TensorDict, torch.Tensor, list[torch.Tensor]]: - inputs_ = self._get_input_from_coarse(coarse_data, topography) + inputs_ = self._get_input_from_coarse(coarse_data, static_inputs) # expand samples and fold to # [batch * n_samples, output_channels, height, width] inputs_ = _repeat_batch_by_samples(inputs_, n_samples) @@ -449,22 +447,22 @@ def generate( def generate_on_batch_no_target( self, batch: BatchData, - topography: Topography | None, + static_inputs: StaticInputs | None, n_samples: int = 1, ) -> TensorDict: - generated, _, _ = self.generate(batch.data, topography, n_samples) + generated, _, _ = self.generate(batch.data, static_inputs, n_samples) return generated @torch.no_grad() def generate_on_batch( self, batch: PairedBatchData, - topography: Topography | None, + static_inputs: StaticInputs | None, n_samples: int = 1, ) -> ModelOutputs: coarse, fine = batch.coarse.data, batch.fine.data generated, generated_norm, latent_steps = self.generate( - coarse, topography, n_samples + coarse, static_inputs, n_samples ) targets_norm = self.out_packer.pack( @@ -611,14 +609,3 @@ def in_names(self): @property def out_names(self): return self._checkpoint["model"]["config"]["out_names"] - - def get_topography(self) -> Topography | None: - if self.data_requirements.use_fine_topography: - if self.fine_topography_path is None: - raise ValueError( - "Topography path must be provided for model configured " - "to use fine topography." - ) - return get_normalized_topography(self.fine_topography_path).to_device() - else: - return None diff --git a/fme/downscaling/predict.py b/fme/downscaling/predict.py index 9daaa778c..aa6f7529f 100644 --- a/fme/downscaling/predict.py +++ b/fme/downscaling/predict.py @@ -111,7 +111,7 @@ def get_gridded_data( ) return event_data_config.build( requirements=requirements, - static_inputs_from_checkpoint=static_inputs_from_checkpoint, + static_inputs=static_inputs_from_checkpoint, ) @@ -153,7 +153,7 @@ def generation_model(self): def run(self): logging.info(f"Running {self.event_name} event downscaling...") - batch, topography = next(iter(self.data.get_generator())) + batch, static_inputs = next(iter(self.data.get_generator())) coarse_coords = batch[0].latlon_coordinates fine_coords = LatLonCoordinates( lat=_downscale_coord(coarse_coords.lat, self.model.downscale_factor), @@ -176,7 +176,7 @@ def run(self): f"for event {self.event_name}" ) outputs = self.model.generate_on_batch_no_target( - batch, topography=topography, n_samples=end_idx - start_idx + batch, static_inputs=static_inputs, n_samples=end_idx - start_idx ) sample_agg.record_batch(outputs) to_log = sample_agg.get_wandb() @@ -247,8 +247,8 @@ def save_netcdf_data(self, ds: xr.Dataset): @property def _fine_latlon_coordinates(self) -> LatLonCoordinates | None: - if self.data.topography is not None: - return self.data.topography.coords + if self.data.static_inputs is not None: + return self.data.static_inputs.coords else: return None @@ -257,12 +257,12 @@ def run(self): downscale_factor=self.model.downscale_factor, latlon_coordinates=self._fine_latlon_coordinates, ) - for i, (batch, topography) in enumerate(self.batch_generator): + for i, (batch, static_inputs) in enumerate(self.batch_generator): with torch.no_grad(): logging.info(f"Generating predictions on batch {i + 1}") prediction = self.generation_model.generate_on_batch_no_target( batch=batch, - topography=topography, + static_inputs=static_inputs, n_samples=self.n_samples, ) logging.info("Recording diagnostics to aggregator") @@ -308,7 +308,7 @@ def build(self) -> list[Downscaler | EventDownscaler]: model = self.model.build() dataset = self.data.build( requirements=self.model.data_requirements, - static_inputs_from_checkpoint=model.static_inputs, + static_inputs=model.static_inputs, ) downscaler = Downscaler( data=dataset, diff --git a/fme/downscaling/predictors/cascade.py b/fme/downscaling/predictors/cascade.py index 2c9fb9d6b..0f64c73b5 100644 --- a/fme/downscaling/predictors/cascade.py +++ b/fme/downscaling/predictors/cascade.py @@ -1,6 +1,5 @@ import dataclasses import math -from collections.abc import Sequence import torch @@ -12,7 +11,7 @@ BatchData, ClosedInterval, PairedBatchData, - Topography, + StaticInputs, adjust_fine_coord_range, scale_tuple, ) @@ -47,11 +46,8 @@ def models(self): self._models = [cfg.build() for cfg in self.cascade_model_checkpoints] return self._models - def get_topographies(self) -> list[Topography | None]: - topographies = [] - for ckpt in self.cascade_model_checkpoints: - topographies.append(ckpt.get_topography()) - return topographies + def get_static_inputs(self) -> list[StaticInputs | None]: + return [model.static_inputs for model in self.models] def build(self): for m in range(len(self.models) - 1): @@ -68,7 +64,7 @@ def build(self): ) return CascadePredictor( - models=self.models, topographies=self.get_topographies() + models=self.models, static_inputs=self.get_static_inputs() ) @property @@ -92,10 +88,10 @@ def _restore_batch_and_sample_dims(data: TensorMapping, n_samples: int): class CascadePredictor: def __init__( - self, models: list[DiffusionModel], topographies: list[Topography | None] + self, models: list[DiffusionModel], static_inputs: list[StaticInputs | None] ): self.models = models - self._topographies = topographies + self._static_inputs = static_inputs self.out_packer = self.models[-1].out_packer self.normalizer = FineResCoarseResPair( coarse=self.models[0].normalizer.coarse, @@ -124,19 +120,18 @@ def generate( self, coarse: TensorMapping, n_samples: int, - topographies=list[Topography | None], + static_inputs: list[StaticInputs | None], ): current_coarse = coarse - for i, (model, fine_topography) in enumerate(zip(self.models, topographies)): + for i, (model, fine_topography) in enumerate(zip(self.models, static_inputs)): sample_data = next(iter(current_coarse.values())) batch_size = sample_data.shape[0] # n_samples are generated for the first step, and subsequent models # generate 1 sample n_samples_cascade_step = n_samples if i == 0 else 1 - _fine_topography = fine_topography.data generated, generated_norm, latent_steps = model.generate( - current_coarse, _fine_topography, n_samples_cascade_step + current_coarse, fine_topography, n_samples_cascade_step ) generated = { k: v.reshape(batch_size * n_samples_cascade_step, *v.shape[-2:]) @@ -150,27 +145,27 @@ def generate( def generate_on_batch_no_target( self, batch: BatchData, - topography: Topography | None, + static_inputs: StaticInputs | None, n_samples: int = 1, ) -> TensorDict: - topographies = self._get_subset_topographies( + subset_static_inputs = self._get_subset_static_inputs( coarse_coords=batch.latlon_coordinates[0] ) - generated, _, _ = self.generate(batch.data, n_samples, topographies) + generated, _, _ = self.generate(batch.data, n_samples, subset_static_inputs) return generated @torch.no_grad() def generate_on_batch( self, batch: PairedBatchData, - topography: Topography | None, + static_inputs: list[StaticInputs | None], n_samples: int = 1, ) -> ModelOutputs: - topographies = self._get_subset_topographies( + static_inputs = self._get_subset_static_inputs( coarse_coords=batch.coarse.latlon_coordinates[0] ) generated, _, latent_steps = self.generate( - batch.coarse.data, n_samples, topographies + batch.coarse.data, n_samples, static_inputs ) targets = filter_tensor_mapping(batch.fine.data, set(self.out_packer.names)) targets = {k: v.unsqueeze(1) for k, v in targets.items()} @@ -182,39 +177,42 @@ def generate_on_batch( latent_steps=latent_steps, ) - def _get_subset_topographies( + def _get_subset_static_inputs( self, coarse_coords: LatLonCoordinates, - ) -> Sequence[Topography | None]: + ) -> list[StaticInputs | None]: # Intermediate topographies are loaded as full range and need to be subset # to the matching lat/lon range for each batch. # TODO: Will eventually move subsetting into checkpoint model. - subset_topographies = [] + subset_static_inputs: list[StaticInputs | None] = [] _coarse_coords = coarse_coords lat_range = _closed_interval_from_coord(_coarse_coords.lat) lon_range = _closed_interval_from_coord(_coarse_coords.lon) - for i, full_intermediate_topography in enumerate(self._topographies): - if full_intermediate_topography is not None: + for i, full_intermediate_static_inputs in enumerate(self._static_inputs): + if full_intermediate_static_inputs is not None: _adjusted_lat_range = adjust_fine_coord_range( lat_range, _coarse_coords.lat, - full_intermediate_topography.coords.lat, + full_intermediate_static_inputs.coords.lat, downscale_factor=self.models[i].downscale_factor, ) _adjusted_lon_range = adjust_fine_coord_range( lon_range, _coarse_coords.lon, - full_intermediate_topography.coords.lon, + full_intermediate_static_inputs.coords.lon, downscale_factor=self.models[i].downscale_factor, ) - subset_interm_topo = full_intermediate_topography.subset_latlon( - lat_interval=_adjusted_lat_range, lon_interval=_adjusted_lon_range + subset_interm_static_inputs = ( + full_intermediate_static_inputs.subset_latlon( + lat_interval=_adjusted_lat_range, + lon_interval=_adjusted_lon_range, + ) ) - _coarse_coords = subset_interm_topo.coords + _coarse_coords = subset_interm_static_inputs.coords lat_range = _closed_interval_from_coord(_coarse_coords.lat) lon_range = _closed_interval_from_coord(_coarse_coords.lon) else: - subset_interm_topo = None - subset_topographies.append(subset_interm_topo) - return subset_topographies + subset_interm_static_inputs = None + subset_static_inputs.append(subset_interm_static_inputs) + return subset_static_inputs diff --git a/fme/downscaling/predictors/composite.py b/fme/downscaling/predictors/composite.py index cea6fd8c3..8cbd08e83 100644 --- a/fme/downscaling/predictors/composite.py +++ b/fme/downscaling/predictors/composite.py @@ -3,7 +3,7 @@ import torch from fme.core.typing_ import TensorDict -from fme.downscaling.data import BatchData, PairedBatchData, Topography, scale_tuple +from fme.downscaling.data import BatchData, PairedBatchData, StaticInputs, scale_tuple from fme.downscaling.data.patching import Patch, get_patches from fme.downscaling.data.utils import null_generator from fme.downscaling.models import DiffusionModel, ModelOutputs @@ -106,7 +106,7 @@ def _get_patches( def generate_on_batch( self, batch: PairedBatchData, - topography: Topography | None, + static_inputs: StaticInputs | None, n_samples: int = 1, ) -> ModelOutputs: predictions = [] @@ -119,14 +119,16 @@ def generate_on_batch( batch_generator = batch.generate_from_patches( coarse_patches=coarse_patches, fine_patches=fine_patches ) - if topography is not None: - topography_generator = topography.generate_from_patches(fine_patches) + if static_inputs is not None: + static_inputs_generator = static_inputs.generate_from_patches(fine_patches) else: - topography_generator = null_generator(len(fine_patches)) + static_inputs_generator = null_generator(len(fine_patches)) - for data_patch, topography_patch in zip(batch_generator, topography_generator): + for data_patch, static_inputs_patch in zip( + batch_generator, static_inputs_generator + ): model_output = self.model.generate_on_batch( - data_patch, topography_patch, n_samples + data_patch, static_inputs_patch, n_samples ) predictions.append(model_output.prediction) loss = loss + model_output.loss @@ -145,7 +147,7 @@ def generate_on_batch( def generate_on_batch_no_target( self, batch: BatchData, - topography: Topography | None, + static_inputs: StaticInputs | None, n_samples: int = 1, ) -> TensorDict: coarse_yx_extent = batch.horizontal_shape @@ -155,15 +157,17 @@ def generate_on_batch_no_target( ) predictions = [] batch_generator = batch.generate_from_patches(coarse_patches) - if topography is not None: - topography_generator = topography.generate_from_patches(fine_patches) + if static_inputs is not None: + static_inputs_generator = static_inputs.generate_from_patches(fine_patches) else: - topography_generator = null_generator(len(fine_patches)) - for data_patch, topo_patch in zip(batch_generator, topography_generator): + static_inputs_generator = null_generator(len(fine_patches)) + for data_patch, static_inputs_patch in zip( + batch_generator, static_inputs_generator + ): predictions.append( self.model.generate_on_batch_no_target( batch=data_patch, - topography=topo_patch, + static_inputs=static_inputs_patch, n_samples=n_samples, ) ) diff --git a/fme/downscaling/predictors/test_cascade.py b/fme/downscaling/predictors/test_cascade.py index 73e9a520d..ec6250b98 100644 --- a/fme/downscaling/predictors/test_cascade.py +++ b/fme/downscaling/predictors/test_cascade.py @@ -5,7 +5,7 @@ from fme.core.device import get_device from fme.core.loss import LossConfig from fme.core.normalizer import NormalizationConfig -from fme.downscaling.data import Topography +from fme.downscaling.data import StaticInput, StaticInputs from fme.downscaling.models import DiffusionModelConfig, PairedNormalizationConfig from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector from fme.downscaling.predictors.cascade import CascadePredictor @@ -48,7 +48,7 @@ def test_CascadePredictor_generate(downscale_factors): n_times, n_samples_generate, nside_coarse = 3, 2, 4 grid_bounds = (0, 100) models = [] - topographies: list[Topography | None] = [] + static_inputs_list: list[StaticInputs | None] = [] input_n_cells = nside_coarse for downscale_factor in downscale_factors: @@ -58,21 +58,27 @@ def test_CascadePredictor_generate(downscale_factors): downscale_factor=downscale_factor, ) ) - topographies.append( - Topography( - data=torch.randn( - input_n_cells * downscale_factor, - input_n_cells * downscale_factor, - device=get_device(), - ), - coords=_latlon_coords_on_ngrid( - n=input_n_cells * downscale_factor, edges=grid_bounds - ), + static_inputs_list.append( + StaticInputs( + fields=[ + StaticInput( + data=torch.randn( + input_n_cells * downscale_factor, + input_n_cells * downscale_factor, + device=get_device(), + ), + coords=_latlon_coords_on_ngrid( + n=input_n_cells * downscale_factor, edges=grid_bounds + ), + ) + ] ) ) input_n_cells *= downscale_factor - cascade_predictor = CascadePredictor(models=models, topographies=topographies) + cascade_predictor = CascadePredictor( + models=models, static_inputs=static_inputs_list + ) coarse_input = { "x": torch.randn( (n_times, nside_coarse, nside_coarse), @@ -81,7 +87,9 @@ def test_CascadePredictor_generate(downscale_factors): ) } generated, _, _ = cascade_predictor.generate( - coarse=coarse_input, n_samples=n_samples_generate, topographies=topographies + coarse=coarse_input, + n_samples=n_samples_generate, + static_inputs=static_inputs_list, ) expected_nside = cascade_predictor.downscale_factor * nside_coarse assert generated["x"].shape == ( @@ -97,7 +105,7 @@ def test_CascadePredictor__subset_topographies(): downscale_factors = [2, 2] grid_bounds = (0, 8) models = [] - topographies: list[Topography | None] = [] + static_inputs_list: list[StaticInputs | None] = [] input_n_cells = nside_coarse for downscale_factor in downscale_factors: @@ -107,36 +115,42 @@ def test_CascadePredictor__subset_topographies(): downscale_factor=downscale_factor, ) ) - topographies.append( - Topography( - data=torch.randn( - input_n_cells * downscale_factor, - input_n_cells * downscale_factor, - device=get_device(), - ), - coords=_latlon_coords_on_ngrid( - n=input_n_cells * downscale_factor, edges=grid_bounds - ), + static_inputs_list.append( + StaticInputs( + fields=[ + StaticInput( + data=torch.randn( + input_n_cells * downscale_factor, + input_n_cells * downscale_factor, + device=get_device(), + ), + coords=_latlon_coords_on_ngrid( + n=input_n_cells * downscale_factor, edges=grid_bounds + ), + ) + ] ) ) input_n_cells *= downscale_factor - cascade_predictor = CascadePredictor(models=models, topographies=topographies) + cascade_predictor = CascadePredictor( + models=models, static_inputs=static_inputs_list + ) # Coarse grid subset has 1.0 grid spacing and midpoints 1.5 ... 4.5 coarse_coords = _latlon_coords_on_ngrid(n=4, edges=(1, 5)) - subset_intermediate_topographies = cascade_predictor._get_subset_topographies( + subset_intermediate_topographies = cascade_predictor._get_subset_static_inputs( coarse_coords=coarse_coords ) # First topography grid 0.5 grid spacing - assert isinstance(subset_intermediate_topographies[0], Topography) + assert isinstance(subset_intermediate_topographies[0], StaticInputs) assert subset_intermediate_topographies[0].shape == (8, 8) assert subset_intermediate_topographies[0].coords.lat[0] == 1.25 assert subset_intermediate_topographies[0].coords.lat[-1] == 4.75 assert subset_intermediate_topographies[0].coords.lon[0] == 1.25 assert subset_intermediate_topographies[0].coords.lon[-1] == 4.75 # Second topography grid has 0.25 grid spacing - assert isinstance(subset_intermediate_topographies[1], Topography) + assert isinstance(subset_intermediate_topographies[1], StaticInputs) assert subset_intermediate_topographies[1].shape == (16, 16) assert subset_intermediate_topographies[1].coords.lat[0] == 1.125 assert subset_intermediate_topographies[1].coords.lat[-1] == 4.875 diff --git a/fme/downscaling/predictors/test_composite.py b/fme/downscaling/predictors/test_composite.py index 6092c5cd9..1b4ebbb83 100644 --- a/fme/downscaling/predictors/test_composite.py +++ b/fme/downscaling/predictors/test_composite.py @@ -6,7 +6,7 @@ from fme.core.device import get_device from fme.core.packer import Packer from fme.downscaling.aggregators.shape_helpers import upsample_tensor -from fme.downscaling.data import BatchData, PairedBatchData, Topography +from fme.downscaling.data import BatchData, PairedBatchData, StaticInput, StaticInputs from fme.downscaling.data.patching import get_patches from fme.downscaling.data.utils import BatchedLatLonCoordinates from fme.downscaling.models import ModelOutputs @@ -51,7 +51,7 @@ def __init__(self, coarse_shape, downscale_factor): self.out_packer = Packer(["x"]) def generate_on_batch( - self, batch: PairedBatchData, topography: Topography | None, n_samples=1 + self, batch: PairedBatchData, static_inputs: StaticInputs | None, n_samples=1 ): prediction_data = { k: v.unsqueeze(1).expand(-1, n_samples, -1, -1) @@ -62,7 +62,7 @@ def generate_on_batch( ) def generate_on_batch_no_target( - self, batch: BatchData, topography: Topography | None, n_samples=1 + self, batch: BatchData, static_inputs: StaticInputs | None, n_samples=1 ): prediction_data = { k: upsample_tensor( @@ -133,11 +133,16 @@ def test_SpatialCompositePredictor_generate_on_batch(patch_size_coarse): paired_batch_data = get_paired_test_data( *coarse_extent, downscale_factor=downscale_factor, batch_size=batch_size ) - topography = Topography( - torch.randn( - coarse_extent[0] * downscale_factor, coarse_extent[1] * downscale_factor - ), - paired_batch_data.fine.latlon_coordinates[0], + static_inputs = StaticInputs( + fields=[ + StaticInput( + data=torch.randn( + coarse_extent[0] * downscale_factor, + coarse_extent[1] * downscale_factor, + ), + coords=paired_batch_data.fine.latlon_coordinates[0], + ) + ] ) predictor = PatchPredictor( @@ -147,7 +152,7 @@ def test_SpatialCompositePredictor_generate_on_batch(patch_size_coarse): ) n_samples_generate = 2 outputs = predictor.generate_on_batch( - paired_batch_data, topography, n_samples=n_samples_generate + paired_batch_data, static_inputs, n_samples=n_samples_generate ) assert outputs.prediction["x"].shape == (batch_size, n_samples_generate, 8, 8) # dummy model predicts same value as fine data for all samples @@ -169,11 +174,16 @@ def test_SpatialCompositePredictor_generate_on_batch_no_target(patch_size_coarse paired_batch_data = get_paired_test_data( *coarse_extent, downscale_factor=downscale_factor, batch_size=batch_size ) - topography = Topography( - torch.randn( - coarse_extent[0] * downscale_factor, coarse_extent[1] * downscale_factor - ), - paired_batch_data.fine.latlon_coordinates[0], + static_inputs = StaticInputs( + fields=[ + StaticInput( + data=torch.randn( + coarse_extent[0] * downscale_factor, + coarse_extent[1] * downscale_factor, + ), + coords=paired_batch_data.fine.latlon_coordinates[0], + ) + ] ) predictor = PatchPredictor( DummyModel(coarse_shape=patch_size_coarse, downscale_factor=2), # type: ignore @@ -183,6 +193,6 @@ def test_SpatialCompositePredictor_generate_on_batch_no_target(patch_size_coarse n_samples_generate = 2 coarse_batch_data = paired_batch_data.coarse prediction = predictor.generate_on_batch_no_target( - coarse_batch_data, topography, n_samples=n_samples_generate + coarse_batch_data, static_inputs, n_samples=n_samples_generate ) assert prediction["x"].shape == (batch_size, n_samples_generate, 8, 8) diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index eb8c3b19b..bbf696655 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -10,7 +10,7 @@ from fme.core.loss import LossConfig from fme.core.normalizer import NormalizationConfig from fme.core.optimization import OptimizationConfig -from fme.downscaling.data import Topography +from fme.downscaling.data import StaticInput, StaticInputs from fme.downscaling.models import ( DiffusionModel, DiffusionModelConfig, @@ -150,21 +150,26 @@ def test_diffusion_model_train_and_generate(predict_residual, use_fine_topograph [batch_size, *coarse_shape], [batch_size, *fine_shape] ) if use_fine_topography: - topography = Topography( - torch.ones(*fine_shape, device=get_device()), - LatLonCoordinates( - lat=torch.ones(fine_shape[0]), lon=torch.ones(fine_shape[1]) - ), + static_inputs = StaticInputs( + fields=[ + StaticInput( + torch.ones(*fine_shape, device=get_device()), + LatLonCoordinates( + lat=torch.ones(fine_shape[0]), lon=torch.ones(fine_shape[1]) + ), + ) + ] ) else: - topography = None + static_inputs = None optimization = OptimizationConfig().build(modules=[model.module], max_epochs=2) - train_outputs = model.train_on_batch(batch, topography, optimization) + train_outputs = model.train_on_batch(batch, static_inputs, optimization) assert torch.allclose(train_outputs.target["x"], batch.fine.data["x"]) n_generated_samples = 2 generated_outputs = [ - model.generate_on_batch(batch, topography) for _ in range(n_generated_samples) + model.generate_on_batch(batch, static_inputs) + for _ in range(n_generated_samples) ] for generated_output in generated_outputs: @@ -285,7 +290,7 @@ def test_model_error_cases(): # missing fine topography when model requires it batch.fine.topography = None with pytest.raises(ValueError): - model.generate_on_batch(batch, topography=None) + model.generate_on_batch(batch, static_inputs=None) def test_DiffusionModel_generate_on_batch_no_target(): @@ -306,13 +311,19 @@ def test_DiffusionModel_generate_on_batch_no_target(): coarse_batch = get_mock_batch( [batch_size, *coarse_shape], topography_scale_factor=downscale_factor ) - topography = Topography( - torch.rand(*fine_shape, device=get_device()), - LatLonCoordinates(lat=torch.ones(fine_shape[0]), lon=torch.ones(fine_shape[1])), + static_inputs = StaticInputs( + fields=[ + StaticInput( + torch.rand(*fine_shape, device=get_device()), + LatLonCoordinates( + lat=torch.ones(fine_shape[0]), lon=torch.ones(fine_shape[1]) + ), + ) + ] ) samples = model.generate_on_batch_no_target( coarse_batch, - topography=topography, + static_inputs=static_inputs, n_samples=n_generated_samples, ) @@ -344,12 +355,18 @@ def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): [batch_size, *alternative_input_shape], topography_scale_factor=downscale_factor, ) - topography = Topography( - torch.rand(*fine_shape, device=get_device()), - LatLonCoordinates(torch.ones(fine_shape[0]), torch.ones(fine_shape[1])), + static_inputs = StaticInputs( + fields=[ + StaticInput( + torch.rand(*fine_shape, device=get_device()), + LatLonCoordinates( + torch.ones(fine_shape[0]), torch.ones(fine_shape[1]) + ), + ) + ] ) samples = model.generate_on_batch_no_target( - coarse_batch, n_samples=n_ensemble, topography=topography + coarse_batch, n_samples=n_ensemble, static_inputs=static_inputs ) assert samples["x"].shape == ( diff --git a/fme/downscaling/test_predict.py b/fme/downscaling/test_predict.py index e4376068e..66db7ada4 100644 --- a/fme/downscaling/test_predict.py +++ b/fme/downscaling/test_predict.py @@ -10,7 +10,7 @@ from fme.core.normalizer import NormalizationConfig from fme.core.testing.wandb import mock_wandb from fme.downscaling import predict -from fme.downscaling.data import StaticInputs, Topography +from fme.downscaling.data import StaticInput, StaticInputs from fme.downscaling.models import DiffusionModelConfig, PairedNormalizationConfig from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector from fme.downscaling.test_models import LinearDownscaling @@ -22,7 +22,11 @@ def forward(self, latent, coarse, noise_level): # type: ignore return super().forward(coarse) -def get_model_config(coarse_shape: tuple[int, int], downscale_factor: int): +def get_model_config( + coarse_shape: tuple[int, int], + downscale_factor: int, + use_fine_topography: bool = True, +): fine_shape = ( coarse_shape[0] * downscale_factor, coarse_shape[1] * downscale_factor, @@ -34,7 +38,7 @@ def get_model_config(coarse_shape: tuple[int, int], downscale_factor: int): "module": LinearDownscalingDiffusion( factor=1, # will pass coarse input interpolated to fine shape fine_img_shape=fine_shape, - n_channels_in=3, + n_channels_in=3 if use_fine_topography else 2, n_channels_out=2, ) }, @@ -58,7 +62,7 @@ def get_model_config(coarse_shape: tuple[int, int], downscale_factor: int): churn=1, num_diffusion_generation_steps=2, predict_residual=True, - use_fine_topography=True, + use_fine_topography=use_fine_topography, ) @@ -99,7 +103,7 @@ def create_predictor_config( return out_path -@pytest.mark.parametrize("static_inputs_on_model", [True, False]) +@pytest.mark.parametrize("static_inputs_on_model", [True]) def test_predictor_runs(static_inputs_on_model, tmp_path, very_fast_only: bool): if very_fast_only: pytest.skip("Skipping non-fast tests") @@ -126,7 +130,7 @@ def test_predictor_runs(static_inputs_on_model, tmp_path, very_fast_only: bool): topo_data = fine_data["HGTsfc"] model.static_inputs = StaticInputs( [ - Topography( + StaticInput( data=torch.randn(topo_data.shape[-2:]), coords=LatLonCoordinates( lat=torch.tensor(topo_data.lat.values), @@ -177,13 +181,16 @@ def test_predictor_renaming( "rename": {"var0": "var0_renamed", "var1": "var1_renamed"} }, ) - model_config = get_model_config(coarse_shape, downscale_factor) + model_config = get_model_config( + coarse_shape, downscale_factor, use_fine_topography=False + ) model = model_config.build(coarse_shape=coarse_shape, downscale_factor=2) with open(predictor_config_path) as f: predictor_config = yaml.safe_load(f) os.makedirs( os.path.join(predictor_config["experiment_dir"], "checkpoints"), exist_ok=True ) + torch.save( { "model": model.get_state(), diff --git a/fme/downscaling/test_train.py b/fme/downscaling/test_train.py index dc0f70dd6..3e65c8bff 100644 --- a/fme/downscaling/test_train.py +++ b/fme/downscaling/test_train.py @@ -86,6 +86,9 @@ def _create_config_dict( experiment_dir = tmp_path / "output" experiment_dir.mkdir() + config["static_inputs"] = { + "HGTsfc": str(train_paths.fine) + "/data.nc", + } config["train_data"]["fine"] = [{"data_path": str(train_paths.fine)}] config["train_data"]["coarse"] = [{"data_path": str(train_paths.coarse)}] config["validation_data"]["fine"] = [ diff --git a/fme/downscaling/train.py b/fme/downscaling/train.py index d155adf3a..9aaa7de0d 100755 --- a/fme/downscaling/train.py +++ b/fme/downscaling/train.py @@ -13,7 +13,6 @@ import fme.core.logging_utils as logging_utils from fme.core.cli import prepare_directory -from fme.core.dataset.xarray import get_raw_paths from fme.core.device import get_device from fme.core.dicts import to_flat_dict from fme.core.distributed import Distributed @@ -185,11 +184,11 @@ def train_one_epoch(self) -> None: self.train_data, random_offset=True, shuffle=True ) outputs = None - for i, (batch, topography) in enumerate(train_batch_generator): + for i, (batch, static_inputs) in enumerate(train_batch_generator): self.num_batches_seen += 1 if i % 10 == 0: logging.info(f"Training on batch {i+1}") - outputs = self.model.train_on_batch(batch, topography, self.optimization) + outputs = self.model.train_on_batch(batch, static_inputs, self.optimization) self.ema(self.model.modules) with torch.no_grad(): train_aggregator.record_batch( @@ -261,9 +260,9 @@ def valid_one_epoch(self) -> dict[str, float]: validation_batch_generator = self._get_batch_generator( self.validation_data, random_offset=False, shuffle=False ) - for batch, topography in validation_batch_generator: + for batch, static_inputs in validation_batch_generator: outputs = self.model.train_on_batch( - batch, topography, self.null_optimization + batch, static_inputs, self.null_optimization ) validation_aggregator.record_batch( outputs=outputs, @@ -272,7 +271,7 @@ def valid_one_epoch(self) -> dict[str, float]: ) generated_outputs = self.model.generate_on_batch( batch, - topography=topography, + static_inputs=static_inputs, n_samples=self.config.generate_n_samples, ) # Add sample dimension to coarse values for generation comparison @@ -407,6 +406,7 @@ class TrainerConfig: experiment_dir: str save_checkpoints: bool logging: LoggingConfig + static_inputs: dict[str, str] | None = None ema: EMAConfig = dataclasses.field(default_factory=EMAConfig) validate_using_ema: bool = False generate_n_samples: int = 1 @@ -434,11 +434,23 @@ def checkpoint_dir(self) -> str: return os.path.join(self.experiment_dir, "checkpoints") def build(self) -> Trainer: + static_inputs_fields = self.static_inputs or {} + static_inputs = StaticInputs( + fields=[ + get_normalized_topography(path, topography_name=key) + for key, path in static_inputs_fields.items() + ] + ) + train_data: PairedGriddedData = self.train_data.build( - train=True, requirements=self.model.data_requirements + train=True, + requirements=self.model.data_requirements, + static_inputs=static_inputs, ) validation_data: PairedGriddedData = self.validation_data.build( - train=False, requirements=self.model.data_requirements + train=False, + requirements=self.model.data_requirements, + static_inputs=static_inputs, ) if self.coarse_patch_extent_lat and self.coarse_patch_extent_lon: model_coarse_shape = ( @@ -448,18 +460,10 @@ def build(self) -> Trainer: else: model_coarse_shape = train_data.coarse_shape - # load full spatial range of topography to save with model - # TODO: this will be replaced in the future with a more general call - # to get normalized static inputs from a model config field - full_topography = get_normalized_topography( - get_raw_paths( - self.train_data.fine[0].data_path, self.train_data.fine[0].file_pattern - )[0] - ) downscaling_model = self.model.build( model_coarse_shape, train_data.downscale_factor, - static_inputs=StaticInputs([full_topography]), + static_inputs=static_inputs, ) optimization = self.optimization.build(