Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion parcels/_core/utils/time.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

from datetime import datetime
from typing import TypeVar
from typing import TYPE_CHECKING, TypeVar

import cftime
import numpy as np

T = TypeVar("T", datetime, cftime.datetime)

if TYPE_CHECKING:
from parcels._typing import DatetimeLike


class TimeInterval:
"""A class representing a time interval between two datetime objects.
Expand Down Expand Up @@ -70,3 +73,28 @@ def is_compatible(t1: datetime | cftime.datetime, t2: datetime | cftime.datetime
return False
else:
return True


def get_datetime_type_calendar(
example_datetime: DatetimeLike,
) -> tuple[type, str | None]:
"""Get the type and calendar of a datetime object.

Parameters
----------
example_datetime : datetime, cftime.datetime, or np.datetime64
The datetime object to check.

Returns
-------
tuple[type, str | None]
A tuple containing the type of the datetime object and its calendar.
The calendar will be None if the datetime object is not a cftime datetime object.
"""
calendar = None
try:
calendar = example_datetime.calendar
except AttributeError:
# datetime isn't a cftime datetime object
pass
return type(example_datetime), calendar
6 changes: 5 additions & 1 deletion parcels/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@

import os
from collections.abc import Callable
from datetime import datetime
from typing import Any, Literal, get_args

import numpy as np
from cftime import datetime as cftime_datetime

InterpMethodOption = Literal[
"linear",
"nearest",
Expand All @@ -30,7 +34,7 @@
VectorType = Literal["3D", "3DSigma", "2D"] | None # corresponds with `vector_type`
GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo", "croco"] # corresponds with `gridindexingtype`
NetcdfEngine = Literal["netcdf4", "xarray", "scipy"]

DatetimeLike = datetime | cftime_datetime | np.datetime64

KernelFunction = Callable[..., None]

Expand Down
28 changes: 27 additions & 1 deletion parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,13 @@ def __init__(
self.name = name
self.data = data
self.grid = grid
self.time_interval = get_time_interval(data)
try:
self.time_interval = get_time_interval(data)
except ValueError as e:
e.add_note(
f"Error getting time interval for field {name!r}. Are you sure that the time dimension on the xarray dataset is stored as datetime or cftime datetime objects?"
)
raise e
Comment on lines +169 to +175
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nicer error message ;)


# For compatibility with parts of the codebase that rely on v3 definition of Grid.
# Should be worked to be removed in v4
Expand Down Expand Up @@ -531,6 +537,13 @@ def __init__(
self.V = V
self.W = W

if W is None:
assert_same_time_interval((U, V))
else:
assert_same_time_interval((U, V, W))

self.time_interval = U.time_interval

if self.W:
self.vector_type = "3D"
else:
Expand Down Expand Up @@ -670,3 +683,16 @@ def get_time_interval(data: xr.DataArray | ux.UxDataArray) -> TimeInterval | Non
return None

return TimeInterval(data.time.values[0], data.time.values[-1])


def assert_same_time_interval(fields: list[Field]) -> None:
if len(fields) == 0:
return

reference_time_interval = fields[0].time_interval

for field in fields[1:]:
if field.time_interval != reference_time_interval:
raise ValueError(
f"Fields must have the same time domain. {fields[0].name}: {reference_time_interval}, {field.name}: {field.time_interval}"
)
38 changes: 38 additions & 0 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from __future__ import annotations

import functools
from collections.abc import Iterable
from typing import TYPE_CHECKING

import numpy as np
import xarray as xr

from parcels._core.utils.time import get_datetime_type_calendar
from parcels._core.utils.time import is_compatible as datetime_is_compatible
from parcels._reprs import fieldset_repr
from parcels._typing import Mesh
from parcels.field import Field, VectorField
from parcels.v4.grid import Grid

if TYPE_CHECKING:
from parcels._typing import DatetimeLike
__all__ = ["FieldSet"]


Expand Down Expand Up @@ -45,6 +53,7 @@ def __init__(self, fields: list[Field | VectorField]):
for field in fields:
if not isinstance(field, (Field, VectorField)):
raise ValueError(f"Expected `field` to be a Field or VectorField object. Got {field}")
assert_compatible_calendars(fields)

self.fields = {f.name: f for f in fields}
self.constants = {}
Expand Down Expand Up @@ -126,6 +135,7 @@ def add_field(self, field: Field, name: str | None = None):
"""
if not isinstance(field, (Field, VectorField)):
raise ValueError(f"Expected `field` to be a Field or VectorField object. Got {type(field)}")
assert_compatible_calendars((*self.fields.values(), field))

name = field.name if name is None else name

Expand Down Expand Up @@ -235,3 +245,31 @@ def add_constant(self, name, value):
# return nextTime
# else:
# return time + nSteps * dt


class CalendarError(Exception): # TODO: Move to a parcels errors module
"""Exception raised when the calendar of a field is not compatible with the rest of the Fields. The user should ensure that they only add fields to a FieldSet that have compatible CFtime calendars."""


def assert_compatible_calendars(fields: Iterable[Field]):
time_intervals = [f.time_interval for f in fields if f.time_interval is not None]
reference_datetime_object = time_intervals[0].left

for field in fields:
if field.time_interval is None:
continue

if not datetime_is_compatible(reference_datetime_object, field.time_interval.left):
msg = format_calendar_error_message(field, reference_datetime_object)
raise CalendarError(msg)


