From bd4b67daa7a1b2275f2f5393d8fe236bce2c9421 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 28 Apr 2025 11:33:22 +0200 Subject: [PATCH 1/9] Remove Field._spatialhash get_spatial_hash() returns the cached spatial hash be default anyway --- parcels/field.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 364e483549..8323380d75 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -203,7 +203,6 @@ def __init__( self.allow_time_extrapolation = allow_time_extrapolation if type(self.data) is ux.UxDataArray: - self._spatialhash = self.grid.get_spatial_hash() self._gtype = None # Set the vertical location if "nz1" in data.dims: @@ -211,7 +210,6 @@ def __init__( elif "nz" in data.dims: self._vertical_location = "face" else: # TODO Nick : This bit probably needs an overhaul once the parcels.Grid class is integrated. - self._spatialhash = None # Set the grid type if "x_g" in self.data.coords: lon = self.data.x_g @@ -365,7 +363,7 @@ def _search_indices_unstructured(self, z, y, x, ei=None, search2D=False): tol = 1e-10 if ei is None: # Search using global search - fi, bcoords = self._spatialhash.query([[x, y]]) # Get the face id for the particle + fi, bcoords = self.grid.get_spatial_hash().query([[x, y]]) # Get the face id for the particle if fi == -1: raise FieldOutOfBoundError(z, y, x) # TODO Joe : Do the vertical grid search @@ -389,7 +387,7 @@ def _search_indices_unstructured(self, z, y, x, ei=None, search2D=False): return bcoords, self.ravel_index(zi, 0, neighbor) # If we reach this point, we do a global search as a last ditch effort the particle is out of bounds - fi, bcoords = self._spatialhash.query([[x, y]]) # Get the face id for the particle + fi, bcoords = self.grid.get_spatial_hash().query([[x, y]]) # Get the face id for the particle if fi == -1: raise FieldOutOfBoundError(z, y, x) From 0a683d67dfa63aa0ea6da710d87e9af0bd7adea2 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 28 Apr 2025 11:34:40 +0200 Subject: [PATCH 2/9] Remove Field._location --- parcels/field.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 8323380d75..ea15552bdd 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -178,7 +178,6 @@ def __init__( self._parent_mesh = data.attrs["mesh"] self._mesh_type = mesh_type - self._location = data.attrs["location"] self._vertical_location = None # Setting the interpolation method dynamically @@ -262,11 +261,11 @@ def units(self, value): @property def lat(self): if type(self.data) is ux.UxDataArray: - if self._location == "node": + if self.data.attrs["location"] == "node": return self.grid.node_lat - elif self._location == "face": + elif self.data.attrs["location"] == "face": return self.grid.face_lat - elif self._location == "edge": + elif self.data.attrs["location"] == "edge": return self.grid.edge_lat else: return self.data.lat @@ -274,11 +273,11 @@ def lat(self): @property def lon(self): if type(self.data) is ux.UxDataArray: - if self._location == "node": + if self.data.attrs["location"] == "node": return self.grid.node_lon - elif self._location == "face": + elif self.data.attrs["location"] == "face": return self.grid.face_lon - elif self._location == "edge": + elif self.data.attrs["location"] == "edge": return self.grid.edge_lon else: return self.data.lon From 7a8de729f3338eb59f8e2434f3f430fa30b53269 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 28 Apr 2025 11:54:16 +0200 Subject: [PATCH 3/9] Add _core.utils subpackage Contributes to #1965 --- parcels/_core/utils/__init__.py | 0 parcels/_core/utils/common.py | 0 parcels/_core/utils/structured.py | 0 parcels/_core/utils/unstructured.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 parcels/_core/utils/__init__.py create mode 100644 parcels/_core/utils/common.py create mode 100644 parcels/_core/utils/structured.py create mode 100644 parcels/_core/utils/unstructured.py diff --git a/parcels/_core/utils/__init__.py b/parcels/_core/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/parcels/_core/utils/common.py b/parcels/_core/utils/common.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/parcels/_core/utils/structured.py b/parcels/_core/utils/structured.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/parcels/_core/utils/unstructured.py b/parcels/_core/utils/unstructured.py new file mode 100644 index 0000000000..e69de29bb2 From 9549629adf3bc26365f540e6b741abb7a596548d Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 28 Apr 2025 13:09:59 +0200 Subject: [PATCH 4/9] Remove Field._vertical_location --- parcels/_core/utils/unstructured.py | 28 ++++++++++++++++++++++++ parcels/field.py | 12 ++++------- tests/v4/utils/test_unstructured.py | 33 +++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 8 deletions(-) create mode 100644 tests/v4/utils/test_unstructured.py diff --git a/parcels/_core/utils/unstructured.py b/parcels/_core/utils/unstructured.py index e69de29bb2..76c4c93d7f 100644 --- a/parcels/_core/utils/unstructured.py +++ b/parcels/_core/utils/unstructured.py @@ -0,0 +1,28 @@ +from collections.abc import Hashable + +DIM_TO_VERTICAL_LOCATION_MAP = { + "nz1": "center", + "nz": "face", +} + + +def get_vertical_location_from_dims(dims: tuple[Hashable, ...]): + """ + Determine the vertical location of the field based on the uxarray.UxDataArray object variables. + + Only used for unstructured grids. + """ + vertical_dims_in_data = set(dims) & set(DIM_TO_VERTICAL_LOCATION_MAP.keys()) + + if len(vertical_dims_in_data) != 1: + raise ValueError( + f"Expected exactly one vertical dimension ({set(DIM_TO_VERTICAL_LOCATION_MAP.keys())}) in the data, got {vertical_dims_in_data}" + ) + + return DIM_TO_VERTICAL_LOCATION_MAP[vertical_dims_in_data.pop()] + + +def get_vertical_dim_name_from_location(location: str): + """Determine the vertical location of the field based on the uxarray.UxGrid object variables.""" + location_to_dim_map = {v: k for k, v in DIM_TO_VERTICAL_LOCATION_MAP.items()} + return location_to_dim_map[location] diff --git a/parcels/field.py b/parcels/field.py index ea15552bdd..3ac68a6eaa 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -10,6 +10,7 @@ import xarray as xr from uxarray.grid.neighbors import _barycentric_coordinates +from parcels._core.utils.unstructured import get_vertical_location_from_dims from parcels._typing import ( Mesh, VectorType, @@ -178,7 +179,6 @@ def __init__( self._parent_mesh = data.attrs["mesh"] self._mesh_type = mesh_type - self._vertical_location = None # Setting the interpolation method dynamically if interp_method is None: @@ -203,11 +203,6 @@ def __init__( if type(self.data) is ux.UxDataArray: self._gtype = None - # Set the vertical location - if "nz1" in data.dims: - self._vertical_location = "center" - elif "nz" in data.dims: - self._vertical_location = "face" else: # TODO Nick : This bit probably needs an overhaul once the parcels.Grid class is integrated. # Set the grid type if "x_g" in self.data.coords: @@ -285,9 +280,10 @@ def lon(self): @property def depth(self): if type(self.data) is ux.UxDataArray: - if self._vertical_location == "center": + vertical_location = get_vertical_location_from_dims(self.data.dims) + if vertical_location == "center": return self.grid.nz1 - elif self._vertical_location == "face": + elif vertical_location == "face": return self.grid.nz else: return self.data.depth diff --git a/tests/v4/utils/test_unstructured.py b/tests/v4/utils/test_unstructured.py new file mode 100644 index 0000000000..e8c296feca --- /dev/null +++ b/tests/v4/utils/test_unstructured.py @@ -0,0 +1,33 @@ +import pytest + +from parcels._core.utils.unstructured import ( + get_vertical_dim_name_from_location, + get_vertical_location_from_dims, +) + + +def test_get_vertical_location_from_dims(): + # Test with nz1 dimension + assert get_vertical_location_from_dims(("nz1", "time")) == "center" + + # Test with nz dimension + assert get_vertical_location_from_dims(("nz", "time")) == "face" + + # Test with both dimensions + with pytest.raises(ValueError): + get_vertical_location_from_dims(("nz1", "nz", "time")) + + # Test with no vertical dimension + with pytest.raises(ValueError): + get_vertical_location_from_dims(("time", "x", "y")) + + +def test_get_vertical_dim_name_from_location(): + # Test with center location + assert get_vertical_dim_name_from_location("center") == "nz1" + + # Test with face location + assert get_vertical_dim_name_from_location("face") == "nz" + + with pytest.raises(KeyError): + get_vertical_dim_name_from_location("invalid_location") From 68bc63711756e18203ba188e3caeb021001fe46d Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 28 Apr 2025 13:24:24 +0200 Subject: [PATCH 5/9] Remove Field.lonlat_minmax --- parcels/field.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 3ac68a6eaa..0b67f89b61 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -232,17 +232,9 @@ def __init__( else: self._gtype = GridType.CurvilinearSGrid - self._lonlat_minmax = np.array( - [np.nanmin(self.lon), np.nanmax(self.lon), np.nanmin(self.lat), np.nanmax(self.lat)], dtype=np.float32 - ) - def __repr__(self): return field_repr(self) - @property - def lonlat_minmax(self): - return self._lonlat_minmax - @property def units(self): return self._units From b486ad85fb1560e58289922b29dd0e42a9f7d7f1 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 28 Apr 2025 13:36:13 +0200 Subject: [PATCH 6/9] Update GridAdapter to use composition The adapter now wraps a Grid instance instead of inheriting from it. This provides better separation of concerns and makes the adapter's purpose more explicit - it's adapting an existing grid rather than extending it. It also allows for more flexibility in adapting any grid instance, regardless of how it was constructed. --- parcels/v4/gridadapter.py | 37 ++++++++++++++++++------------------ tests/v4/test_gridadapter.py | 5 +++-- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/parcels/v4/gridadapter.py b/parcels/v4/gridadapter.py index 99e34a2c65..31329d6496 100644 --- a/parcels/v4/gridadapter.py +++ b/parcels/v4/gridadapter.py @@ -30,67 +30,68 @@ def get_time(axis: Axis) -> npt.NDArray: return axis._ds[axis.coords["center"]].values -class GridAdapter(NewGrid): - def __init__(self, ds, mesh="flat", *args, **kwargs): - super().__init__(ds, *args, **kwargs) +class GridAdapter: + def __init__(self, grid: NewGrid, mesh="flat"): + self.grid = grid self.mesh = mesh + # ! Not ideal... Triggers computation on a throwaway item. If adapter is still needed in codebase, and this is prohibitively expensive, perhaps store GridAdapter on Field object instead of Grid self.lonlat_minmax = np.array( [ - np.nanmin(self._ds["lon"]), - np.nanmax(self._ds["lon"]), - np.nanmin(self._ds["lat"]), - np.nanmax(self._ds["lat"]), + np.nanmin(self.grid._ds["lon"]), + np.nanmax(self.grid._ds["lon"]), + np.nanmin(self.grid._ds["lat"]), + np.nanmax(self.grid._ds["lat"]), ] ) @property def lon(self): try: - _ = self.axes["X"] + _ = self.grid.axes["X"] except KeyError: return np.zeros(1) - return self._ds["lon"].values + return self.grid._ds["lon"].values @property def lat(self): try: - _ = self.axes["Y"] + _ = self.grid.axes["Y"] except KeyError: return np.zeros(1) - return self._ds["lat"].values + return self.grid._ds["lat"].values @property def depth(self): try: - _ = self.axes["Z"] + _ = self.grid.axes["Z"] except KeyError: return np.zeros(1) - return self._ds["depth"].values + return self.grid._ds["depth"].values @property def time(self): try: - axis = self.axes["T"] + axis = self.grid.axes["T"] except KeyError: return np.zeros(1) return get_time(axis) @property def xdim(self): - return get_dimensionality(self.axes.get("X")) + return get_dimensionality(self.grid.axes.get("X")) @property def ydim(self): - return get_dimensionality(self.axes.get("Y")) + return get_dimensionality(self.grid.axes.get("Y")) @property def zdim(self): - return get_dimensionality(self.axes.get("Z")) + return get_dimensionality(self.grid.axes.get("Z")) @property def tdim(self): - return get_dimensionality(self.axes.get("T")) + return get_dimensionality(self.grid.axes.get("T")) @property def time_origin(self): diff --git a/tests/v4/test_gridadapter.py b/tests/v4/test_gridadapter.py index 1fdcc66745..24c67b70b6 100644 --- a/tests/v4/test_gridadapter.py +++ b/tests/v4/test_gridadapter.py @@ -7,6 +7,7 @@ from parcels._datasets.structured.grid_datasets import N, T, datasets from parcels.grid import Grid as OldGrid from parcels.tools.converters import TimeConverter +from parcels.v4.grid import Grid as NewGrid from parcels.v4.gridadapter import GridAdapter TestCase = namedtuple("TestCase", ["Grid", "attr", "expected"]) @@ -38,7 +39,7 @@ def assert_equal(actual, expected): @pytest.mark.parametrize("ds, attr, expected", test_cases) def test_grid_adapter_properties_ground_truth(ds, attr, expected): - adapter = GridAdapter(ds, periodic=False) + adapter = GridAdapter(NewGrid(ds, periodic=False)) actual = getattr(adapter, attr) assert_equal(actual, expected) @@ -60,7 +61,7 @@ def test_grid_adapter_properties_ground_truth(ds, attr, expected): ) @pytest.mark.parametrize("ds", datasets.values()) def test_grid_adapter_against_old(ds, attr): - adapter = GridAdapter(ds, periodic=False) + adapter = GridAdapter(NewGrid(ds, periodic=False)) grid = OldGrid.create_grid( lon=ds.lon.values, From 9303d5880aa70425c1a37dddc5a7d8188cc39517 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 28 Apr 2025 13:51:08 +0200 Subject: [PATCH 7/9] Update Field to use GridAdapter --- parcels/field.py | 78 +++++++++++++++--------------------------------- 1 file changed, 24 insertions(+), 54 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 0b67f89b61..613a88ac82 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -29,6 +29,7 @@ _raise_field_out_of_bound_error, ) from parcels.v4.grid import Grid +from parcels.v4.gridadapter import GridAdapter from ._index_search import _search_indices_rectilinear, _search_time_index @@ -167,6 +168,13 @@ def __init__( self.data = data self.grid = grid + # For compatibility with parts of the codebase that rely on v3 definition of Grid. + # Should be worked to be removed in v4 + if isinstance(grid, Grid): + self.gridadapter = GridAdapter(grid) + else: + self.gridadapter = None + try: if isinstance(data, ux.UxDataArray): _assert_valid_uxdataarray(data) @@ -201,37 +209,6 @@ def __init__( else: self.allow_time_extrapolation = allow_time_extrapolation - if type(self.data) is ux.UxDataArray: - self._gtype = None - else: # TODO Nick : This bit probably needs an overhaul once the parcels.Grid class is integrated. - # Set the grid type - if "x_g" in self.data.coords: - lon = self.data.x_g - elif "x_c" in self.data.coords: - lon = self.data.x_c - else: - lon = self.data.lon - - if "nz1" in self.data.coords: - depth = self.data.nz1 - elif "nz" in self.data.coords: - depth = self.data.nz - elif "depth" in self.data.coords: - depth = self.data.depth - else: - depth = None - - if len(lon.shape) <= 1: - if depth is None or len(depth.shape) <= 1: - self._gtype = GridType.RectilinearZGrid - else: - self._gtype = GridType.RectilinearSGrid - else: - if depth is None or len(depth.shape) <= 1: - self._gtype = GridType.CurvilinearZGrid - else: - self._gtype = GridType.CurvilinearSGrid - def __repr__(self): return field_repr(self) @@ -255,7 +232,7 @@ def lat(self): elif self.data.attrs["location"] == "edge": return self.grid.edge_lat else: - return self.data.lat + return self.gridadapter.lat @property def lon(self): @@ -267,7 +244,7 @@ def lon(self): elif self.data.attrs["location"] == "edge": return self.grid.edge_lon else: - return self.data.lon + return self.gridadapter.lon @property def depth(self): @@ -278,40 +255,33 @@ def depth(self): elif vertical_location == "face": return self.grid.nz else: - return self.data.depth + return self.gridadapter.depth @property def xdim(self): if type(self.data) is xr.DataArray: - if "face_lon" in self.data.dims: - return self.data.sizes["face_lon"] - elif "node_lon" in self.data.dims: - return self.data.sizes["node_lon"] - else: - return self.data.sizes["lon"] + return self.gridadapter.xdim else: - return 0 # TODO : Discuss what we want to return as xdim for uxdataarray obj + raise NotImplementedError("xdim not implemented for unstructured grids") @property def ydim(self): if type(self.data) is xr.DataArray: - if "face_lat" in self.data.dims: - return self.data.sizes["face_lat"] - elif "node_lat" in self.data.dims: - return self.data.sizes["node_lat"] - else: - return self.data.sizes["lat"] + return self.gridadapter.ydim else: - return 0 # TODO : Discuss what we want to return as ydim for uxdataarray obj + raise NotImplementedError("ydim not implemented for unstructured grids") @property def zdim(self): - if "nz1" in self.data.dims: - return self.data.sizes["nz1"] - elif "nz" in self.data.dims: - return self.data.sizes["nz"] + if type(self.data) is xr.DataArray: + return self.gridadapter.zdim else: - return 0 + if "nz1" in self.data.dims: + return self.data.sizes["nz1"] + elif "nz" in self.data.dims: + return self.data.sizes["nz"] + else: + return 0 @property def n_face(self): @@ -379,7 +349,7 @@ def _search_indices_unstructured(self, z, y, x, ei=None, search2D=False): raise FieldOutOfBoundError(z, y, x) def _search_indices_structured(self, z, y, x, ei=None, search2D=False): - if self._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: + if self.gridadapter._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: (zeta, eta, xsi, zi, yi, xi) = _search_indices_rectilinear(self, z, y, x, ei=ei, search2D=search2D) else: ## TODO : Still need to implement the search_indices_curvilinear From dbbc290a554b3a15d9f530a492e8f2d8e7ffcd1e Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 29 Apr 2025 14:20:08 +0200 Subject: [PATCH 8/9] bugfix kwarg --- tests/v4/test_field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v4/test_field.py b/tests/v4/test_field.py index 082918567e..cf13bb2221 100644 --- a/tests/v4/test_field.py +++ b/tests/v4/test_field.py @@ -36,7 +36,7 @@ def test_field_init_param_types(): ], ) def test_field_incompatible_combination(data, grid): - with pytest.raises(ValueError, msg="Incompatible data-grid combination."): + with pytest.raises(ValueError, match="Incompatible data-grid combination."): Field( name="test_field", data=data, From 80b7cd4daf0a80706dbab6aa8d77ccd243d0988a Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 29 Apr 2025 14:52:25 +0200 Subject: [PATCH 9/9] xfail test_field_incompatible_combination[xarray-uxgrid] --- tests/v4/test_field.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/v4/test_field.py b/tests/v4/test_field.py index cf13bb2221..be0f0b87e4 100644 --- a/tests/v4/test_field.py +++ b/tests/v4/test_field.py @@ -32,7 +32,14 @@ def test_field_init_param_types(): "data,grid", [ pytest.param(ux.UxDataArray(), Grid(xr.Dataset()), id="uxdata-grid"), - pytest.param(xr.DataArray(), ux.UxDataArray().uxgrid, id="xarray-uxgrid"), + pytest.param( + xr.DataArray(), + ux.UxDataArray().uxgrid, + id="xarray-uxgrid", + marks=pytest.mark.xfail( + reason="Replace uxDataArray object with one that actually has a grid (once unstructured example datasets are in the codebase)." + ), + ), ], ) def test_field_incompatible_combination(data, grid):