Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions parcels/_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_raise_field_out_of_bound_error,
_raise_field_out_of_bound_surface_error,
_raise_field_sampling_error,
_raise_time_extrapolation_error,
)

from .grid import GridType
Expand All @@ -23,6 +24,35 @@
from .grid import Grid


def _search_time_index(grid: Grid, time: float, allow_time_extrapolation=True):
"""Find and return the index and relative coordinate in the time array associated with a given time.

Note that we normalize to either the first or the last index
if the sampled value is outside the time value range.
"""
if not allow_time_extrapolation and (time < grid.time[0] or time > grid.time[-1]):
_raise_time_extrapolation_error(time, field=None)
time_index = grid.time <= time

if time_index.all():
# If given time > last known field time, use
# the last field frame without interpolation
ti = len(grid.time) - 1
elif np.logical_not(time_index).all():
# If given time < any time in the field, use
# the first field frame without interpolation
ti = 0
else:
ti = int(time_index.argmin() - 1) if time_index.any() else 0
if grid.tdim == 1:
tau = 0
elif ti == len(grid.time) - 1:
tau = 1
else:
tau = (time - grid.time[ti]) / (grid.time[ti + 1] - grid.time[ti]) if grid.time[ti] != grid.time[ti + 1] else 0
return tau, ti


def search_indices_vertical_z(grid: Grid, gridindexingtype: GridIndexingType, z: float):
if grid.depth[-1] > grid.depth[0]:
if z < grid.depth[0]:
Expand Down
37 changes: 2 additions & 35 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,11 @@
FieldOutOfBoundError,
FieldOutOfBoundSurfaceError,
FieldSamplingError,
TimeExtrapolationError,
_raise_field_out_of_bound_error,
)
from parcels.tools.warnings import FieldSetWarning

from ._index_search import _search_indices_curvilinear, _search_indices_rectilinear
from ._index_search import _search_indices_curvilinear, _search_indices_rectilinear, _search_time_index
from .fieldfilebuffer import (
NetcdfFileBuffer,
)
Expand Down Expand Up @@ -543,7 +542,7 @@ def _reshape(self, data):
return data

def _search_indices(self, time, z, y, x, particle=None, search2D=False):
tau, ti = self._search_time_index(time)
tau, ti = _search_time_index(self.grid, time, self.allow_time_extrapolation)

if self.grid._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]:
(zeta, eta, xsi, zi, yi, xi) = _search_indices_rectilinear(
Expand Down Expand Up @@ -602,38 +601,6 @@ def _interpolate(self, time, z, y, x, particle=None):
e = add_note(e, f"Error interpolating field '{self.name}'.", before=True)
raise e

def _search_time_index(self, time):
"""Find and return the index and relative coordinate in the time array associated with a given time.

Note that we normalize to either the first or the last index
if the sampled value is outside the time value range.
"""
if not self.allow_time_extrapolation and (time < self.grid.time[0] or time > self.grid.time[-1]):
raise TimeExtrapolationError(time, field=self)
time_index = self.grid.time <= time

if time_index.all():
# If given time > last known field time, use
# the last field frame without interpolation
ti = len(self.grid.time) - 1
elif np.logical_not(time_index).all():
# If given time < any time in the field, use
# the first field frame without interpolation
ti = 0
else:
ti = time_index.argmin() - 1 if time_index.any() else 0
if self.grid.tdim == 1:
tau = 0
elif ti == len(self.grid.time) - 1:
tau = 1
else:
tau = (
(time - self.grid.time[ti]) / (self.grid.time[ti + 1] - self.grid.time[ti])
if self.grid.time[ti] != self.grid.time[ti + 1]
else 0
)
return tau, ti

def _check_velocitysampling(self):
if self.name in ["U", "V", "W"]:
warnings.warn(
Expand Down
3 changes: 1 addition & 2 deletions parcels/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
self._lon = lon
self._lat = lat
self.time = time
self.tdim = time.size
self._time_origin = TimeConverter() if time_origin is None else time_origin
assert isinstance(self.time_origin, TimeConverter), "time_origin needs to be a TimeConverter object"
assert_valid_mesh(mesh)
Expand Down Expand Up @@ -182,7 +183,6 @@ def __init__(self, lon, lat, time, time_origin, mesh: Mesh):
assert len(time.shape) == 1, "time is not a vector"

super().__init__(lon, lat, time, time_origin, mesh)
self.tdim = self.time.size

@property
def xdim(self):
Expand Down Expand Up @@ -326,7 +326,6 @@ def __init__(
lon = lon.squeeze()
lat = lat.squeeze()
super().__init__(lon, lat, time, time_origin, mesh)
self.tdim = self.time.size

@property
def xdim(self):
Expand Down
5 changes: 5 additions & 0 deletions parcels/tools/statuscodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"_raise_field_out_of_bound_error",
"_raise_field_out_of_bound_surface_error",
"_raise_field_sampling_error",
"_raise_time_extrapolation_error",
]


Expand Down Expand Up @@ -77,6 +78,10 @@ def __init__(self, time, field=None):
super().__init__(message)


def _raise_time_extrapolation_error(time: float, field=None):
raise TimeExtrapolationError(time, field)


class KernelError(RuntimeError):
"""General particle kernel error with optional custom message."""

Expand Down
4 changes: 3 additions & 1 deletion tests/test_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,11 @@ def test_fieldset_write(tmp_zarrfile):
fieldset.U.to_write = True

def UpdateU(particle, fieldset, time): # pragma: no cover
from parcels._index_search import _search_time_index

Comment thread
VeckoTheGecko marked this conversation as resolved.
tmp1, tmp2 = fieldset.UV[particle]
_, yi, xi = fieldset.U.unravel_index(particle.ei)
_, ti = fieldset.U._search_time_index(time)
_, ti = _search_time_index(fieldset.U.grid, time)
fieldset.U.data[ti, yi, xi] += 1
fieldset.U.grid.time[0] = time

Expand Down
Loading