diff --git a/parcels/_datasets/structured/generic.py b/parcels/_datasets/structured/generic.py index 92fa51ba25..4260eefab8 100644 --- a/parcels/_datasets/structured/generic.py +++ b/parcels/_datasets/structured/generic.py @@ -24,6 +24,10 @@ def _rotated_curvilinear_grid(): { "data_g": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)), "data_c": (["ZC", "YC", "XC"], np.random.rand(3 * N, 2 * N, N)), + "U (A grid)": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)), + "V (A grid)": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)), + "U (C grid)": (["ZG", "YC", "XG"], np.random.rand(3 * N, 2 * N, N)), + "V (C grid)": (["ZG", "YG", "XC"], np.random.rand(3 * N, 2 * N, N)), }, coords={ "XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}), @@ -95,6 +99,10 @@ def _unrolled_cone_curvilinear_grid(): { "data_g": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)), "data_c": (["ZC", "YC", "XC"], np.random.rand(3 * N, 2 * N, N)), + "U (A grid)": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)), + "V (A grid)": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)), + "U (C grid)": (["ZG", "YC", "XG"], np.random.rand(3 * N, 2 * N, N)), + "V (C grid)": (["ZG", "YG", "XC"], np.random.rand(3 * N, 2 * N, N)), }, coords={ "XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}), @@ -133,6 +141,10 @@ def _unrolled_cone_curvilinear_grid(): { "data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)), "data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, 3 * N, 2 * N, N)), + "U (A grid)": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)), + "V (A grid)": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)), + "U (C grid)": (["ZG", "YC", "XG"], np.random.rand(3 * N, 2 * N, N)), + "V (C grid)": (["ZG", "YG", "XC"], np.random.rand(3 * N, 2 * N, N)), }, coords={ "XG": ( diff --git a/parcels/_reprs.py b/parcels/_reprs.py new file mode 100644 index 0000000000..fe77d4ecb0 --- /dev/null +++ b/parcels/_reprs.py @@ -0,0 +1,84 @@ +"""Parcels reprs""" + +from __future__ import annotations + +import textwrap +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from parcels import Field, FieldSet, ParticleSet + + +def field_repr(field: Field) -> str: + """Return a pretty repr for Field""" + out = f"""<{type(field).__name__}> + name : {field.name!r} + data : {field.data!r} + extrapolate time: {field.allow_time_extrapolation!r} +""" + return textwrap.dedent(out).strip() + + +def _format_list_items_multiline(items: list[str], level: int = 1) -> str: + """Given a list of strings, formats them across multiple lines. + + Uses indentation levels of 4 spaces provided by ``level``. + + Example + ------- + >>> output = _format_list_items_multiline(["item1", "item2", "item3"], 4) + >>> f"my_items: {output}" + my_items: [ + item1, + item2, + item3, + ] + """ + if len(items) == 0: + return "[]" + + assert level >= 1, "Indentation level >=1 supported" + indentation_str = level * 4 * " " + indentation_str_end = (level - 1) * 4 * " " + + items_str = ",\n".join([textwrap.indent(i, indentation_str) for i in items]) + return f"[\n{items_str}\n{indentation_str_end}]" + + +def particleset_repr(pset: ParticleSet) -> str: + """Return a pretty repr for ParticleSet""" + if len(pset) < 10: + particles = [repr(p) for p in pset] + else: + particles = [repr(pset[i]) for i in range(7)] + ["..."] + + out = f"""<{type(pset).__name__}> + fieldset : +{textwrap.indent(repr(pset.fieldset), " " * 8)} + pclass : {pset.pclass} + repeatdt : {pset.repeatdt} + # particles: {len(pset)} + particles : {_format_list_items_multiline(particles, level=2)} +""" + return textwrap.dedent(out).strip() + + +def fieldset_repr(fieldset: FieldSet) -> str: + """Return a pretty repr for FieldSet""" + fields_repr = "\n".join([repr(f) for f in fieldset.get_fields()]) + + out = f"""<{type(fieldset).__name__}> + fields: +{textwrap.indent(fields_repr, 8 * " ")} +""" + return textwrap.dedent(out).strip() + + +def default_repr(obj: Any): + if is_builtin_object(obj): + return repr(obj) + return object.__repr__(obj) + + +def is_builtin_object(obj): + return obj.__class__.__module__ == "builtins" diff --git a/parcels/field.py b/parcels/field.py index 613a88ac82..42fd157396 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -11,12 +11,12 @@ from uxarray.grid.neighbors import _barycentric_coordinates from parcels._core.utils.unstructured import get_vertical_location_from_dims +from parcels._reprs import default_repr, field_repr from parcels._typing import ( Mesh, VectorType, assert_valid_mesh, ) -from parcels.tools._helpers import default_repr, field_repr from parcels.tools.converters import ( UnitConverter, unitconverters_map, @@ -185,7 +185,6 @@ def __init__( e.add_note(f"Error validating field {name!r}.") raise e - self._parent_mesh = data.attrs["mesh"] self._mesh_type = mesh_type # Setting the interpolation method dynamically diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 26543e9e8c..ba76526794 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -3,9 +3,10 @@ import numpy as np import xarray as xr +from parcels._reprs import fieldset_repr from parcels._typing import Mesh from parcels.field import Field, VectorField -from parcels.tools._helpers import fieldset_repr +from parcels.v4.grid import Grid __all__ = ["FieldSet"] @@ -41,10 +42,21 @@ class FieldSet: """ def __init__(self, fields: list[Field | VectorField]): - # TODO Nick : Enforce fields to be list of Field or VectorField objects + for field in fields: + if not isinstance(field, (Field, VectorField)): + raise ValueError(f"Expected `field` to be a Field or VectorField object. Got {field}") + self.fields = {f.name: f for f in fields} + self.constants = {} - # TODO : Nick : Add _getattr_ magic method to allow access to fields by name + def __getattr__(self, name): + """Get the field by name. If the field is not found, check if it's a constant.""" + if name in self.fields: + return self.fields[name] + elif name in self.constants: + return self.constants[name] + else: + raise AttributeError(f"FieldSet has no attribute '{name}'") @property def time_interval(self): @@ -112,6 +124,9 @@ def add_field(self, field: Field, name: str | None = None): * `Unit converters <../examples/tutorial_unitconverters.ipynb>`__ (Default value = None) """ + if not isinstance(field, (Field, VectorField)): + raise ValueError(f"Expected `field` to be a Field or VectorField object. Got {type(field)}") + name = field.name if name is None else name if name in self.fields: @@ -137,21 +152,25 @@ def add_constant_field(self, name: str, value, mesh: Mesh = "flat"): correction for zonal velocity U near the poles. 2. flat: No conversion, lat/lon are assumed to be in m. """ - time = 0.0 - values = np.full((1, 1, 1, 1), value) - data = xr.DataArray( - data=values, - name=name, - dims="null", - coords=[time, [0], [0], [0]], - attrs=dict(description="null", units="null", location="node", mesh="constant", mesh_type=mesh), + da = xr.DataArray( + data=np.full((1, 1, 1, 1), value), + dims=["T", "ZG", "YG", "XG"], + coords={ + "ZG": (["ZG"], np.arange(1), {"axis": "Z"}), + "YG": (["YG"], np.arange(1), {"axis": "Y"}), + "XG": (["XG"], np.arange(1), {"axis": "X"}), + "lon": (["XG"], np.arange(1), {"axis": "X"}), + "lat": (["YG"], np.arange(1), {"axis": "Y"}), + "depth": (["ZG"], np.arange(1), {"axis": "Z"}), + }, ) + grid = Grid(da) self.add_field( Field( name, - data, + da, + grid, interp_method=None, # TODO : Need to define an interpolation method for constants - allow_time_extrapolation=True, ) ) @@ -185,7 +204,10 @@ def add_constant(self, name, value): `Diffusion <../examples/tutorial_diffusion.ipynb>`__ `Periodic boundaries <../examples/tutorial_periodic_boundaries.ipynb>`__ """ - setattr(self, name, value) + if name in self.constants: + raise ValueError(f"FieldSet already has a constant with name '{name}'") + + self.constants[name] = np.float32(value) # def computeTimeChunk(self, time=0.0, dt=1): # """Load a chunk of three data time steps into the FieldSet. diff --git a/parcels/particlefile.py b/parcels/particlefile.py index 7926caef93..511fb6d395 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -10,7 +10,8 @@ import parcels from parcels._compat import MPI -from parcels.tools._helpers import default_repr, timedelta_to_float +from parcels._reprs import default_repr +from parcels.tools._helpers import timedelta_to_float from parcels.tools.warnings import FileWarning __all__ = ["ParticleFile"] diff --git a/parcels/particleset.py b/parcels/particleset.py index a2dd9f7a75..fa09d48451 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -11,6 +11,7 @@ from tqdm import tqdm from parcels._compat import MPI +from parcels._reprs import particleset_repr from parcels.application_kernels.advection import AdvectionRK4 from parcels.field import Field from parcels.grid import GridType @@ -25,7 +26,7 @@ from parcels.particle import Particle, Variable from parcels.particledata import ParticleData, ParticleDataIterator from parcels.particlefile import ParticleFile -from parcels.tools._helpers import particleset_repr, timedelta_to_float +from parcels.tools._helpers import timedelta_to_float from parcels.tools.converters import _get_cftime_calendars, convert_to_flat_array from parcels.tools.loggers import logger from parcels.tools.statuscodes import StatusCode diff --git a/parcels/tools/_helpers.py b/parcels/tools/_helpers.py index 74f4bdfb33..4e3500e89e 100644 --- a/parcels/tools/_helpers.py +++ b/parcels/tools/_helpers.py @@ -3,17 +3,12 @@ from __future__ import annotations import functools -import textwrap import warnings from collections.abc import Callable from datetime import timedelta -from typing import TYPE_CHECKING, Any import numpy as np -if TYPE_CHECKING: - from parcels import Field, FieldSet, ParticleSet - PACKAGE = "Parcels" @@ -68,81 +63,6 @@ def patch_docstring(obj: Callable, extra: str) -> None: obj.__doc__ = f"{obj.__doc__ or ''}{extra}".strip() -def field_repr(field: Field) -> str: - """Return a pretty repr for Field""" - out = f"""<{type(field).__name__}> - name : {field.name!r} - data : {field.data!r} - extrapolate time: {field.allow_time_extrapolation!r} -""" - return textwrap.dedent(out).strip() - - -def _format_list_items_multiline(items: list[str], level: int = 1) -> str: - """Given a list of strings, formats them across multiple lines. - - Uses indentation levels of 4 spaces provided by ``level``. - - Example - ------- - >>> output = _format_list_items_multiline(["item1", "item2", "item3"], 4) - >>> f"my_items: {output}" - my_items: [ - item1, - item2, - item3, - ] - """ - if len(items) == 0: - return "[]" - - assert level >= 1, "Indentation level >=1 supported" - indentation_str = level * 4 * " " - indentation_str_end = (level - 1) * 4 * " " - - items_str = ",\n".join([textwrap.indent(i, indentation_str) for i in items]) - return f"[\n{items_str}\n{indentation_str_end}]" - - -def particleset_repr(pset: ParticleSet) -> str: - """Return a pretty repr for ParticleSet""" - if len(pset) < 10: - particles = [repr(p) for p in pset] - else: - particles = [repr(pset[i]) for i in range(7)] + ["..."] - - out = f"""<{type(pset).__name__}> - fieldset : -{textwrap.indent(repr(pset.fieldset), " " * 8)} - pclass : {pset.pclass} - repeatdt : {pset.repeatdt} - # particles: {len(pset)} - particles : {_format_list_items_multiline(particles, level=2)} -""" - return textwrap.dedent(out).strip() - - -def fieldset_repr(fieldset: FieldSet) -> str: - """Return a pretty repr for FieldSet""" - fields_repr = "\n".join([repr(f) for f in fieldset.get_fields()]) - - out = f"""<{type(fieldset).__name__}> - fields: -{textwrap.indent(fields_repr, 8 * " ")} -""" - return textwrap.dedent(out).strip() - - -def default_repr(obj: Any): - if is_builtin_object(obj): - return repr(obj) - return object.__repr__(obj) - - -def is_builtin_object(obj): - return obj.__class__.__module__ == "builtins" - - def timedelta_to_float(dt: float | timedelta | np.timedelta64) -> float: """Convert a timedelta to a float in seconds.""" if isinstance(dt, timedelta): diff --git a/tests/v4/test_field.py b/tests/v4/test_field.py index f0e0313009..8ca612bc90 100644 --- a/tests/v4/test_field.py +++ b/tests/v4/test_field.py @@ -54,10 +54,9 @@ def test_field_incompatible_combination(data, grid): [ pytest.param( structured_datasets["ds_2d_left"]["data_g"], Grid(structured_datasets["ds_2d_left"]), id="ds_2d_left" - ), + ), # TODO: Perhaps this test should be expanded to cover more datasets? ], ) -@pytest.mark.xfail(reason="Structured grid creation is not implemented yet") def test_field_structured_grid_creation(data, grid): """Test creating a field.""" field = Field( @@ -66,15 +65,10 @@ def test_field_structured_grid_creation(data, grid): grid=grid, ) assert field.name == "test_field" - assert field.data == data + assert field.data.equals(data) assert field.grid == grid -def test_field_structured_grid_creation_spherical(): - # Field(..., mesh_type="spherical") - ... - - def test_field_unstructured_grid_creation(): ... diff --git a/tests/v4/test_fieldset.py b/tests/v4/test_fieldset.py new file mode 100644 index 0000000000..cfe6c1ca5e --- /dev/null +++ b/tests/v4/test_fieldset.py @@ -0,0 +1,73 @@ +import pytest + +from parcels._datasets.structured.generic import datasets as datasets_structured +from parcels.field import Field, VectorField +from parcels.fieldset import FieldSet +from parcels.v4.grid import Grid + +ds = datasets_structured["ds_2d_left"] + + +@pytest.fixture +def fieldset() -> FieldSet: + """Fixture to create a FieldSet object for testing.""" + grid = Grid(ds) + U = Field("U", ds["U (A grid)"], grid, mesh_type="flat") + V = Field("V", ds["V (A grid)"], grid, mesh_type="flat") + UV = VectorField("UV", U, V) + + return FieldSet( + [U, V, UV], + ) + + +def test_fieldset_init_wrong_types(): + with pytest.raises(ValueError, match="Expected `field` to be a Field or VectorField object. Got .*"): + FieldSet([1.0, 2.0, 3.0]) + + +def test_fieldset_add_constant(fieldset): + fieldset.add_constant("test_constant", 1.0) + assert fieldset.test_constant == 1.0 + + +def test_fieldset_add_constant_field(fieldset): + fieldset.add_constant_field("test_constant_field", 1.0) + + # Get a point in the domain + time = ds["time"].mean() + depth = ds["depth"].mean() + lat = ds["lat"].mean() + lon = ds["lon"].mean() + + pytest.xfail(reason="Not yet implemented interpolation.") + assert fieldset.test_constant_field[time, depth, lat, lon] == 1.0 + + +def test_fieldset_add_field(fieldset): + grid = Grid(ds) + field = Field("test_field", ds["U (A grid)"], grid, mesh_type="flat") + fieldset.add_field(field) + assert fieldset.test_field == field + + +def test_fieldset_add_field_wrong_type(fieldset): + not_a_field = 1.0 + with pytest.raises(ValueError, match="Expected `field` to be a Field or VectorField object. Got .*"): + fieldset.add_field(not_a_field, "test_field") + + +def test_fieldset_add_field_already_exists(fieldset): + grid = Grid(ds) + field = Field("test_field", ds["U (A grid)"], grid, mesh_type="flat") + fieldset.add_field(field, "test_field") + with pytest.raises(ValueError, match="FieldSet already has a Field with name 'test_field'"): + fieldset.add_field(field, "test_field") + + +@pytest.mark.xfail(reason="FieldSet doesn't yet correctly handle duplicate grids.") +def test_fieldset_gridset_size(fieldset): + assert fieldset.gridset_size == 1 + + +def test_fieldset_executable_domain(): ... diff --git a/tests/v4/test_uxarray_fieldset.py b/tests/v4/test_uxarray_fieldset.py index 521aab87ca..75bc24a02f 100644 --- a/tests/v4/test_uxarray_fieldset.py +++ b/tests/v4/test_uxarray_fieldset.py @@ -52,16 +52,16 @@ def uvw_fesom_channel(ds_fesom_channel) -> VectorField: def test_fesom_fieldset(ds_fesom_channel, uv_fesom_channel): fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V]) # Check that the fieldset has the expected properties - assert (fieldset.fields["U"] == ds_fesom_channel.U).all() - assert (fieldset.fields["V"] == ds_fesom_channel.V).all() + assert (fieldset.U == ds_fesom_channel.U).all() + assert (fieldset.V == ds_fesom_channel.V).all() @pytest.mark.skip(reason="ParticleSet.__init__ needs major refactoring") def test_fesom_in_particleset(ds_fesom_channel, uv_fesom_channel): fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V]) # Check that the fieldset has the expected properties - assert (fieldset.fields["U"] == ds_fesom_channel.U).all() - assert (fieldset.fields["V"] == ds_fesom_channel.V).all() + assert (fieldset.U == ds_fesom_channel.U).all() + assert (fieldset.V == ds_fesom_channel.V).all() pset = ParticleSet(fieldset, pclass=Particle) assert pset.fieldset == fieldset @@ -69,12 +69,12 @@ def test_fesom_in_particleset(ds_fesom_channel, uv_fesom_channel): def test_set_interp_methods(ds_fesom_channel, uv_fesom_channel): fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V]) # Check that the fieldset has the expected properties - assert (fieldset.fields["U"] == ds_fesom_channel.U).all() - assert (fieldset.fields["V"] == ds_fesom_channel.V).all() + assert (fieldset.U == ds_fesom_channel.U).all() + assert (fieldset.V == ds_fesom_channel.V).all() # Set the interpolation method for each field - fieldset.fields["U"].interp_method = UXPiecewiseConstantFace - fieldset.fields["V"].interp_method = UXPiecewiseConstantFace + fieldset.U.interp_method = UXPiecewiseConstantFace + fieldset.V.interp_method = UXPiecewiseConstantFace @pytest.mark.skip(reason="ParticleSet.__init__ needs major refactoring") @@ -82,9 +82,9 @@ def test_fesom_channel(ds_fesom_channel, uvw_fesom_channel): fieldset = FieldSet([uvw_fesom_channel, uvw_fesom_channel.U, uvw_fesom_channel.V, uvw_fesom_channel.W]) # Check that the fieldset has the expected properties - assert (fieldset.fields["U"] == ds_fesom_channel.U).all() - assert (fieldset.fields["V"] == ds_fesom_channel.V).all() - assert (fieldset.fields["W"] == ds_fesom_channel.W).all() + assert (fieldset.U == ds_fesom_channel.U).all() + assert (fieldset.V == ds_fesom_channel.V).all() + assert (fieldset.W == ds_fesom_channel.W).all() pset = ParticleSet(fieldset, pclass=Particle) pset.execute(endtime=timedelta(days=1), dt=timedelta(hours=1))