Skip to content

Commit 9e91ee0

Browse files
Merge branch 'v4-dev' into testing_nemo_spatialhash
2 parents 3eb2803 + b16fcc8 commit 9e91ee0

File tree

5 files changed

+197
-118
lines changed

5 files changed

+197
-118
lines changed

docs/v4/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ The key goals of this update are
77
1. to support `Fields` on unstructured grids;
88
2. to allow for user-defined interpolation methods (somewhat similar to user-defined kernels);
99
3. to make the codebase more modular, easier to extend, and more maintainable;
10-
4. to align Parcels more with other tools in the [Pangeo ecosystemand](https://www.pangeo.io/#ecosystem), particularly by leveraging `xarray` more; and
10+
4. to align Parcels more with other tools in the [Pangeo ecosystem](https://www.pangeo.io/#ecosystem), particularly by leveraging `xarray` more; and
1111
5. to improve the performance of Parcels.
1212

1313
The timeline for the release of Parcels v4 is not yet fixed, but we are aiming for a release of an 'alpha' version in September 2025. This v4-alpha will have support for unstructured grids and user-defined interpolation methods, but is not yet performance-optimised.

parcels/particleset.py

Lines changed: 68 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
22
import warnings
33
from collections.abc import Iterable
4+
from typing import Literal
45

56
import numpy as np
67
import xarray as xr
@@ -506,59 +507,17 @@ def execute(
506507
if output_file:
507508
output_file.metadata["parcels_kernels"] = self._kernel.name
508509

509-
if (dt is not None) and (not isinstance(dt, np.timedelta64)):
510-
raise TypeError("dt must be a np.timedelta64 object")
511-
if dt is None or np.isnat(dt):
510+
if dt is None:
512511
dt = np.timedelta64(1, "s")
513-
self._data["dt"][:] = dt
514-
sign_dt = np.sign(dt).astype(int)
515-
if sign_dt not in [-1, 1]:
516-
raise ValueError("dt must be a positive or negative np.timedelta64 object")
517512

518-
if self.fieldset.time_interval is None:
519-
start_time = np.timedelta64(0, "s") # For the execution loop, we need a start time as a timedelta object
520-
if runtime is None:
521-
raise TypeError("The runtime must be provided when the time_interval is not defined for a fieldset.")
513+
if not isinstance(dt, np.timedelta64) or np.isnat(dt) or (sign_dt := np.sign(dt).astype(int)) not in [-1, 1]:
514+
raise ValueError(f"dt must be a positive or negative np.timedelta64 object, got {dt=!r}")
522515

523-
else:
524-
if isinstance(runtime, np.timedelta64):
525-
end_time = runtime
526-
else:
527-
raise TypeError("The runtime must be a np.timedelta64 object")
516+
self._data["dt"][:] = dt
528517

529-
else:
530-
if not np.isnat(self.time_nextloop).any():
531-
if sign_dt > 0:
532-
start_time = self.time_nextloop.min()
533-
else:
534-
start_time = self.time_nextloop.max()
535-
else:
536-
if sign_dt > 0:
537-
start_time = self.fieldset.time_interval.left
538-
else:
539-
start_time = self.fieldset.time_interval.right
540-
541-
if runtime is None:
542-
if endtime is None:
543-
raise ValueError(
544-
"Must provide either runtime or endtime when time_interval is defined for a fieldset."
545-
)
546-
# Ensure that the endtime uses the same type as the start_time
547-
if isinstance(endtime, self.fieldset.time_interval.left.__class__):
548-
if sign_dt > 0:
549-
if endtime < self.fieldset.time_interval.left:
550-
raise ValueError("The endtime must be after the start time of the fieldset.time_interval")
551-
end_time = min(endtime, self.fieldset.time_interval.right)
552-
else:
553-
if endtime > self.fieldset.time_interval.right:
554-
raise ValueError(
555-
"The endtime must be before the end time of the fieldset.time_interval when dt < 0"
556-
)
557-
end_time = max(endtime, self.fieldset.time_interval.left)
558-
else:
559-
raise TypeError("The endtime must be of the same type as the fieldset.time_interval start time.")
560-
else:
561-
end_time = start_time + runtime * sign_dt
518+
start_time, end_time = _get_simulation_start_and_end_times(
519+
self.fieldset.time_interval, self._data["time_nextloop"], runtime, endtime, sign_dt
520+
)
562521

563522
# Set the time of the particles if it hadn't been set on initialisation
564523
if np.isnat(self._data["time"]).any():
@@ -619,15 +578,69 @@ def _warn_particle_times_outside_fieldset_time_bounds(release_times: np.ndarray,
619578

620579
if isinstance(time.left, np.datetime64) and isinstance(release_times[0], np.timedelta64):
621580
release_times = np.array([t + time.left for t in release_times])
622-
if np.any(release_times < time.left):
581+
if np.any(release_times < time.left) or np.any(release_times > time.right):
623582
warnings.warn(
624583
"Some particles are set to be released outside the FieldSet's executable time domain.",
625584
ParticleSetWarning,
626585
stacklevel=2,
627586
)
628-
if np.any(release_times > time.right):
629-
warnings.warn(
630-
"Some particles are set to be released after the fieldset's last time and the fields are not constant in time.",
631-
ParticleSetWarning,
632-
stacklevel=2,
587+
588+
589+
def _get_simulation_start_and_end_times(
590+
time_interval: TimeInterval,
591+
particle_release_times: np.ndarray,
592+
runtime: np.timedelta64 | None,
593+
endtime: np.datetime64 | None,
594+
sign_dt: Literal[-1, 1],
595+
) -> tuple[np.datetime64, np.datetime64]:
596+
if runtime is not None and endtime is not None:
597+
raise ValueError(
598+
f"runtime and endtime are mutually exclusive - provide one or the other. Got {runtime=!r}, {endtime=!r}"
633599
)
600+
601+
if runtime is None and time_interval is None:
602+
raise ValueError("The runtime must be provided when the time_interval is not defined for a fieldset.")
603+
604+
if sign_dt == 1:
605+
first_release_time = particle_release_times.min()
606+
else:
607+
first_release_time = particle_release_times.max()
608+
609+
start_time = _get_start_time(first_release_time, time_interval, sign_dt, runtime)
610+
611+
if endtime is None:
612+
if not isinstance(runtime, np.timedelta64):
613+
raise ValueError(f"The runtime must be a np.timedelta64 object. Got {type(runtime)}")
614+
615+
endtime = start_time + sign_dt * runtime
616+
617+
if time_interval is not None:
618+
if type(endtime) != type(time_interval.left): # noqa: E721
619+
raise ValueError(
620+
f"The endtime must be of the same type as the fieldset.time_interval start time. Got {endtime=!r} with {time_interval=!r}"
621+
)
622+
if endtime not in time_interval:
623+
msg = (
624+
f"Calculated/provided end time of {endtime!r} is not in fieldset time interval {time_interval!r}. Either reduce your runtime, modify your "
625+
"provided endtime, or change your release timing."
626+
"Important info:\n"
627+
f" First particle release: {first_release_time!r}\n"
628+
f" runtime: {runtime!r}\n"
629+
f" (calculated) endtime: {endtime!r}"
630+
)
631+
raise ValueError(msg)
632+
633+
return start_time, endtime
634+
635+
636+
def _get_start_time(first_release_time, time_interval, sign_dt, runtime):
637+
if time_interval is None:
638+
time_interval = TimeInterval(left=np.timedelta64(0, "s"), right=runtime)
639+
640+
if sign_dt == 1:
641+
fieldset_start = time_interval.left
642+
else:
643+
fieldset_start = time_interval.right
644+
645+
start_time = first_release_time if not np.isnat(first_release_time) else fieldset_start
646+
return start_time

tests/v4/test_particleset.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -114,21 +114,6 @@ def test_pset_create_outside_time(fieldset):
114114
ParticleSet(fieldset, pclass=Particle, lon=[0] * len(time), lat=[0] * len(time), time=time)
115115

116116

117-
@pytest.mark.parametrize(
118-
"dt, expectation",
119-
[
120-
(np.timedelta64(5, "s"), does_not_raise()),
121-
(5.0, pytest.raises(TypeError)),
122-
(np.datetime64("2000-01-02T00:00:00"), pytest.raises(TypeError)),
123-
(timedelta(seconds=2), pytest.raises(TypeError)),
124-
],
125-
)
126-
def test_particleset_dt_type(fieldset, dt, expectation):
127-
pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], depth=[50.0], pclass=Particle)
128-
with expectation:
129-
pset.execute(runtime=np.timedelta64(10, "s"), dt=dt, pyfunc=DoNothing)
130-
131-
132117
def test_pset_starttime_not_multiple_dt(fieldset):
133118
times = [0, 1, 2]
134119
datetimes = [fieldset.time_interval.left + np.timedelta64(t, "s") for t in times]
@@ -141,38 +126,6 @@ def Addlon(particle, fieldset, time): # pragma: no cover
141126
assert np.allclose([p.lon_nextloop for p in pset], [8 - t for t in times])
142127

143128

144-
@pytest.mark.parametrize(
145-
"runtime, expectation",
146-
[
147-
(np.timedelta64(5, "s"), does_not_raise()),
148-
(5.0, pytest.raises(TypeError)),
149-
(timedelta(seconds=2), pytest.raises(TypeError)),
150-
(np.datetime64("2001-01-02T00:00:00"), pytest.raises(TypeError)),
151-
(datetime(2000, 1, 2, 0, 0, 0), pytest.raises(TypeError)),
152-
],
153-
)
154-
def test_particleset_runtime_type(fieldset, runtime, expectation):
155-
pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], depth=[50.0], pclass=Particle)
156-
with expectation:
157-
pset.execute(runtime=runtime, dt=np.timedelta64(10, "s"), pyfunc=DoNothing)
158-
159-
160-
@pytest.mark.parametrize(
161-
"endtime, expectation",
162-
[
163-
(np.datetime64("2000-01-02T00:00:00"), does_not_raise()),
164-
(5.0, pytest.raises(TypeError)),
165-
(np.timedelta64(5, "s"), pytest.raises(TypeError)),
166-
(timedelta(seconds=2), pytest.raises(TypeError)),
167-
(datetime(2000, 1, 2, 0, 0, 0), pytest.raises(TypeError)),
168-
],
169-
)
170-
def test_particleset_endtime_type(fieldset, endtime, expectation):
171-
pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], depth=[50.0], pclass=Particle)
172-
with expectation:
173-
pset.execute(endtime=endtime, dt=np.timedelta64(10, "m"), pyfunc=DoNothing)
174-
175-
176129
def test_pset_add_explicit(fieldset):
177130
npart = 11
178131
lon = np.linspace(0, 1, npart)

