Skip to content

Commit c8d8e6d

Browse files
Merge pull request #1895 from OceanParcels/swap_space_and_time_interpolation_order
Moving the time interpolation down to within _Interpolator functions
2 parents 04a77c5 + 48db381 commit c8d8e6d

File tree

8 files changed

+222
-170
lines changed

8 files changed

+222
-170
lines changed

parcels/_interpolation.py

Lines changed: 82 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55

66
from parcels._typing import GridIndexingType
7+
from parcels.tools._helpers import should_calculate_next_ti
78

89

910
@dataclass
@@ -14,6 +15,8 @@ class InterpolationContext2D:
1415
----------
1516
data: np.ndarray
1617
field data of shape (time, y, x)
18+
tau: float
19+
time interpolation coordinate in unit length
1720
eta: float
1821
y-direction interpolation coordinate in unit cube (between 0 and 1)
1922
xsi: float
@@ -28,6 +31,7 @@ class InterpolationContext2D:
2831
"""
2932

3033
data: np.ndarray
34+
tau: float
3135
eta: float
3236
xsi: float
3337
ti: int
@@ -45,6 +49,8 @@ class InterpolationContext3D:
4549
field data of shape (time, z, y, x). This needs to be complete in the vertical
4650
direction as some interpolation methods need to know whether they are at the
4751
surface or bottom.
52+
tau: float
53+
time interpolation coordinate in unit length
4854
zeta: float
4955
vertical interpolation coordinate in unit cube
5056
eta: float
@@ -65,6 +71,7 @@ class InterpolationContext3D:
6571
"""
6672

6773
data: np.ndarray
74+
tau: float
6875
zeta: float
6976
eta: float
7077
xsi: float
@@ -110,7 +117,11 @@ def decorator(interpolator: Callable[[InterpolationContext3D], float]):
110117
def _nearest_2d(ctx: InterpolationContext2D) -> float:
111118
xii = ctx.xi if ctx.xsi <= 0.5 else ctx.xi + 1
112119
yii = ctx.yi if ctx.eta <= 0.5 else ctx.yi + 1
113-
return ctx.data[ctx.ti, yii, xii]
120+
ft0 = ctx.data[ctx.ti, yii, xii]
121+
if not should_calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]):
122+
return ft0
123+
ft1 = ctx.data[ctx.ti + 1, yii, xii]
124+
return (1 - ctx.tau) * ft0 + ctx.tau * ft1
114125

115126

116127
def _interp_on_unit_square(*, eta: float, xsi: float, data: np.ndarray, yi: int, xi: int) -> float:
@@ -128,7 +139,11 @@ def _interp_on_unit_square(*, eta: float, xsi: float, data: np.ndarray, yi: int,
128139
@register_2d_interpolator("partialslip")
129140
@register_2d_interpolator("freeslip")
130141
def _linear_2d(ctx: InterpolationContext2D) -> float:
131-
return _interp_on_unit_square(eta=ctx.eta, xsi=ctx.xsi, data=ctx.data[ctx.ti, :, :], yi=ctx.yi, xi=ctx.xi)
142+
ft0 = _interp_on_unit_square(eta=ctx.eta, xsi=ctx.xsi, data=ctx.data[ctx.ti, :, :], yi=ctx.yi, xi=ctx.xi)
143+
if not should_calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]):
144+
return ft0
145+
ft1 = _interp_on_unit_square(eta=ctx.eta, xsi=ctx.xsi, data=ctx.data[ctx.ti + 1, :, :], yi=ctx.yi, xi=ctx.xi)
146+
return (1 - ctx.tau) * ft0 + ctx.tau * ft1
132147

133148

134149
@register_2d_interpolator("linear_invdist_land_tracer")
@@ -142,6 +157,13 @@ def _linear_invdist_land_tracer_2d(ctx: InterpolationContext2D) -> float:
142157
land = np.isclose(data[ti, yi : yi + 2, xi : xi + 2], 0.0)
143158
nb_land = np.sum(land)
144159

160+
def _get_data_temporalinterp(*, ti, yi, xi):
161+
dt0 = data[ti, yi, xi]
162+
if not should_calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]):
163+
return dt0
164+
dt1 = data[ti + 1, yi, xi]
165+
return (1 - ctx.tau) * dt0 + ctx.tau * dt1
166+
145167
if nb_land == 4:
146168
return 0
147169
elif nb_land > 0:
@@ -154,9 +176,9 @@ def _linear_invdist_land_tracer_2d(ctx: InterpolationContext2D) -> float:
154176
if land[j][i] == 1: # index search led us directly onto land
155177
return 0
156178
else:
157-
return data[ti, yi + j, xi + i]
179+
return _get_data_temporalinterp(ti=ti, yi=yi + j, xi=xi + i)
158180
elif land[j][i] == 0:
159-
val += data[ti, yi + j, xi + i] / distance
181+
val += _get_data_temporalinterp(ti=ti, yi=yi + j, xi=xi + i) / distance
160182
w_sum += 1 / distance
161183
return val / w_sum
162184
else:
@@ -166,33 +188,64 @@ def _linear_invdist_land_tracer_2d(ctx: InterpolationContext2D) -> float:
166188
@register_2d_interpolator("cgrid_tracer")
167189
@register_2d_interpolator("bgrid_tracer")
168190
def _tracer_2d(ctx: InterpolationContext2D) -> float:
169-
return ctx.data[ctx.ti, ctx.yi + 1, ctx.xi + 1]
191+
ft0 = ctx.data[ctx.ti, ctx.yi + 1, ctx.xi + 1]
192+
if not should_calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]):
193+
return ft0
194+
ft1 = ctx.data[ctx.ti + 1, ctx.yi + 1, ctx.xi + 1]
195+
return (1 - ctx.tau) * ft0 + ctx.tau * ft1
170196

