Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions parcels/_datasets/structured/grid_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import numpy as np
import xarray as xr

__all__ = ["N", "T", "datasets"]

N = 30
T = 10


def rotated_curvilinear_grid():
def _rotated_curvilinear_grid():
XG = np.arange(N)
YG = np.arange(2 * N)
LON, LAT = np.meshgrid(XG, YG)
Expand Down Expand Up @@ -66,7 +68,7 @@ def _polar_to_cartesian(r, theta):
return x, y


def unrolled_cone_curvilinear_grid():
def _unrolled_cone_curvilinear_grid():
# Not a great unrolled cone, but this is good enough for testing
# you can use matplotlib pcolormesh to plot
XG = np.arange(N)
Expand Down Expand Up @@ -126,7 +128,7 @@ def unrolled_cone_curvilinear_grid():


datasets = {
"2d_left_rotated": rotated_curvilinear_grid(),
"2d_left_rotated": _rotated_curvilinear_grid(),
"ds_2d_left": xr.Dataset(
{
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
Expand Down Expand Up @@ -165,5 +167,5 @@ def unrolled_cone_curvilinear_grid():
"time": (["time"], np.arange(T), {"axis": "T"}),
},
),
"2d_left_unrolled_cone": unrolled_cone_curvilinear_grid(),
"2d_left_unrolled_cone": _unrolled_cone_curvilinear_grid(),
}
76 changes: 55 additions & 21 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from parcels._typing import (
Mesh,
VectorType,
assert_valid_mesh,
)
from parcels.tools._helpers import default_repr, field_repr
from parcels.tools.converters import (
Expand All @@ -26,6 +27,7 @@
FieldSamplingError,
_raise_field_out_of_bound_error,
)
from parcels.v4.grid import Grid

from ._index_search import _search_indices_rectilinear, _search_time_index

Expand Down Expand Up @@ -142,16 +144,37 @@ def __init__(
self,
name: str,
data: xr.DataArray | ux.UxDataArray,
grid: ux.Grid | None = None, # TODO Nick : Once parcels.Grid class is added, allow for it to be passed here
grid: ux.Grid | Grid,
mesh_type: Mesh = "flat",
interp_method: Callable | None = None,
allow_time_extrapolation: bool | None = None,
):
if not isinstance(data, (ux.UxDataArray, xr.DataArray)):
raise ValueError(
f"Expected `data` to be a uxarray.UxDataArray or xarray.DataArray object, got {type(data)}."
)
if not isinstance(name, str):
raise ValueError(f"Expected `name` to be a string, got {type(name)}.")
if not isinstance(grid, (ux.Grid, Grid)):
raise ValueError(f"Expected `grid` to be a uxarray.Grid or parcels Grid object, got {type(grid)}.")

assert_valid_mesh(mesh_type)

_assert_compatible_combination(data, grid)

self.name = name
self.data = data
self.grid = grid

_validate_dataarray(data, name)
try:
if isinstance(data, ux.UxDataArray):
_assert_valid_uxdataarray(data)
# TODO: For unstructured grids, validate that `data.uxgrid` is the same as `grid`
else:
pass # TODO v4: Add validation for xr.DataArray objects
except Exception as e:
e.add_note(f"Error validating field {name!r}.")
raise e

self._parent_mesh = data.attrs["mesh"]
self._mesh_type = mesh_type
Expand Down Expand Up @@ -631,47 +654,58 @@ def __getitem__(self, key):
return _deal_with_errors(error, key, vector_type=self.vector_type)


def _validate_dataarray(data, name):
def _assert_valid_uxdataarray(data: ux.UxDataArray):
"""Verifies that all the required attributes are present in the xarray.DataArray or
uxarray.UxDataArray object.
"""
if isinstance(data, ux.UxDataArray):
# Validate dimensions
if not ("nz1" in data.dims or "nz" in data.dims):
raise ValueError(
f"Field {name} is missing a 'nz1' or 'nz' dimension in the field's metadata. "
"This attribute is required for xarray.DataArray objects."
)
# Validate dimensions
if not ("nz1" in data.dims or "nz" in data.dims):
raise ValueError(
"Field is missing a 'nz1' or 'nz' dimension in the field's metadata. "
"This attribute is required for xarray.DataArray objects."
)

if "time" not in data.dims:
raise ValueError(
f"Field {name} is missing a 'time' dimension in the field's metadata. "
"This attribute is required for xarray.DataArray objects."
)
if "time" not in data.dims:
raise ValueError(
"Field is missing a 'time' dimension in the field's metadata. "
"This attribute is required for xarray.DataArray objects."
)

