1+ from contextlib import nullcontext as does_not_raise
2+ from datetime import datetime , timedelta
3+
14import numpy as np
25import pytest
36
@@ -27,7 +30,20 @@ def fieldset() -> FieldSet:
2730 grid = XGrid .from_dataset (ds , mesh = "flat" )
2831 U = Field ("U" , ds ["U (A grid)" ], grid )
2932 V = Field ("V" , ds ["V (A grid)" ], grid )
30- return FieldSet ([U , V ])
33+ UV = VectorField ("UV" , U , V )
34+ return FieldSet ([U , V , UV ])
35+
36+
37+ @pytest .fixture
38+ def fieldset_no_time_interval () -> FieldSet :
39+ # i.e., no time variation
40+ ds = datasets_structured ["ds_2d_left" ].isel (time = 0 ).drop ("time" )
41+
42+ grid = XGrid .from_dataset (ds , mesh = "flat" )
43+ U = Field ("U" , ds ["U (A grid)" ], grid )
44+ V = Field ("V" , ds ["V (A grid)" ], grid )
45+ UV = VectorField ("UV" , U , V )
46+ return FieldSet ([U , V , UV ])
3147
3248
3349@pytest .fixture
@@ -41,6 +57,98 @@ def zonal_flow_fieldset() -> FieldSet:
4157 return FieldSet ([U , V , UV ])
4258
4359
60+ def test_pset_execute_implicit_dt_one_second (fieldset ):
61+ pset = ParticleSet (fieldset , lon = [0.2 ], lat = [5.0 ], pclass = Particle )
62+ pset .execute (DoNothing , runtime = np .timedelta64 (1 , "s" ))
63+
64+ time = pset .time .copy ()
65+
66+ pset .execute (DoNothing , runtime = np .timedelta64 (1 , "s" ))
67+ np .testing .assert_array_equal (pset .time , time + np .timedelta64 (1 , "s" ))
68+
69+
70+ def test_pset_execute_invalid_arguments (fieldset , fieldset_no_time_interval ):
71+ for dt in [1 , np .timedelta64 (0 , "s" ), np .timedelta64 (None )]:
72+ with pytest .raises (
73+ ValueError ,
74+ match = "dt must be a positive or negative np.timedelta64 object, got .*" ,
75+ ):
76+ ParticleSet (fieldset , lon = [0.2 ], lat = [5.0 ], pclass = Particle ).execute (dt = dt )
77+
78+ with pytest .raises (
79+ ValueError ,
80+ match = "runtime and endtime are mutually exclusive - provide one or the other. Got .*" ,
81+ ):
82+ ParticleSet (fieldset , lon = [0.2 ], lat = [5.0 ], pclass = Particle ).execute (
83+ runtime = np .timedelta64 (1 , "s" ), endtime = np .datetime64 ("2100-01-01" )
84+ )
85+
86+ with pytest .raises (
87+ ValueError ,
88+ match = "The runtime must be a np.timedelta64 object. Got .*" ,
89+ ):
90+ ParticleSet (fieldset , lon = [0.2 ], lat = [5.0 ], pclass = Particle ).execute (runtime = 1 )
91+
92+ msg = """Calculated/provided end time of .* is not in fieldset time interval .* Either reduce your runtime, modify your provided endtime, or change your release timing.*"""
93+ with pytest .raises (
94+ ValueError ,
95+ match = msg ,
96+ ):
97+ ParticleSet (fieldset , lon = [0.2 ], lat = [5.0 ], pclass = Particle ).execute (endtime = np .datetime64 ("1990-01-01" ))
98+
99+ with pytest .raises (
100+ ValueError ,
101+ match = msg ,
102+ ):
103+ ParticleSet (fieldset , lon = [0.2 ], lat = [5.0 ], pclass = Particle ).execute (
104+ endtime = np .datetime64 ("2100-01-01" ), dt = np .timedelta64 (- 1 , "s" )
105+ )
106+
107+ with pytest .raises (
108+ ValueError ,
109+ match = "The endtime must be of the same type as the fieldset.time_interval start time. Got .*" ,
110+ ):
111+ ParticleSet (fieldset , lon = [0.2 ], lat = [5.0 ], pclass = Particle ).execute (endtime = 12345 )
112+
113+ with pytest .raises (
114+ ValueError ,
115+ match = "The runtime must be provided when the time_interval is not defined for a fieldset." ,
116+ ):
117+ ParticleSet (fieldset_no_time_interval , lon = [0.2 ], lat = [5.0 ], pclass = Particle ).execute ()
118+
119+
120+ @pytest .mark .parametrize (
121+ "runtime, expectation" ,
122+ [
123+ (np .timedelta64 (5 , "s" ), does_not_raise ()),
124+ (5.0 , pytest .raises (ValueError )),
125+ (timedelta (seconds = 2 ), pytest .raises (ValueError )),
126+ (np .datetime64 ("2001-01-02T00:00:00" ), pytest .raises (ValueError )),
127+ (datetime (2000 , 1 , 2 , 0 , 0 , 0 ), pytest .raises (ValueError )),
128+ ],
129+ )
130+ def test_particleset_runtime_type (fieldset , runtime , expectation ):
131+ pset = ParticleSet (fieldset , lon = [0.2 ], lat = [5.0 ], depth = [50.0 ], pclass = Particle )
132+ with expectation :
133+ pset .execute (runtime = runtime , dt = np .timedelta64 (10 , "s" ), pyfunc = DoNothing )
134+
135+
136+ @pytest .mark .parametrize (
137+ "endtime, expectation" ,
138+ [
139+ (np .datetime64 ("2000-01-02T00:00:00" ), does_not_raise ()),
140+ (5.0 , pytest .raises (ValueError )),
141+ (np .timedelta64 (5 , "s" ), pytest .raises (ValueError )),
142+ (timedelta (seconds = 2 ), pytest .raises (ValueError )),
143+ (datetime (2000 , 1 , 2 , 0 , 0 , 0 ), pytest .raises (ValueError )),
144+ ],
145+ )
146+ def test_particleset_endtime_type (fieldset , endtime , expectation ):
147+ pset = ParticleSet (fieldset , lon = [0.2 ], lat = [5.0 ], depth = [50.0 ], pclass = Particle )
148+ with expectation :
149+ pset .execute (endtime = endtime , dt = np .timedelta64 (10 , "m" ), pyfunc = DoNothing )
150+
151+
44152def test_pset_remove_particle_in_kernel (fieldset ):
45153 npart = 100
46154 pset = ParticleSet (fieldset , lon = np .linspace (0 , 1 , npart ), lat = np .linspace (1 , 0 , npart ))
@@ -92,7 +200,8 @@ def AddLat(particle, fieldset, time): # pragma: no cover
92200def test_execution_endtime (fieldset , starttime , endtime , dt ):
93201 starttime = fieldset .time_interval .left + np .timedelta64 (starttime , "s" )
94202 endtime = fieldset .time_interval .left + np .timedelta64 (endtime , "s" )
95- dt = np .timedelta64 (dt , "s" )
203+ if dt is not None :
204+ dt = np .timedelta64 (dt , "s" )
96205 pset = ParticleSet (fieldset , time = starttime , lon = 0 , lat = 0 )
97206 pset .execute (DoNothing , endtime = endtime , dt = dt )
98207 assert abs (pset .time_nextloop - endtime ) < np .timedelta64 (1 , "ms" )
@@ -152,10 +261,10 @@ def test_some_particles_throw_outoftime(fieldset):
152261 pset = ParticleSet (fieldset , lon = np .zeros_like (time ), lat = np .zeros_like (time ), time = time )
153262
154263 def FieldAccessOutsideTime (particle , fieldset , time ): # pragma: no cover
155- fieldset .U [particle .time + np .timedelta64 (1 , "D" ), particle .depth , particle .lat , particle .lon , particle ]
264+ fieldset .U [particle .time + np .timedelta64 (400 , "D" ), particle .depth , particle .lat , particle .lon , particle ]
156265
157266 with pytest .raises (TimeExtrapolationError ):
158- pset .execute (FieldAccessOutsideTime , runtime = np .timedelta64 (400 , "D" ), dt = np .timedelta64 (10 , "D" ))
267+ pset .execute (FieldAccessOutsideTime , runtime = np .timedelta64 (1 , "D" ), dt = np .timedelta64 (10 , "D" ))
159268
160269
161270def test_execution_check_stopallexecution (fieldset ):
@@ -200,7 +309,8 @@ def test_execution_runtime(fieldset, starttime, runtime, dt, npart):
200309 starttime = fieldset .time_interval .left + np .timedelta64 (starttime , "s" )
201310 runtime = np .timedelta64 (runtime , "s" )
202311 sign_dt = 1 if dt is None else np .sign (dt )
203- dt = np .timedelta64 (dt , "s" )
312+ if dt is not None :
313+ dt = np .timedelta64 (dt , "s" )
204314 pset = ParticleSet (fieldset , time = starttime , lon = np .zeros (npart ), lat = np .zeros (npart ))
205315 pset .execute (DoNothing , runtime = runtime , dt = dt )
206316 assert all ([abs (p .time_nextloop - starttime - runtime * sign_dt ) < np .timedelta64 (1 , "ms" ) for p in pset ])
0 commit comments