diff --git a/parcels/_datasets/structured/grid_datasets.py b/parcels/_datasets/structured/grid_datasets.py index 7f95aff504..92fa51ba25 100644 --- a/parcels/_datasets/structured/grid_datasets.py +++ b/parcels/_datasets/structured/grid_datasets.py @@ -3,11 +3,13 @@ import numpy as np import xarray as xr +__all__ = ["N", "T", "datasets"] + N = 30 T = 10 -def rotated_curvilinear_grid(): +def _rotated_curvilinear_grid(): XG = np.arange(N) YG = np.arange(2 * N) LON, LAT = np.meshgrid(XG, YG) @@ -66,7 +68,7 @@ def _polar_to_cartesian(r, theta): return x, y -def unrolled_cone_curvilinear_grid(): +def _unrolled_cone_curvilinear_grid(): # Not a great unrolled cone, but this is good enough for testing # you can use matplotlib pcolormesh to plot XG = np.arange(N) @@ -126,7 +128,7 @@ def unrolled_cone_curvilinear_grid(): datasets = { - "2d_left_rotated": rotated_curvilinear_grid(), + "2d_left_rotated": _rotated_curvilinear_grid(), "ds_2d_left": xr.Dataset( { "data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)), @@ -165,5 +167,5 @@ def unrolled_cone_curvilinear_grid(): "time": (["time"], np.arange(T), {"axis": "T"}), }, ), - "2d_left_unrolled_cone": unrolled_cone_curvilinear_grid(), + "2d_left_unrolled_cone": _unrolled_cone_curvilinear_grid(), } diff --git a/parcels/field.py b/parcels/field.py index d739b3f86c..364e483549 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -13,6 +13,7 @@ from parcels._typing import ( Mesh, VectorType, + assert_valid_mesh, ) from parcels.tools._helpers import default_repr, field_repr from parcels.tools.converters import ( @@ -26,6 +27,7 @@ FieldSamplingError, _raise_field_out_of_bound_error, ) +from parcels.v4.grid import Grid from ._index_search import _search_indices_rectilinear, _search_time_index @@ -142,16 +144,37 @@ def __init__( self, name: str, data: xr.DataArray | ux.UxDataArray, - grid: ux.Grid | None = None, # TODO Nick : Once parcels.Grid class is added, allow for it to be passed here + grid: ux.Grid | Grid, mesh_type: Mesh = "flat", interp_method: Callable | None = None, allow_time_extrapolation: bool | None = None, ): + if not isinstance(data, (ux.UxDataArray, xr.DataArray)): + raise ValueError( + f"Expected `data` to be a uxarray.UxDataArray or xarray.DataArray object, got {type(data)}." + ) + if not isinstance(name, str): + raise ValueError(f"Expected `name` to be a string, got {type(name)}.") + if not isinstance(grid, (ux.Grid, Grid)): + raise ValueError(f"Expected `grid` to be a uxarray.Grid or parcels Grid object, got {type(grid)}.") + + assert_valid_mesh(mesh_type) + + _assert_compatible_combination(data, grid) + self.name = name self.data = data self.grid = grid - _validate_dataarray(data, name) + try: + if isinstance(data, ux.UxDataArray): + _assert_valid_uxdataarray(data) + # TODO: For unstructured grids, validate that `data.uxgrid` is the same as `grid` + else: + pass # TODO v4: Add validation for xr.DataArray objects + except Exception as e: + e.add_note(f"Error validating field {name!r}.") + raise e self._parent_mesh = data.attrs["mesh"] self._mesh_type = mesh_type @@ -631,47 +654,58 @@ def __getitem__(self, key): return _deal_with_errors(error, key, vector_type=self.vector_type) -def _validate_dataarray(data, name): +def _assert_valid_uxdataarray(data: ux.UxDataArray): """Verifies that all the required attributes are present in the xarray.DataArray or uxarray.UxDataArray object. """ - if isinstance(data, ux.UxDataArray): - # Validate dimensions - if not ("nz1" in data.dims or "nz" in data.dims): - raise ValueError( - f"Field {name} is missing a 'nz1' or 'nz' dimension in the field's metadata. " - "This attribute is required for xarray.DataArray objects." - ) + # Validate dimensions + if not ("nz1" in data.dims or "nz" in data.dims): + raise ValueError( + "Field is missing a 'nz1' or 'nz' dimension in the field's metadata. " + "This attribute is required for xarray.DataArray objects." + ) - if "time" not in data.dims: - raise ValueError( - f"Field {name} is missing a 'time' dimension in the field's metadata. " - "This attribute is required for xarray.DataArray objects." - ) + if "time" not in data.dims: + raise ValueError( + "Field is missing a 'time' dimension in the field's metadata. " + "This attribute is required for xarray.DataArray objects." + ) # Validate attributes required_keys = ["location", "mesh"] for key in required_keys: if key not in data.attrs.keys(): raise ValueError( - f"Field {name} is missing a '{key}' attribute in the field's metadata. " + f"Field is missing a '{key}' attribute in the field's metadata. " "This attribute is required for xarray.DataArray objects." ) - if type(data) is ux.UxDataArray: - _validate_uxgrid(data.uxgrid, name) + _assert_valid_uxgrid(data.uxgrid) -def _validate_uxgrid(grid, name): +def _assert_valid_uxgrid(grid): """Verifies that all the required attributes are present in the uxarray.UxDataArray.UxGrid object.""" if "Conventions" not in grid.attrs.keys(): raise ValueError( - f"Field {name} is missing a 'Conventions' attribute in the field's metadata. " + "Field is missing a 'Conventions' attribute in the field's metadata. " "This attribute is required for uxarray.UxDataArray objects." ) if grid.attrs["Conventions"] != "UGRID-1.0": raise ValueError( - f"Field {name} has a 'Conventions' attribute that is not 'UGRID-1.0'. " + "Field has a 'Conventions' attribute that is not 'UGRID-1.0'. " "This attribute is required for uxarray.UxDataArray objects." "See https://ugrid-conventions.github.io/ugrid-conventions/ for more information." ) + + +def _assert_compatible_combination(data: xr.DataArray | ux.UxDataArray, grid: ux.Grid | Grid): + if isinstance(data, ux.UxDataArray): + if not isinstance(grid, ux.Grid): + raise ValueError( + f"Incompatible data-grid combination. Data is a uxarray.UxDataArray, expected `grid` to be a uxarray.Grid object, got {type(grid)}." + ) + elif isinstance(data, xr.DataArray): + if not isinstance(grid, Grid): + raise ValueError( + f"Incompatible data-grid combination. Data is a xarray.DataArray, expected `grid` to be a parcels Grid object, got {type(grid)}." + ) diff --git a/tests/v4/test_field.py b/tests/v4/test_field.py new file mode 100644 index 0000000000..082918567e --- /dev/null +++ b/tests/v4/test_field.py @@ -0,0 +1,85 @@ +import pytest +import uxarray as ux +import xarray as xr + +from parcels import Field +from parcels._datasets.structured.grid_datasets import datasets as structured_datasets +from parcels.v4.grid import Grid + + +def test_field_init_param_types(): + data = xr.DataArray( + attrs={ + "location": "node", + "mesh": "flat", + } + ) + grid = Grid(data) + with pytest.raises(ValueError, match="Expected `name` to be a string"): + Field(name=123, data=data, grid=grid) + + with pytest.raises(ValueError, match="Expected `data` to be a uxarray.UxDataArray or xarray.DataArray"): + Field(name="test", data=123, grid=grid) + + with pytest.raises(ValueError, match="Expected `grid` to be a uxarray.Grid or parcels Grid"): + Field(name="test", data=data, grid=123) + + with pytest.raises(ValueError, match="Invalid value 'invalid'. Valid options are.*"): + Field(name="test", data=data, grid=grid, mesh_type="invalid") + + +@pytest.mark.parametrize( + "data,grid", + [ + pytest.param(ux.UxDataArray(), Grid(xr.Dataset()), id="uxdata-grid"), + pytest.param(xr.DataArray(), ux.UxDataArray().uxgrid, id="xarray-uxgrid"), + ], +) +def test_field_incompatible_combination(data, grid): + with pytest.raises(ValueError, msg="Incompatible data-grid combination."): + Field( + name="test_field", + data=data, + grid=grid, + ) + + +@pytest.mark.parametrize( + "data,grid", + [ + pytest.param( + structured_datasets["ds_2d_left"]["data_g"], Grid(structured_datasets["ds_2d_left"]), id="ds_2d_left" + ), + ], +) +@pytest.mark.xfail(reason="Structured grid creation is not implemented yet") +def test_field_structured_grid_creation(data, grid): + """Test creating a field.""" + field = Field( + name="test_field", + data=data, + grid=grid, + ) + assert field.name == "test_field" + assert field.data == data + assert field.grid == grid + + +def test_field_structured_grid_creation_spherical(): + # Field(..., mesh_type="spherical") + ... + + +def test_field_unstructured_grid_creation(): ... + + +def test_field_interpolation(): ... + + +def test_field_interpolation_out_of_spatial_bounds(): ... + + +def test_field_interpolation_out_of_time_bounds(): ... + + +def test_field_allow_time_extrapolation(): ...