# Validate attributes
required_keys = ["location", "mesh"]
for key in required_keys:
if key not in data.attrs.keys():
raise ValueError(
f"Field {name} is missing a '{key}' attribute in the field's metadata. "
f"Field is missing a '{key}' attribute in the field's metadata. "
"This attribute is required for xarray.DataArray objects."
)

if type(data) is ux.UxDataArray:
_validate_uxgrid(data.uxgrid, name)
_assert_valid_uxgrid(data.uxgrid)


def _validate_uxgrid(grid, name):
def _assert_valid_uxgrid(grid):
"""Verifies that all the required attributes are present in the uxarray.UxDataArray.UxGrid object."""
if "Conventions" not in grid.attrs.keys():
raise ValueError(
f"Field {name} is missing a 'Conventions' attribute in the field's metadata. "
"Field is missing a 'Conventions' attribute in the field's metadata. "
"This attribute is required for uxarray.UxDataArray objects."
)
if grid.attrs["Conventions"] != "UGRID-1.0":
raise ValueError(
f"Field {name} has a 'Conventions' attribute that is not 'UGRID-1.0'. "
"Field has a 'Conventions' attribute that is not 'UGRID-1.0'. "
"This attribute is required for uxarray.UxDataArray objects."
"See https://ugrid-conventions.github.io/ugrid-conventions/ for more information."
)


def _assert_compatible_combination(data: xr.DataArray | ux.UxDataArray, grid: ux.Grid | Grid):
if isinstance(data, ux.UxDataArray):
if not isinstance(grid, ux.Grid):
raise ValueError(
f"Incompatible data-grid combination. Data is a uxarray.UxDataArray, expected `grid` to be a uxarray.Grid object, got {type(grid)}."
)
elif isinstance(data, xr.DataArray):
if not isinstance(grid, Grid):
raise ValueError(
f"Incompatible data-grid combination. Data is a xarray.DataArray, expected `grid` to be a parcels Grid object, got {type(grid)}."
)
85 changes: 85 additions & 0 deletions tests/v4/test_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import pytest
import uxarray as ux
import xarray as xr

from parcels import Field
from parcels._datasets.structured.grid_datasets import datasets as structured_datasets
from parcels.v4.grid import Grid


def test_field_init_param_types():
data = xr.DataArray(
attrs={
"location": "node",
"mesh": "flat",
}
)
grid = Grid(data)
with pytest.raises(ValueError, match="Expected `name` to be a string"):
Field(name=123, data=data, grid=grid)

with pytest.raises(ValueError, match="Expected `data` to be a uxarray.UxDataArray or xarray.DataArray"):
Field(name="test", data=123, grid=grid)

with pytest.raises(ValueError, match="Expected `grid` to be a uxarray.Grid or parcels Grid"):
Field(name="test", data=data, grid=123)

with pytest.raises(ValueError, match="Invalid value 'invalid'. Valid options are.*"):
Field(name="test", data=data, grid=grid, mesh_type="invalid")


@pytest.mark.parametrize(
"data,grid",
[
pytest.param(ux.UxDataArray(), Grid(xr.Dataset()), id="uxdata-grid"),
pytest.param(xr.DataArray(), ux.UxDataArray().uxgrid, id="xarray-uxgrid"),
],
)
def test_field_incompatible_combination(data, grid):
with pytest.raises(ValueError, msg="Incompatible data-grid combination."):
Field(
name="test_field",
data=data,
grid=grid,
)


@pytest.mark.parametrize(
"data,grid",
[
pytest.param(
structured_datasets["ds_2d_left"]["data_g"], Grid(structured_datasets["ds_2d_left"]), id="ds_2d_left"
),
],
)
@pytest.mark.xfail(reason="Structured grid creation is not implemented yet")
def test_field_structured_grid_creation(data, grid):
"""Test creating a field."""
field = Field(
name="test_field",
data=data,
grid=grid,
)
assert field.name == "test_field"
assert field.data == data
assert field.grid == grid


def test_field_structured_grid_creation_spherical():
# Field(..., mesh_type="spherical")
...


def test_field_unstructured_grid_creation(): ...


def test_field_interpolation(): ...


def test_field_interpolation_out_of_spatial_bounds(): ...


def test_field_interpolation_out_of_time_bounds(): ...


def test_field_allow_time_extrapolation(): ...
Loading