44import numpy as np
55
66from parcels ._typing import GridIndexingType
7+ from parcels .tools ._helpers import should_calculate_next_ti
78
89
910@dataclass
@@ -14,6 +15,8 @@ class InterpolationContext2D:
1415 ----------
1516 data: np.ndarray
1617 field data of shape (time, y, x)
18+ tau: float
19+ time interpolation coordinate in unit length
1720 eta: float
1821 y-direction interpolation coordinate in unit cube (between 0 and 1)
1922 xsi: float
@@ -28,6 +31,7 @@ class InterpolationContext2D:
2831 """
2932
3033 data : np .ndarray
34+ tau : float
3135 eta : float
3236 xsi : float
3337 ti : int
@@ -45,6 +49,8 @@ class InterpolationContext3D:
4549 field data of shape (time, z, y, x). This needs to be complete in the vertical
4650 direction as some interpolation methods need to know whether they are at the
4751 surface or bottom.
52+ tau: float
53+ time interpolation coordinate in unit length
4854 zeta: float
4955 vertical interpolation coordinate in unit cube
5056 eta: float
@@ -65,6 +71,7 @@ class InterpolationContext3D:
6571 """
6672
6773 data : np .ndarray
74+ tau : float
6875 zeta : float
6976 eta : float
7077 xsi : float
@@ -110,7 +117,11 @@ def decorator(interpolator: Callable[[InterpolationContext3D], float]):
110117def _nearest_2d (ctx : InterpolationContext2D ) -> float :
111118 xii = ctx .xi if ctx .xsi <= 0.5 else ctx .xi + 1
112119 yii = ctx .yi if ctx .eta <= 0.5 else ctx .yi + 1
113- return ctx .data [ctx .ti , yii , xii ]
120+ ft0 = ctx .data [ctx .ti , yii , xii ]
121+ if not should_calculate_next_ti (ctx .ti , ctx .tau , ctx .data .shape [0 ]):
122+ return ft0
123+ ft1 = ctx .data [ctx .ti + 1 , yii , xii ]
124+ return (1 - ctx .tau ) * ft0 + ctx .tau * ft1
114125
115126
116127def _interp_on_unit_square (* , eta : float , xsi : float , data : np .ndarray , yi : int , xi : int ) -> float :
@@ -128,7 +139,11 @@ def _interp_on_unit_square(*, eta: float, xsi: float, data: np.ndarray, yi: int,
128139@register_2d_interpolator ("partialslip" )
129140@register_2d_interpolator ("freeslip" )
130141def _linear_2d (ctx : InterpolationContext2D ) -> float :
131- return _interp_on_unit_square (eta = ctx .eta , xsi = ctx .xsi , data = ctx .data [ctx .ti , :, :], yi = ctx .yi , xi = ctx .xi )
142+ ft0 = _interp_on_unit_square (eta = ctx .eta , xsi = ctx .xsi , data = ctx .data [ctx .ti , :, :], yi = ctx .yi , xi = ctx .xi )
143+ if not should_calculate_next_ti (ctx .ti , ctx .tau , ctx .data .shape [0 ]):
144+ return ft0
145+ ft1 = _interp_on_unit_square (eta = ctx .eta , xsi = ctx .xsi , data = ctx .data [ctx .ti + 1 , :, :], yi = ctx .yi , xi = ctx .xi )
146+ return (1 - ctx .tau ) * ft0 + ctx .tau * ft1
132147
133148
134149@register_2d_interpolator ("linear_invdist_land_tracer" )
@@ -142,6 +157,13 @@ def _linear_invdist_land_tracer_2d(ctx: InterpolationContext2D) -> float:
142157 land = np .isclose (data [ti , yi : yi + 2 , xi : xi + 2 ], 0.0 )
143158 nb_land = np .sum (land )
144159
160+ def _get_data_temporalinterp (* , ti , yi , xi ):
161+ dt0 = data [ti , yi , xi ]
162+ if not should_calculate_next_ti (ctx .ti , ctx .tau , ctx .data .shape [0 ]):
163+ return dt0
164+ dt1 = data [ti + 1 , yi , xi ]
165+ return (1 - ctx .tau ) * dt0 + ctx .tau * dt1
166+
145167 if nb_land == 4 :
146168 return 0
147169 elif nb_land > 0 :
@@ -154,9 +176,9 @@ def _linear_invdist_land_tracer_2d(ctx: InterpolationContext2D) -> float:
154176 if land [j ][i ] == 1 : # index search led us directly onto land
155177 return 0
156178 else :
157- return data [ ti , yi + j , xi + i ]
179+ return _get_data_temporalinterp ( ti = ti , yi = yi + j , xi = xi + i )
158180 elif land [j ][i ] == 0 :
159- val += data [ ti , yi + j , xi + i ] / distance
181+ val += _get_data_temporalinterp ( ti = ti , yi = yi + j , xi = xi + i ) / distance
160182 w_sum += 1 / distance
161183 return val / w_sum
162184 else :
@@ -166,33 +188,64 @@ def _linear_invdist_land_tracer_2d(ctx: InterpolationContext2D) -> float:
166188@register_2d_interpolator ("cgrid_tracer" )
167189@register_2d_interpolator ("bgrid_tracer" )
168190def _tracer_2d (ctx : InterpolationContext2D ) -> float :
169- return ctx .data [ctx .ti , ctx .yi + 1 , ctx .xi + 1 ]
191+ ft0 = ctx .data [ctx .ti , ctx .yi + 1 , ctx .xi + 1 ]
192+ if not should_calculate_next_ti (ctx .ti , ctx .tau , ctx .data .shape [0 ]):
193+ return ft0
194+ ft1 = ctx .data [ctx .ti + 1 , ctx .yi + 1 , ctx .xi + 1 ]
195+ return (1 - ctx .tau ) * ft0 + ctx .tau * ft1
170196
171197
172198@register_3d_interpolator ("nearest" )
173199def _nearest_3d (ctx : InterpolationContext3D ) -> float :
174200 xii = ctx .xi if ctx .xsi <= 0.5 else ctx .xi + 1
175201 yii = ctx .yi if ctx .eta <= 0.5 else ctx .yi + 1
176202 zii = ctx .zi if ctx .zeta <= 0.5 else ctx .zi + 1
177- return ctx .data [ctx .ti , zii , yii , xii ]
203+ ft0 = ctx .data [ctx .ti , zii , yii , xii ]
204+ if not should_calculate_next_ti (ctx .ti , ctx .tau , ctx .data .shape [0 ]):
205+ return ft0
206+ ft1 = ctx .data [ctx .ti + 1 , zii , yii , xii ]
207+ return (1 - ctx .tau ) * ft0 + ctx .tau * ft1
208+
209+
210+ def _get_cgrid_depth_point (* , zeta : float , data : np .ndarray , zi : int , yi : int , xi : int ) -> float :
211+ f0 = data [zi , yi , xi ]
212+ f1 = data [zi + 1 , yi , xi ]
213+ return (1 - zeta ) * f0 + zeta * f1
178214
179215
180216@register_3d_interpolator ("cgrid_velocity" )
181- def _cgrid_velocity_3d (ctx : InterpolationContext3D ) -> float :
217+ def _cgrid_W_velocity_3d (ctx : InterpolationContext3D ) -> float :
182218 # evaluating W velocity in c_grid
183219 if ctx .gridindexingtype == "nemo" :
184- f0 = ctx .data [ctx .ti , ctx .zi , ctx .yi + 1 , ctx .xi + 1 ]
185- f1 = ctx .data [ctx .ti , ctx .zi + 1 , ctx .yi + 1 , ctx .xi + 1 ]
220+ ft0 = _get_cgrid_depth_point (
221+ zeta = ctx .zeta , data = ctx .data [ctx .ti , :, :, :], zi = ctx .zi , yi = ctx .yi + 1 , xi = ctx .xi + 1
222+ )
186223 elif ctx .gridindexingtype in ["mitgcm" , "croco" ]:
187- f0 = ctx .data [ctx .ti , ctx .zi , ctx .yi , ctx .xi ]
188- f1 = ctx .data [ctx .ti , ctx .zi + 1 , ctx .yi , ctx .xi ]
189- return (1 - ctx .zeta ) * f0 + ctx .zeta * f1
224+ ft0 = _get_cgrid_depth_point (zeta = ctx .zeta , data = ctx .data [ctx .ti , :, :, :], zi = ctx .zi , yi = ctx .yi , xi = ctx .xi )
225+ if not should_calculate_next_ti (ctx .ti , ctx .tau , ctx .data .shape [0 ]):
226+ return ft0
227+
228+ if ctx .gridindexingtype == "nemo" :
229+ ft1 = _get_cgrid_depth_point (
230+ zeta = ctx .zeta , data = ctx .data [ctx .ti + 1 , :, :, :], zi = ctx .zi , yi = ctx .yi + 1 , xi = ctx .xi + 1
231+ )
232+ elif ctx .gridindexingtype in ["mitgcm" , "croco" ]:
233+ ft1 = _get_cgrid_depth_point (zeta = ctx .zeta , data = ctx .data [ctx .ti + 1 , :, :, :], zi = ctx .zi , yi = ctx .yi , xi = ctx .xi )
234+ return (1 - ctx .tau ) * ft0 + ctx .tau * ft1
190235
191236
192237@register_3d_interpolator ("linear_invdist_land_tracer" )
193238def _linear_invdist_land_tracer_3d (ctx : InterpolationContext3D ) -> float :
194239 land = np .isclose (ctx .data [ctx .ti , ctx .zi : ctx .zi + 2 , ctx .yi : ctx .yi + 2 , ctx .xi : ctx .xi + 2 ], 0.0 )
195240 nb_land = np .sum (land )
241+
242+ def _get_data_temporalinterp (* , ti , zi , yi , xi ):
243+ dt0 = ctx .data [ti , zi , yi , xi ]
244+ if not should_calculate_next_ti (ctx .ti , ctx .tau , ctx .data .shape [0 ]):
245+ return dt0
246+ dt1 = data [ti + 1 , zi , yi , xi ]
247+ return (1 - ctx .tau ) * dt0 + ctx .tau * dt1
248+
196249 if nb_land == 8 :
197250 return 0
198251 elif nb_land > 0 :
@@ -206,9 +259,11 @@ def _linear_invdist_land_tracer_3d(ctx: InterpolationContext3D) -> float:
206259 if land [k ][j ][i ] == 1 : # index search led us directly onto land
207260 return 0
208261 else :
209- return ctx .data [ ctx . ti , ctx .zi + k , ctx .yi + j , ctx .xi + i ]
262+ return _get_data_temporalinterp ( ti = ctx .ti , zi = ctx .zi + k , yi = ctx .yi + j , xi = ctx .xi + i )
210263 elif land [k ][j ][i ] == 0 :
211- val += ctx .data [ctx .ti , ctx .zi + k , ctx .yi + j , ctx .xi + i ] / distance
264+ val += (
265+ _get_data_temporalinterp (ti = ctx .ti , zi = ctx .zi + k , yi = ctx .yi + j , xi = ctx .xi + i ) / distance
266+ )
212267 w_sum += 1 / distance
213268 return val / w_sum
214269 else :
@@ -253,9 +308,15 @@ def _z_layer_interp(
253308def _linear_3d (ctx : InterpolationContext3D ) -> float :
254309 zdim = ctx .data .shape [1 ]
255310 data_3d = ctx .data [ctx .ti , :, :, :]
256- f0 , f1 = _get_3d_f0_f1 (eta = ctx .eta , xsi = ctx .xsi , data = data_3d , zi = ctx .zi , yi = ctx .yi , xi = ctx .xi )
311+ fz0 , fz1 = _get_3d_f0_f1 (eta = ctx .eta , xsi = ctx .xsi , data = data_3d , zi = ctx .zi , yi = ctx .yi , xi = ctx .xi )
312+ if should_calculate_next_ti (ctx .ti , ctx .tau , ctx .data .shape [0 ]):
313+ data_3d = ctx .data [ctx .ti + 1 , :, :, :]
314+ fz0_t1 , fz1_t1 = _get_3d_f0_f1 (eta = ctx .eta , xsi = ctx .xsi , data = data_3d , zi = ctx .zi , yi = ctx .yi , xi = ctx .xi )
315+ fz0 = (1 - ctx .tau ) * fz0 + ctx .tau * fz0_t1
316+ if fz1_t1 is not None and fz1 is not None :
317+ fz1 = (1 - ctx .tau ) * fz1 + ctx .tau * fz1_t1
257318
258- return _z_layer_interp (zeta = ctx .zeta , f0 = f0 , f1 = f1 , zi = ctx .zi , zdim = zdim , gridindexingtype = ctx .gridindexingtype )
319+ return _z_layer_interp (zeta = ctx .zeta , f0 = fz0 , f1 = fz1 , zi = ctx .zi , zdim = zdim , gridindexingtype = ctx .gridindexingtype )
259320
260321
261322@register_3d_interpolator ("bgrid_velocity" )
@@ -277,4 +338,8 @@ def _linear_3d_bgrid_w_velocity(ctx: InterpolationContext3D) -> float:
277338@register_3d_interpolator ("bgrid_tracer" )
278339@register_3d_interpolator ("cgrid_tracer" )
279340def _tracer_3d (ctx : InterpolationContext3D ) -> float :
280- return ctx .data [ctx .ti , ctx .zi , ctx .yi + 1 , ctx .xi + 1 ]
341+ ft0 = ctx .data [ctx .ti , ctx .zi , ctx .yi + 1 , ctx .xi + 1 ]
342+ if not should_calculate_next_ti (ctx .ti , ctx .tau , ctx .data .shape [0 ]):
343+ return ft0
344+ ft1 = ctx .data [ctx .ti + 1 , ctx .zi , ctx .yi + 1 , ctx .xi + 1 ]
345+ return (1 - ctx .tau ) * ft0 + ctx .tau * ft1
0 commit comments