tests/v4/test_particleset_execute.py

Lines changed: 115 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from contextlib import nullcontext as does_not_raise
2+
from datetime import datetime, timedelta
3+
14
import numpy as np
25
import pytest
36

@@ -27,7 +30,20 @@ def fieldset() -> FieldSet:
2730
grid = XGrid.from_dataset(ds, mesh="flat")
2831
U = Field("U", ds["U (A grid)"], grid)
2932
V = Field("V", ds["V (A grid)"], grid)
30-
return FieldSet([U, V])
33+
UV = VectorField("UV", U, V)
34+
return FieldSet([U, V, UV])
35+
36+
37+
@pytest.fixture
38+
def fieldset_no_time_interval() -> FieldSet:
39+
# i.e., no time variation
40+
ds = datasets_structured["ds_2d_left"].isel(time=0).drop("time")
41+
42+
grid = XGrid.from_dataset(ds, mesh="flat")
43+
U = Field("U", ds["U (A grid)"], grid)
44+
V = Field("V", ds["V (A grid)"], grid)
45+
UV = VectorField("UV", U, V)
46+
return FieldSet([U, V, UV])
3147

3248

3349
@pytest.fixture
@@ -41,6 +57,98 @@ def zonal_flow_fieldset() -> FieldSet:
4157
return FieldSet([U, V, UV])
4258