def format_calendar_error_message(field: Field, reference_datetime: DatetimeLike) -> str:
def datetime_to_msg(example_datetime: DatetimeLike) -> str:
datetime_type, calendar = get_datetime_type_calendar(example_datetime)
msg = str(datetime_type)
if calendar is not None:
msg += f" with cftime calendar {calendar}'"
return msg

return f"Expected field {field.name!r} to have calendar compatible with datetime object {datetime_to_msg(reference_datetime)}. Got field with calendar {datetime_to_msg(field.time_interval.left)}. Have you considered using xarray to update the time dimension of the dataset to have a compatible calendar?"
37 changes: 31 additions & 6 deletions tests/v4/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import xarray as xr

from parcels import Field
from parcels._datasets.structured.generic import datasets as structured_datasets
from parcels._datasets.unstructured.generic import datasets as unstructured_datasets
from parcels._datasets.structured.generic import T as T_structured
from parcels._datasets.structured.generic import datasets as datasets_structured
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
from parcels.v4.grid import Grid


Expand Down Expand Up @@ -36,7 +37,7 @@ def test_field_init_param_types():
pytest.param(ux.UxDataArray(), Grid(xr.Dataset()), id="uxdata-grid"),
pytest.param(
xr.DataArray(),
unstructured_datasets["stommel_gyre_delaunay"].uxgrid,
datasets_unstructured["stommel_gyre_delaunay"].uxgrid,
id="xarray-uxgrid",
),
],
Expand All @@ -54,11 +55,11 @@ def test_field_incompatible_combination(data, grid):
"data,grid",
[
pytest.param(
structured_datasets["ds_2d_left"]["data_g"], Grid(structured_datasets["ds_2d_left"]), id="ds_2d_left"
datasets_structured["ds_2d_left"]["data_g"], Grid(datasets_structured["ds_2d_left"]), id="ds_2d_left"
), # TODO: Perhaps this test should be expanded to cover more datasets?
],
)
def test_field_structured_grid_creation(data, grid):
def test_field_init_structured_grid(data, grid):
"""Test creating a field."""
field = Field(
name="test_field",
Expand All @@ -70,11 +71,30 @@ def test_field_structured_grid_creation(data, grid):
assert field.grid == grid


@pytest.mark.parametrize("numpy_dtype", ["timedelta64[s]", "float64"])
def test_field_init_fail_on_bad_time_type(numpy_dtype):
"""Tests that field initialisation fails when the time isn't given as datetime object (i.e., is float or timedelta)."""
ds = datasets_structured["ds_2d_left"].copy()
ds["time"] = np.arange(0, T_structured, dtype=numpy_dtype)

data = ds["data_g"]
grid = Grid(ds)
with pytest.raises(
ValueError,
match="Error getting time interval.*. Are you sure that the time dimension on the xarray dataset is stored as datetime or cftime datetime objects\?",
):
Field(
name="test_field",
data=data,
grid=grid,
)


@pytest.mark.parametrize(
"data,grid",
[
pytest.param(
structured_datasets["ds_2d_left"]["data_g"], Grid(structured_datasets["ds_2d_left"]), id="ds_2d_left"
datasets_structured["ds_2d_left"]["data_g"], Grid(datasets_structured["ds_2d_left"]), id="ds_2d_left"
),
],
)
Expand All @@ -85,6 +105,11 @@ def test_field_time_interval(data, grid):
assert field.time_interval.right == np.datetime64("2001-01-01")


def test_vectorfield_init_different_time_intervals():
# Tests that a VectorField raises a ValueError if the component fields have different time domains.
...


def test_field_unstructured_grid_creation(): ...


Expand Down
30 changes: 30 additions & 0 deletions tests/v4/test_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import numpy as np
import pytest
import xarray as xr

from parcels._datasets.structured.generic import T as T_structured
from parcels._datasets.structured.generic import datasets as datasets_structured
from parcels.field import Field, VectorField
from parcels.fieldset import FieldSet
Expand Down Expand Up @@ -87,3 +89,31 @@ def test_fieldset_time_interval():

assert fieldset.time_interval.left == np.datetime64("2000-01-02")
assert fieldset.time_interval.right == np.datetime64("2001-01-01")


def test_fieldset_init_incompatible_calendars():
ds1 = ds.copy()
ds1["time"] = xr.date_range("2000", "2001", T_structured, calendar="365_day", use_cftime=True)

grid = Grid(ds1)
U = Field("U", ds1["U (A grid)"], grid, mesh_type="flat")
V = Field("V", ds1["V (A grid)"], grid, mesh_type="flat")
UV = VectorField("UV", U, V)

ds2 = ds.copy()
ds2["time"] = xr.date_range("2000", "2001", T_structured, calendar="360_day", use_cftime=True)
grid2 = Grid(ds2)
incompatible_calendar = Field("test", ds2["data_g"], grid2, mesh_type="flat")

with pytest.raises(ValueError):
FieldSet([U, V, UV, incompatible_calendar])


def test_fieldset_add_field_incompatible_calendars(fieldset):
ds_test = ds.copy()
ds_test["time"] = xr.date_range("2000", "2001", T_structured, calendar="360_day", use_cftime=True)
grid = Grid(ds_test)
field = Field("test_field", ds_test["data_g"], grid, mesh_type="flat")

with pytest.raises(ValueError):
fieldset.add_field(field, "test_field")
Loading