Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions parcels/_core/utils/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TypeVar

import cftime
import numpy as np

T = TypeVar("T", datetime, cftime.datetime)

Expand All @@ -24,10 +25,10 @@ class TimeInterval:
"""

def __init__(self, left: T, right: T) -> None:
if not isinstance(left, (datetime, cftime.datetime)):
raise ValueError(f"Expected left to be a datetime or cftime.datetime, got {type(left)}.")
if not isinstance(right, (datetime, cftime.datetime)):
raise ValueError(f"Expected right to be a datetime or cftime.datetime, got {type(right)}.")
if not isinstance(left, (datetime, cftime.datetime, np.datetime64)):
raise ValueError(f"Expected right to be a datetime, cftime.datetime, or np.datetime64. Got {type(left)}.")
if not isinstance(right, (datetime, cftime.datetime, np.datetime64)):
raise ValueError(f"Expected right to be a datetime, cftime.datetime, or np.datetime64. Got {type(right)}.")
if left >= right:
raise ValueError(f"Expected left to be strictly less than right, got left={left} and right={right}.")
if not is_compatible(left, right):
Expand Down
40 changes: 20 additions & 20 deletions parcels/_datasets/structured/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
__all__ = ["N", "T", "datasets"]

N = 30
T = 10
T = 13


def _rotated_curvilinear_grid():
Expand All @@ -22,12 +22,12 @@ def _rotated_curvilinear_grid():

return xr.Dataset(
{
"data_g": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
"data_c": (["ZC", "YC", "XC"], np.random.rand(3 * N, 2 * N, N)),
"U (A grid)": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
"V (A grid)": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
"U (C grid)": (["ZG", "YC", "XG"], np.random.rand(3 * N, 2 * N, N)),
"V (C grid)": (["ZG", "YG", "XC"], np.random.rand(3 * N, 2 * N, N)),
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, 3 * N, 2 * N, N)),
"U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
"V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
"U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
"V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, 3 * N, 2 * N, N)),
},
coords={
"XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}),
Expand All @@ -45,7 +45,7 @@ def _rotated_curvilinear_grid():
{"axis": "Z"},
),
"depth": (["ZG"], np.arange(3 * N), {"axis": "Z"}),
"time": (["time"], np.arange(T), {"axis": "T"}),
"time": (["time"], xr.date_range("2000", "2001", T), {"axis": "T"}),
"lon": (
["YG", "XG"],
LON,
Expand Down Expand Up @@ -97,12 +97,12 @@ def _unrolled_cone_curvilinear_grid():

return xr.Dataset(
{
"data_g": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
"data_c": (["ZC", "YC", "XC"], np.random.rand(3 * N, 2 * N, N)),
"U (A grid)": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
"V (A grid)": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
"U (C grid)": (["ZG", "YC", "XG"], np.random.rand(3 * N, 2 * N, N)),
"V (C grid)": (["ZG", "YG", "XC"], np.random.rand(3 * N, 2 * N, N)),
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, 3 * N, 2 * N, N)),
"U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
"V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
"U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
"V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, 3 * N, 2 * N, N)),
},
coords={
"XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}),
Expand All @@ -120,7 +120,7 @@ def _unrolled_cone_curvilinear_grid():
{"axis": "Z"},
),
"depth": (["ZG"], np.arange(3 * N), {"axis": "Z"}),
"time": (["time"], np.arange(T), {"axis": "T"}),
"time": (["time"], xr.date_range("2000", "2001", T), {"axis": "T"}),
"lon": (
["YG", "XG"],
LON,
Expand All @@ -141,10 +141,10 @@ def _unrolled_cone_curvilinear_grid():
{
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, 3 * N, 2 * N, N)),
"U (A grid)": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
"V (A grid)": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
"U (C grid)": (["ZG", "YC", "XG"], np.random.rand(3 * N, 2 * N, N)),
"V (C grid)": (["ZG", "YG", "XC"], np.random.rand(3 * N, 2 * N, N)),
"U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
"V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
"U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
"V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, 3 * N, 2 * N, N)),
},
coords={
"XG": (
Expand Down Expand Up @@ -176,7 +176,7 @@ def _unrolled_cone_curvilinear_grid():
"lon": (["XG"], 2 * np.pi / N * np.arange(0, N)),
"lat": (["YG"], 2 * np.pi / (2 * N) * np.arange(0, 2 * N)),
"depth": (["ZG"], np.arange(3 * N)),
"time": (["time"], np.arange(T), {"axis": "T"}),
"time": (["time"], xr.date_range("2000", "2001", T), {"axis": "T"}),
},
),
"2d_left_unrolled_cone": _unrolled_cone_curvilinear_grid(),
Expand Down
15 changes: 11 additions & 4 deletions parcels/field.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from __future__ import annotations

import inspect
import warnings
from collections.abc import Callable
from datetime import datetime
from enum import IntEnum
from typing import TYPE_CHECKING

import numpy as np
import uxarray as ux
import xarray as xr
from uxarray.grid.neighbors import _barycentric_coordinates

from parcels._core.utils.time import TimeInterval
from parcels._core.utils.unstructured import get_vertical_location_from_dims
from parcels._reprs import default_repr, field_repr
from parcels._typing import (
Expand All @@ -33,9 +35,6 @@

from ._index_search import _search_indices_rectilinear, _search_time_index

if TYPE_CHECKING:
pass

__all__ = ["Field", "GridType", "VectorField"]


Expand Down Expand Up @@ -167,6 +166,7 @@ def __init__(
self.name = name
self.data = data
self.grid = grid
self.time_interval = get_time_interval(data)

# For compatibility with parts of the codebase that rely on v3 definition of Grid.
# Should be worked to be removed in v4
Expand Down Expand Up @@ -663,3 +663,10 @@ def _assert_compatible_combination(data: xr.DataArray | ux.UxDataArray, grid: ux
raise ValueError(
f"Incompatible data-grid combination. Data is a xarray.DataArray, expected `grid` to be a parcels Grid object, got {type(grid)}."
)


def get_time_interval(data: xr.DataArray | ux.UxDataArray) -> TimeInterval | None:
if "time" not in data.dims:
return None

return TimeInterval(data.time.values[0], data.time.values[-1])
19 changes: 16 additions & 3 deletions tests/v4/test_field.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest
import uxarray as ux
import xarray as xr
Expand Down Expand Up @@ -69,6 +70,21 @@ def test_field_structured_grid_creation(data, grid):
assert field.grid == grid


@pytest.mark.parametrize(
"data,grid",
[
pytest.param(
structured_datasets["ds_2d_left"]["data_g"], Grid(structured_datasets["ds_2d_left"]), id="ds_2d_left"
),
],
)
def test_field_time_interval(data, grid):
"""Test creating a field."""
field = Field(name="test_field", data=data, grid=grid, mesh_type="flat")
assert field.time_interval.left == np.datetime64("2000-01-01")
assert field.time_interval.right == np.datetime64("2001-01-01")


def test_field_unstructured_grid_creation(): ...


Expand All @@ -79,6 +95,3 @@ def test_field_interpolation_out_of_spatial_bounds(): ...


def test_field_interpolation_out_of_time_bounds(): ...


def test_field_allow_time_extrapolation(): ...
18 changes: 17 additions & 1 deletion tests/v4/test_fieldset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from datetime import timedelta

import numpy as np
import pytest

from parcels._datasets.structured.generic import datasets as datasets_structured
Expand Down Expand Up @@ -70,4 +73,17 @@ def test_fieldset_gridset_size(fieldset):
assert fieldset.gridset_size == 1


def test_fieldset_executable_domain(): ...
def test_fieldset_time_interval():
grid1 = Grid(ds)
field1 = Field("field1", ds["U (A grid)"], grid1, mesh_type="flat")

ds2 = ds.copy()
ds2["time"] = ds2["time"] + np.timedelta64(timedelta(days=1))
grid2 = Grid(ds2)
field2 = Field("field2", ds2["U (A grid)"], grid2, mesh_type="flat")

fieldset = FieldSet([field1, field2])
fieldset.add_constant_field("constant_field", 1.0)

assert fieldset.time_interval.left == np.datetime64("2000-01-02")
assert fieldset.time_interval.right == np.datetime64("2001-01-01")
26 changes: 18 additions & 8 deletions tests/v4/utils/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,38 @@

from datetime import datetime, timedelta

import numpy as np
import pytest
from cftime import datetime as cftime_datetime
from hypothesis import given
from hypothesis import strategies as st

from parcels._core.utils.time import TimeInterval

calendar_strategy = st.sampled_from(["gregorian", "proleptic_gregorian", "365_day", "360_day", "julian", "366_day"])
calendar_strategy = st.sampled_from(
["gregorian", "proleptic_gregorian", "365_day", "360_day", "julian", "366_day", np.datetime64, datetime]
)


@st.composite
def cftime_datetime_strategy(draw, calendar=None):
def datetime_strategy(draw, calendar=None):
year = draw(st.integers(1900, 2100))
month = draw(st.integers(1, 12))
day = draw(st.integers(1, 28))
if calendar is None:
calendar = draw(calendar_strategy)
if calendar is datetime:
return datetime(year, month, day)
if calendar is np.datetime64:
return np.datetime64(datetime(year, month, day))

return cftime_datetime(year, month, day, calendar=calendar)


@st.composite
def cftime_interval_strategy(draw, left=None, calendar=None):
def time_interval_strategy(draw, left=None, calendar=None):
if left is None:
left = draw(cftime_datetime_strategy(calendar=calendar))
left = draw(datetime_strategy(calendar=calendar))
right = left + draw(
st.timedeltas(
min_value=timedelta(seconds=1),
Expand All @@ -41,6 +49,8 @@ def cftime_interval_strategy(draw, left=None, calendar=None):
(cftime_datetime(2023, 1, 1, calendar="gregorian"), cftime_datetime(2023, 1, 2, calendar="gregorian")),
(cftime_datetime(2023, 6, 1, calendar="365_day"), cftime_datetime(2023, 6, 2, calendar="365_day")),
(cftime_datetime(2023, 12, 1, calendar="360_day"), cftime_datetime(2023, 12, 2, calendar="360_day")),
(datetime(2023, 12, 1), datetime(2023, 12, 2)),
(np.datetime64(datetime(2023, 12, 1)), np.datetime64(datetime(2023, 12, 2))),
],
)
def test_time_interval_initialization(left, right):
Expand All @@ -53,7 +63,7 @@ def test_time_interval_initialization(left, right):
TimeInterval(right, left)


@given(cftime_interval_strategy())
@given(time_interval_strategy())
def test_time_interval_contains(interval):
left = interval.left
right = interval.right
Expand All @@ -64,12 +74,12 @@ def test_time_interval_contains(interval):
assert middle in interval


@given(cftime_interval_strategy(calendar="365_day"), cftime_interval_strategy(calendar="365_day"))
@given(time_interval_strategy(calendar="365_day"), time_interval_strategy(calendar="365_day"))
def test_time_interval_intersection_commutative(interval1, interval2):
assert interval1.intersection(interval2) == interval2.intersection(interval1)


@given(cftime_interval_strategy())
@given(time_interval_strategy())
def test_time_interval_intersection_with_self(interval):
assert interval.intersection(interval) == interval

Expand All @@ -81,7 +91,7 @@ def test_time_interval_repr():
assert repr(interval) == expected


@given(cftime_interval_strategy())
@given(time_interval_strategy())
def test_time_interval_equality(interval):
assert interval == interval

Expand Down