4359

60+
def test_pset_execute_implicit_dt_one_second(fieldset):
61+
pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle)
62+
pset.execute(DoNothing, runtime=np.timedelta64(1, "s"))
63+
64+
time = pset.time.copy()
65+
66+
pset.execute(DoNothing, runtime=np.timedelta64(1, "s"))
67+
np.testing.assert_array_equal(pset.time, time + np.timedelta64(1, "s"))
68+
69+
70+
def test_pset_execute_invalid_arguments(fieldset, fieldset_no_time_interval):
71+
for dt in [1, np.timedelta64(0, "s"), np.timedelta64(None)]:
72+
with pytest.raises(
73+
ValueError,
74+
match="dt must be a positive or negative np.timedelta64 object, got .*",
75+
):
76+
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(dt=dt)
77+
78+
with pytest.raises(
79+
ValueError,
80+
match="runtime and endtime are mutually exclusive - provide one or the other. Got .*",
81+
):
82+
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(
83+
runtime=np.timedelta64(1, "s"), endtime=np.datetime64("2100-01-01")
84+
)
85+
86+
with pytest.raises(
87+
ValueError,
88+
match="The runtime must be a np.timedelta64 object. Got .*",
89+
):
90+
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(runtime=1)
91+
92+
msg = """Calculated/provided end time of .* is not in fieldset time interval .* Either reduce your runtime, modify your provided endtime, or change your release timing.*"""
93+
with pytest.raises(
94+
ValueError,
95+
match=msg,
96+
):
97+
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(endtime=np.datetime64("1990-01-01"))
98+
99+
with pytest.raises(
100+
ValueError,
101+
match=msg,
102+
):
103+
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(
104+
endtime=np.datetime64("2100-01-01"), dt=np.timedelta64(-1, "s")
105+
)
106+
107+
with pytest.raises(
108+
ValueError,
109+
match="The endtime must be of the same type as the fieldset.time_interval start time. Got .*",
110+
):
111+
ParticleSet(fieldset, lon=[0.2], lat=[5.0], pclass=Particle).execute(endtime=12345)
112+
113+
with pytest.raises(
114+
ValueError,
115+
match="The runtime must be provided when the time_interval is not defined for a fieldset.",
116+
):
117+
ParticleSet(fieldset_no_time_interval, lon=[0.2], lat=[5.0], pclass=Particle).execute()
118+
119+
120+
@pytest.mark.parametrize(
121+
"runtime, expectation",
122+
[
123+
(np.timedelta64(5, "s"), does_not_raise()),
124+
(5.0, pytest.raises(ValueError)),
125+
(timedelta(seconds=2), pytest.raises(ValueError)),
126+
(np.datetime64("2001-01-02T00:00:00"), pytest.raises(ValueError)),
127+
(datetime(2000, 1, 2, 0, 0, 0), pytest.raises(ValueError)),
128+
],
129+
)
130+
def test_particleset_runtime_type(fieldset, runtime, expectation):
131+
pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], depth=[50.0], pclass=Particle)
132+
with expectation:
133+
pset.execute(runtime=runtime, dt=np.timedelta64(10, "s"), pyfunc=DoNothing)
134+
135+
136+
@pytest.mark.parametrize(
137+
"endtime, expectation",
138+
[
139+
(np.datetime64("2000-01-02T00:00:00"), does_not_raise()),
140+
(5.0, pytest.raises(ValueError)),
141+
(np.timedelta64(5, "s"), pytest.raises(ValueError)),
142+
(timedelta(seconds=2), pytest.raises(ValueError)),
143+
(datetime(2000, 1, 2, 0, 0, 0), pytest.raises(ValueError)),
144+
],
145+
)
146+
def test_particleset_endtime_type(fieldset, endtime, expectation):
147+
pset = ParticleSet(fieldset, lon=[0.2], lat=[5.0], depth=[50.0], pclass=Particle)
148+
with expectation:
149+
pset.execute(endtime=endtime, dt=np.timedelta64(10, "m"), pyfunc=DoNothing)
150+
151+
44152
def test_pset_remove_particle_in_kernel(fieldset):
45153
npart = 100
46154
pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.linspace(1, 0, npart))
@@ -92,7 +200,8 @@ def AddLat(particle, fieldset, time): # pragma: no cover
92200
def test_execution_endtime(fieldset, starttime, endtime, dt):
93201
starttime = fieldset.time_interval.left + np.timedelta64(starttime, "s")
94202
endtime = fieldset.time_interval.left + np.timedelta64(endtime, "s")
95-
dt = np.timedelta64(dt, "s")
203+
if dt is not None:
204+
dt = np.timedelta64(dt, "s")
96205
pset = ParticleSet(fieldset, time=starttime, lon=0, lat=0)
97206
pset.execute(DoNothing, endtime=endtime, dt=dt)
98207
assert abs(pset.time_nextloop - endtime) < np.timedelta64(1, "ms")
@@ -152,10 +261,10 @@ def test_some_particles_throw_outoftime(fieldset):
152261
pset = ParticleSet(fieldset, lon=np.zeros_like(time), lat=np.zeros_like(time), time=time)
153262

