From 04f0e261329b761e9a19b20a2ca4ec0c8be62264 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Wed, 12 Mar 2025 16:43:13 +0100 Subject: [PATCH 1/2] Moving the _search_time_index function to _index_search.py --- parcels/_index_search.py | 30 +++++++++++++++++++++++++++++ parcels/field.py | 37 ++---------------------------------- parcels/tools/statuscodes.py | 5 +++++ tests/test_fieldset.py | 4 +++- 4 files changed, 40 insertions(+), 36 deletions(-) diff --git a/parcels/_index_search.py b/parcels/_index_search.py index a52baf56ff..35abc3f37e 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -14,6 +14,7 @@ _raise_field_out_of_bound_error, _raise_field_out_of_bound_surface_error, _raise_field_sampling_error, + _raise_time_extrapolation_error, ) from .grid import GridType @@ -23,6 +24,35 @@ from .grid import Grid +def _search_time_index(grid: Grid, time: float, allow_time_extrapolation=True): + """Find and return the index and relative coordinate in the time array associated with a given time. + + Note that we normalize to either the first or the last index + if the sampled value is outside the time value range. + """ + if not allow_time_extrapolation and (time < grid.time[0] or time > grid.time[-1]): + _raise_time_extrapolation_error(time, field=None) + time_index = grid.time <= time + + if time_index.all(): + # If given time > last known field time, use + # the last field frame without interpolation + ti = len(grid.time) - 1 + elif np.logical_not(time_index).all(): + # If given time < any time in the field, use + # the first field frame without interpolation + ti = 0 + else: + ti = time_index.argmin() - 1 if time_index.any() else 0 + if grid.tdim == 1: + tau = 0 + elif ti == len(grid.time) - 1: + tau = 1 + else: + tau = (time - grid.time[ti]) / (grid.time[ti + 1] - grid.time[ti]) if grid.time[ti] != grid.time[ti + 1] else 0 + return tau, ti + + def search_indices_vertical_z(grid: Grid, gridindexingtype: GridIndexingType, z: float): if grid.depth[-1] > grid.depth[0]: if z < grid.depth[0]: diff --git a/parcels/field.py b/parcels/field.py index 0952143112..e4159b2bcd 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -36,12 +36,11 @@ FieldOutOfBoundError, FieldOutOfBoundSurfaceError, FieldSamplingError, - TimeExtrapolationError, _raise_field_out_of_bound_error, ) from parcels.tools.warnings import FieldSetWarning -from ._index_search import _search_indices_curvilinear, _search_indices_rectilinear +from ._index_search import _search_indices_curvilinear, _search_indices_rectilinear, _search_time_index from .fieldfilebuffer import ( NetcdfFileBuffer, ) @@ -543,7 +542,7 @@ def _reshape(self, data): return data def _search_indices(self, time, z, y, x, particle=None, search2D=False): - tau, ti = self._search_time_index(time) + tau, ti = _search_time_index(self.grid, time, self.allow_time_extrapolation) if self.grid._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: (zeta, eta, xsi, zi, yi, xi) = _search_indices_rectilinear( @@ -602,38 +601,6 @@ def _interpolate(self, time, z, y, x, particle=None): e = add_note(e, f"Error interpolating field '{self.name}'.", before=True) raise e - def _search_time_index(self, time): - """Find and return the index and relative coordinate in the time array associated with a given time. - - Note that we normalize to either the first or the last index - if the sampled value is outside the time value range. - """ - if not self.allow_time_extrapolation and (time < self.grid.time[0] or time > self.grid.time[-1]): - raise TimeExtrapolationError(time, field=self) - time_index = self.grid.time <= time - - if time_index.all(): - # If given time > last known field time, use - # the last field frame without interpolation - ti = len(self.grid.time) - 1 - elif np.logical_not(time_index).all(): - # If given time < any time in the field, use - # the first field frame without interpolation - ti = 0 - else: - ti = time_index.argmin() - 1 if time_index.any() else 0 - if self.grid.tdim == 1: - tau = 0 - elif ti == len(self.grid.time) - 1: - tau = 1 - else: - tau = ( - (time - self.grid.time[ti]) / (self.grid.time[ti + 1] - self.grid.time[ti]) - if self.grid.time[ti] != self.grid.time[ti + 1] - else 0 - ) - return tau, ti - def _check_velocitysampling(self): if self.name in ["U", "V", "W"]: warnings.warn( diff --git a/parcels/tools/statuscodes.py b/parcels/tools/statuscodes.py index 7a9799dc7c..501741ec2c 100644 --- a/parcels/tools/statuscodes.py +++ b/parcels/tools/statuscodes.py @@ -10,6 +10,7 @@ "_raise_field_out_of_bound_error", "_raise_field_out_of_bound_surface_error", "_raise_field_sampling_error", + "_raise_time_extrapolation_error", ] @@ -77,6 +78,10 @@ def __init__(self, time, field=None): super().__init__(message) +def _raise_time_extrapolation_error(time: float, field=None): + raise TimeExtrapolationError(time, field) + + class KernelError(RuntimeError): """General particle kernel error with optional custom message.""" diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index 26618605df..78b7317a50 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -450,9 +450,11 @@ def test_fieldset_write(tmp_zarrfile): fieldset.U.to_write = True def UpdateU(particle, fieldset, time): # pragma: no cover + from parcels._index_search import _search_time_index + tmp1, tmp2 = fieldset.UV[particle] _, yi, xi = fieldset.U.unravel_index(particle.ei) - _, ti = fieldset.U._search_time_index(time) + _, ti = _search_time_index(fieldset.U.grid, time) fieldset.U.data[ti, yi, xi] += 1 fieldset.U.grid.time[0] = time From d9d4f8debee354e6ebab0ad90e3b5640f5f6e378 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 13 Mar 2025 13:40:59 +0100 Subject: [PATCH 2/2] fix mypy --- parcels/_index_search.py | 2 +- parcels/grid.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/parcels/_index_search.py b/parcels/_index_search.py index 35abc3f37e..ddfbadd173 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -43,7 +43,7 @@ def _search_time_index(grid: Grid, time: float, allow_time_extrapolation=True): # the first field frame without interpolation ti = 0 else: - ti = time_index.argmin() - 1 if time_index.any() else 0 + ti = int(time_index.argmin() - 1) if time_index.any() else 0 if grid.tdim == 1: tau = 0 elif ti == len(grid.time) - 1: diff --git a/parcels/grid.py b/parcels/grid.py index 813dea2c70..f74c126a32 100644 --- a/parcels/grid.py +++ b/parcels/grid.py @@ -58,6 +58,7 @@ def __init__( self._lon = lon self._lat = lat self.time = time + self.tdim = time.size self.time_full = self.time # needed for deferred_loaded Fields self._time_origin = TimeConverter() if time_origin is None else time_origin assert isinstance(self.time_origin, TimeConverter), "time_origin needs to be a TimeConverter object" @@ -183,7 +184,6 @@ def __init__(self, lon, lat, time, time_origin, mesh: Mesh): assert len(time.shape) == 1, "time is not a vector" super().__init__(lon, lat, time, time_origin, mesh) - self.tdim = self.time.size @property def xdim(self): @@ -327,7 +327,6 @@ def __init__( lon = lon.squeeze() lat = lat.squeeze() super().__init__(lon, lat, time, time_origin, mesh) - self.tdim = self.time.size @property def xdim(self):