171197

172198
@register_3d_interpolator("nearest")
173199
def _nearest_3d(ctx: InterpolationContext3D) -> float:
174200
xii = ctx.xi if ctx.xsi <= 0.5 else ctx.xi + 1
175201
yii = ctx.yi if ctx.eta <= 0.5 else ctx.yi + 1
176202
zii = ctx.zi if ctx.zeta <= 0.5 else ctx.zi + 1
177-
return ctx.data[ctx.ti, zii, yii, xii]
203+
ft0 = ctx.data[ctx.ti, zii, yii, xii]
204+
if not should_calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]):
205+
return ft0
206+
ft1 = ctx.data[ctx.ti + 1, zii, yii, xii]
207+
return (1 - ctx.tau) * ft0 + ctx.tau * ft1
208+
209+
210+
def _get_cgrid_depth_point(*, zeta: float, data: np.ndarray, zi: int, yi: int, xi: int) -> float:
211+
f0 = data[zi, yi, xi]
212+
f1 = data[zi + 1, yi, xi]
213+
return (1 - zeta) * f0 + zeta * f1
178214

179215

180216
@register_3d_interpolator("cgrid_velocity")
181-
def _cgrid_velocity_3d(ctx: InterpolationContext3D) -> float:
217+
def _cgrid_W_velocity_3d(ctx: InterpolationContext3D) -> float:
182218
# evaluating W velocity in c_grid
183219
if ctx.gridindexingtype == "nemo":
184-
f0 = ctx.data[ctx.ti, ctx.zi, ctx.yi + 1, ctx.xi + 1]
185-
f1 = ctx.data[ctx.ti, ctx.zi + 1, ctx.yi + 1, ctx.xi + 1]
220+
ft0 = _get_cgrid_depth_point(
221+
zeta=ctx.zeta, data=ctx.data[ctx.ti, :, :, :], zi=ctx.zi, yi=ctx.yi + 1, xi=ctx.xi + 1
222+
)
186223
elif ctx.gridindexingtype in ["mitgcm", "croco"]:
187-
f0 = ctx.data[ctx.ti, ctx.zi, ctx.yi, ctx.xi]
188-
f1 = ctx.data[ctx.ti, ctx.zi + 1, ctx.yi, ctx.xi]
189-
return (1 - ctx.zeta) * f0 + ctx.zeta * f1
224+
ft0 = _get_cgrid_depth_point(zeta=ctx.zeta, data=ctx.data[ctx.ti, :, :, :], zi=ctx.zi, yi=ctx.yi, xi=ctx.xi)
225+
if not should_calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]):
226+
return ft0
227+
228+
if ctx.gridindexingtype == "nemo":
229+
ft1 = _get_cgrid_depth_point(
230+
zeta=ctx.zeta, data=ctx.data[ctx.ti + 1, :, :, :], zi=ctx.zi, yi=ctx.yi + 1, xi=ctx.xi + 1
231+
)
232+
elif ctx.gridindexingtype in ["mitgcm", "croco"]:
233+
ft1 = _get_cgrid_depth_point(zeta=ctx.zeta, data=ctx.data[ctx.ti + 1, :, :, :], zi=ctx.zi, yi=ctx.yi, xi=ctx.xi)
234+
return (1 - ctx.tau) * ft0 + ctx.tau * ft1
190235

191236