154263
def FieldAccessOutsideTime(particle, fieldset, time): # pragma: no cover
155-
fieldset.U[particle.time + np.timedelta64(1, "D"), particle.depth, particle.lat, particle.lon, particle]
264+
fieldset.U[particle.time + np.timedelta64(400, "D"), particle.depth, particle.lat, particle.lon, particle]
156265

157266
with pytest.raises(TimeExtrapolationError):
158-
pset.execute(FieldAccessOutsideTime, runtime=np.timedelta64(400, "D"), dt=np.timedelta64(10, "D"))
267+
pset.execute(FieldAccessOutsideTime, runtime=np.timedelta64(1, "D"), dt=np.timedelta64(10, "D"))
159268

160269

161270
def test_execution_check_stopallexecution(fieldset):
@@ -200,7 +309,8 @@ def test_execution_runtime(fieldset, starttime, runtime, dt, npart):
200309
starttime = fieldset.time_interval.left + np.timedelta64(starttime, "s")
201310
runtime = np.timedelta64(runtime, "s")
202311
sign_dt = 1 if dt is None else np.sign(dt)
203-
dt = np.timedelta64(dt, "s")
312+
if dt is not None:
313+
dt = np.timedelta64(dt, "s")
204314
pset = ParticleSet(fieldset, time=starttime, lon=np.zeros(npart), lat=np.zeros(npart))
205315
pset.execute(DoNothing, runtime=runtime, dt=dt)
206316
assert all([abs(p.time_nextloop - starttime - runtime * sign_dt) < np.timedelta64(1, "ms") for p in pset])

0 commit comments

Comments
 (0)