diff --git a/.gitignore b/.gitignore index ac1b3b1980..50fb0120e0 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,7 @@ parcels.egg-info/* dist/parcels*.egg parcels/_version_setup.py .pytest_cache +.hypothesis .coverage # pixi environments diff --git a/environment.yml b/environment.yml index 021a29d751..4b8694f0c4 100644 --- a/environment.yml +++ b/environment.yml @@ -28,6 +28,7 @@ dependencies: #! Keep in sync with [tool.pixi.dependencies] in pyproject.toml - pytest - pytest-html - coverage + - hypothesis # Typing - mypy diff --git a/parcels/_core/utils/time.py b/parcels/_core/utils/time.py new file mode 100644 index 0000000000..917be1a4d9 --- /dev/null +++ b/parcels/_core/utils/time.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TypeVar + +import cftime + +T = TypeVar("T", datetime, cftime.datetime) + + +class TimeInterval: + """A class representing a time interval between two datetime objects. + + Parameters + ---------- + left : datetime or cftime.datetime + The left endpoint of the interval. + right : datetime or cftime.datetime + The right endpoint of the interval. + + Notes + ----- + For the purposes of this codebase, the interval can be thought of as closed on the left and right. + """ + + def __init__(self, left: T, right: T) -> None: + if not isinstance(left, (datetime, cftime.datetime)): + raise ValueError(f"Expected left to be a datetime or cftime.datetime, got {type(left)}.") + if not isinstance(right, (datetime, cftime.datetime)): + raise ValueError(f"Expected right to be a datetime or cftime.datetime, got {type(right)}.") + if left >= right: + raise ValueError(f"Expected left to be strictly less than right, got left={left} and right={right}.") + if not is_compatible(left, right): + raise ValueError(f"Expected left and right to be compatible, got left={left} and right={right}.") + + self.left = left + self.right = right + + def __contains__(self, item: T) -> bool: + return self.left <= item <= self.right + + def __repr__(self) -> str: + return f"TimeInterval(left={self.left!r}, right={self.right!r})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TimeInterval): + return False + return self.left == other.left and self.right == other.right + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + def intersection(self, other: TimeInterval) -> TimeInterval | None: + """Return the intersection of two time intervals. Returns None if there is no overlap.""" + if not is_compatible(self.left, other.left): + raise ValueError("TimeIntervals are not compatible.") + + start = max(self.left, other.left) + end = min(self.right, other.right) + + return TimeInterval(start, end) if start <= end else None + + +def is_compatible(t1: datetime | cftime.datetime, t2: datetime | cftime.datetime) -> bool: + """Checks whether two (cftime.)datetime objects are compatible.""" + try: + t1 - t2 + except Exception: + return False + else: + return True diff --git a/pyproject.toml b/pyproject.toml index fddc2387db..0acf33ee75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,7 @@ trajan = "*" # Testing nbval = "*" pytest = "*" +hypothesis = "*" pytest-html = "*" coverage = "*" diff --git a/tests/v4/test_gridadapter.py b/tests/v4/test_gridadapter.py index 24c67b70b6..9f1ead0ac7 100644 --- a/tests/v4/test_gridadapter.py +++ b/tests/v4/test_gridadapter.py @@ -10,18 +10,18 @@ from parcels.v4.grid import Grid as NewGrid from parcels.v4.gridadapter import GridAdapter -TestCase = namedtuple("TestCase", ["Grid", "attr", "expected"]) +GridTestCase = namedtuple("GridTestCase", ["Grid", "attr", "expected"]) test_cases = [ - TestCase(datasets["ds_2d_left"], "lon", datasets["ds_2d_left"].XG.values), - TestCase(datasets["ds_2d_left"], "lat", datasets["ds_2d_left"].YG.values), - TestCase(datasets["ds_2d_left"], "depth", datasets["ds_2d_left"].ZG.values), - TestCase(datasets["ds_2d_left"], "time", datasets["ds_2d_left"].time.values), - TestCase(datasets["ds_2d_left"], "xdim", N), - TestCase(datasets["ds_2d_left"], "ydim", 2 * N), - TestCase(datasets["ds_2d_left"], "zdim", 3 * N), - TestCase(datasets["ds_2d_left"], "tdim", T), - TestCase(datasets["ds_2d_left"], "time_origin", TimeConverter(datasets["ds_2d_left"].time.values[0])), + GridTestCase(datasets["ds_2d_left"], "lon", datasets["ds_2d_left"].XG.values), + GridTestCase(datasets["ds_2d_left"], "lat", datasets["ds_2d_left"].YG.values), + GridTestCase(datasets["ds_2d_left"], "depth", datasets["ds_2d_left"].ZG.values), + GridTestCase(datasets["ds_2d_left"], "time", datasets["ds_2d_left"].time.values), + GridTestCase(datasets["ds_2d_left"], "xdim", N), + GridTestCase(datasets["ds_2d_left"], "ydim", 2 * N), + GridTestCase(datasets["ds_2d_left"], "zdim", 3 * N), + GridTestCase(datasets["ds_2d_left"], "tdim", T), + GridTestCase(datasets["ds_2d_left"], "time_origin", TimeConverter(datasets["ds_2d_left"].time.values[0])), ] diff --git a/tests/v4/utils/test_time.py b/tests/v4/utils/test_time.py new file mode 100644 index 0000000000..c572df2eef --- /dev/null +++ b/tests/v4/utils/test_time.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +from datetime import datetime, timedelta + +import pytest +from cftime import datetime as cftime_datetime +from hypothesis import given +from hypothesis import strategies as st + +from parcels._core.utils.time import TimeInterval + +calendar_strategy = st.sampled_from(["gregorian", "proleptic_gregorian", "365_day", "360_day", "julian", "366_day"]) + + +@st.composite +def cftime_datetime_strategy(draw, calendar=None): + year = draw(st.integers(1900, 2100)) + month = draw(st.integers(1, 12)) + day = draw(st.integers(1, 28)) + if calendar is None: + calendar = draw(calendar_strategy) + return cftime_datetime(year, month, day, calendar=calendar) + + +@st.composite +def cftime_interval_strategy(draw, left=None, calendar=None): + if left is None: + left = draw(cftime_datetime_strategy(calendar=calendar)) + right = left + draw( + st.timedeltas( + min_value=timedelta(seconds=1), + max_value=timedelta(days=100 * 365), + ) + ) + return TimeInterval(left, right) + + +@pytest.mark.parametrize( + "left,right", + [ + (cftime_datetime(2023, 1, 1, calendar="gregorian"), cftime_datetime(2023, 1, 2, calendar="gregorian")), + (cftime_datetime(2023, 6, 1, calendar="365_day"), cftime_datetime(2023, 6, 2, calendar="365_day")), + (cftime_datetime(2023, 12, 1, calendar="360_day"), cftime_datetime(2023, 12, 2, calendar="360_day")), + ], +) +def test_time_interval_initialization(left, right): + """Test that TimeInterval can be initialized with valid inputs.""" + interval = TimeInterval(left, right) + assert interval.left == left + assert interval.right == right + + with pytest.raises(ValueError): + TimeInterval(right, left) + + +@given(cftime_interval_strategy()) +def test_time_interval_contains(interval): + left = interval.left + right = interval.right + middle = left + (right - left) / 2 + + assert left in interval + assert right in interval + assert middle in interval + + +@given(cftime_interval_strategy(calendar="365_day"), cftime_interval_strategy(calendar="365_day")) +def test_time_interval_intersection_commutative(interval1, interval2): + assert interval1.intersection(interval2) == interval2.intersection(interval1) + + +@given(cftime_interval_strategy()) +def test_time_interval_intersection_with_self(interval): + assert interval.intersection(interval) == interval + + +def test_time_interval_repr(): + """Test the string representation of TimeInterval.""" + interval = TimeInterval(datetime(2023, 1, 1, 12, 0), datetime(2023, 1, 2, 12, 0)) + expected = "TimeInterval(left=datetime.datetime(2023, 1, 1, 12, 0), right=datetime.datetime(2023, 1, 2, 12, 0))" + assert repr(interval) == expected + + +@given(cftime_interval_strategy()) +def test_time_interval_equality(interval): + assert interval == interval + + +@pytest.mark.parametrize( + "interval1,interval2,expected", + [ + pytest.param( + TimeInterval( + cftime_datetime(2023, 1, 1, calendar="gregorian"), cftime_datetime(2023, 1, 3, calendar="gregorian") + ), + TimeInterval( + cftime_datetime(2023, 1, 2, calendar="gregorian"), cftime_datetime(2023, 1, 4, calendar="gregorian") + ), + TimeInterval( + cftime_datetime(2023, 1, 2, calendar="gregorian"), cftime_datetime(2023, 1, 3, calendar="gregorian") + ), + id="overlapping intervals", + ), + pytest.param( + TimeInterval( + cftime_datetime(2023, 1, 1, calendar="gregorian"), cftime_datetime(2023, 1, 3, calendar="gregorian") + ), + TimeInterval( + cftime_datetime(2023, 1, 5, calendar="gregorian"), cftime_datetime(2023, 1, 6, calendar="gregorian") + ), + None, + id="non-overlapping intervals", + ), + pytest.param( + TimeInterval( + cftime_datetime(2023, 1, 1, calendar="gregorian"), cftime_datetime(2023, 1, 3, calendar="gregorian") + ), + TimeInterval( + cftime_datetime(2023, 1, 1, calendar="gregorian"), cftime_datetime(2023, 1, 2, calendar="gregorian") + ), + TimeInterval( + cftime_datetime(2023, 1, 1, calendar="gregorian"), cftime_datetime(2023, 1, 2, calendar="gregorian") + ), + id="intervals with same start time", + ), + pytest.param( + TimeInterval( + cftime_datetime(2023, 1, 1, calendar="gregorian"), cftime_datetime(2023, 1, 3, calendar="gregorian") + ), + TimeInterval( + cftime_datetime(2023, 1, 2, calendar="gregorian"), cftime_datetime(2023, 1, 3, calendar="gregorian") + ), + TimeInterval( + cftime_datetime(2023, 1, 2, calendar="gregorian"), cftime_datetime(2023, 1, 3, calendar="gregorian") + ), + id="intervals with same end time", + ), + ], +) +def test_time_interval_intersection(interval1, interval2, expected): + """Test the intersection of two time intervals.""" + result = interval1.intersection(interval2) + if expected is None: + assert result is None + else: + assert result.left == expected.left + assert result.right == expected.right + + +def test_time_interval_intersection_different_calendars(): + interval1 = TimeInterval( + cftime_datetime(2023, 1, 1, calendar="gregorian"), cftime_datetime(2023, 1, 3, calendar="gregorian") + ) + interval2 = TimeInterval( + cftime_datetime(2023, 1, 1, calendar="365_day"), cftime_datetime(2023, 1, 3, calendar="365_day") + ) + with pytest.raises(ValueError, match="TimeIntervals are not compatible."): + interval1.intersection(interval2)