192237
@register_3d_interpolator("linear_invdist_land_tracer")
193238
def _linear_invdist_land_tracer_3d(ctx: InterpolationContext3D) -> float:
194239
land = np.isclose(ctx.data[ctx.ti, ctx.zi : ctx.zi + 2, ctx.yi : ctx.yi + 2, ctx.xi : ctx.xi + 2], 0.0)
195240
nb_land = np.sum(land)
241+
242+
def _get_data_temporalinterp(*, ti, zi, yi, xi):
243+
dt0 = ctx.data[ti, zi, yi, xi]
244+
if not should_calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]):
245+
return dt0
246+
dt1 = data[ti + 1, zi, yi, xi]
247+
return (1 - ctx.tau) * dt0 + ctx.tau * dt1
248+
196249
if nb_land == 8:
197250
return 0
198251
elif nb_land > 0:
@@ -206,9 +259,11 @@ def _linear_invdist_land_tracer_3d(ctx: InterpolationContext3D) -> float:
206259
if land[k][j][i] == 1: # index search led us directly onto land
207260
return 0
208261
else:
209-
return ctx.data[ctx.ti, ctx.zi + k, ctx.yi + j, ctx.xi + i]
262+
return _get_data_temporalinterp(ti=ctx.ti, zi=ctx.zi + k, yi=ctx.yi + j, xi=ctx.xi + i)
210263
elif land[k][j][i] == 0:
211-
val += ctx.data[ctx.ti, ctx.zi + k, ctx.yi + j, ctx.xi + i] / distance
264+
val += (
265+
_get_data_temporalinterp(ti=ctx.ti, zi=ctx.zi + k, yi=ctx.yi + j, xi=ctx.xi + i) / distance
266+
)
212267
w_sum += 1 / distance
213268
return val / w_sum
214269
else:
@@ -253,9 +308,15 @@ def _z_layer_interp(
253308
def _linear_3d(ctx: InterpolationContext3D) -> float:
254309
zdim = ctx.data.shape[1]
255310
data_3d = ctx.data[ctx.ti, :, :, :]
256-
f0, f1 = _get_3d_f0_f1(eta=ctx.eta, xsi=ctx.xsi, data=data_3d, zi=ctx.zi, yi=ctx.yi, xi=ctx.xi)
311+
fz0, fz1 = _get_3d_f0_f1(eta=ctx.eta, xsi=ctx.xsi, data=data_3d, zi=ctx.zi, yi=ctx.yi, xi=ctx.xi)
312+
if should_calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]):
313+
data_3d = ctx.data[ctx.ti + 1, :, :, :]
314+
fz0_t1, fz1_t1 = _get_3d_f0_f1(eta=ctx.eta, xsi=ctx.xsi, data=data_3d, zi=ctx.zi, yi=ctx.yi, xi=ctx.xi)
315+
fz0 = (1 - ctx.tau) * fz0 + ctx.tau * fz0_t1
316+
if fz1_t1 is not None and fz1 is not None:
317+
fz1 = (1 - ctx.tau) * fz1 + ctx.tau * fz1_t1
257318

258-
return _z_layer_interp(zeta=ctx.zeta, f0=f0, f1=f1, zi=ctx.zi, zdim=zdim, gridindexingtype=ctx.gridindexingtype)
319+
return _z_layer_interp(zeta=ctx.zeta, f0=fz0, f1=fz1, zi=ctx.zi, zdim=zdim, gridindexingtype=ctx.gridindexingtype)
259320

260321

261322
@register_3d_interpolator("bgrid_velocity")
@@ -277,4 +338,8 @@ def _linear_3d_bgrid_w_velocity(ctx: InterpolationContext3D) -> float:
277338
@register_3d_interpolator("bgrid_tracer")
278339
@register_3d_interpolator("cgrid_tracer")
279340
def _tracer_3d(ctx: InterpolationContext3D) -> float:
280-
return ctx.data[ctx.ti, ctx.zi, ctx.yi + 1, ctx.xi + 1]
341+
ft0 = ctx.data[ctx.ti, ctx.zi, ctx.yi + 1, ctx.xi + 1]
342+
if not should_calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]):
343+
return ft0
344+
ft1 = ctx.data[ctx.ti + 1, ctx.zi, ctx.yi + 1, ctx.xi + 1]
345+
return (1 - ctx.tau) * ft0 + ctx.tau * ft1

parcels/application_kernels/advection.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,16 +177,14 @@ def AdvectionAnalytical(particle, fieldset, time): # pragma: no cover
177177
direction = 1.0 if particle.dt > 0 else -1.0
178178
withW = True if "W" in [f.name for f in fieldset.get_fields()] else False
179179
withTime = True if len(fieldset.U.grid.time_full) > 1 else False
180-
ti = fieldset.U._time_index(time)
180+
tau, zeta, eta, xsi, ti, zi, yi, xi = fieldset.U._search_indices(
181+
time, particle.depth, particle.lat, particle.lon, particle=particle
182+
)
181183
ds_t = particle.dt
182184
if withTime:
183-
tau = (time - fieldset.U.grid.time[ti]) / (fieldset.U.grid.time[ti + 1] - fieldset.U.grid.time[ti])
184185
time_i = np.linspace(0, fieldset.U.grid.time[ti + 1] - fieldset.U.grid.time[ti], I_s)
185186
ds_t = min(ds_t, time_i[np.where(time - fieldset.U.grid.time[ti] < time_i)[0][0]])
186187

187-
zeta, eta, xsi, zi, yi, xi = fieldset.U._search_indices(
188-
time, particle.depth, particle.lat, particle.lon, ti, particle=particle
189-
)
190188
if withW:
191189
if abs(xsi - 1) < tol:
192190
if fieldset.U.data[0, zi + 1, yi + 1, xi + 1] > 0:

0 commit comments

Comments
 (0)