Skip to content
12 changes: 12 additions & 0 deletions parcels/_datasets/structured/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def _rotated_curvilinear_grid():
{
"data_g": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
"data_c": (["ZC", "YC", "XC"], np.random.rand(3 * N, 2 * N, N)),
"U (A grid)": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
"V (A grid)": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
"U (C grid)": (["ZG", "YC", "XG"], np.random.rand(3 * N, 2 * N, N)),
"V (C grid)": (["ZG", "YG", "XC"], np.random.rand(3 * N, 2 * N, N)),
},
coords={
"XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}),
Expand Down Expand Up @@ -95,6 +99,10 @@ def _unrolled_cone_curvilinear_grid():
{
"data_g": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
"data_c": (["ZC", "YC", "XC"], np.random.rand(3 * N, 2 * N, N)),
"U (A grid)": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
"V (A grid)": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
"U (C grid)": (["ZG", "YC", "XG"], np.random.rand(3 * N, 2 * N, N)),
"V (C grid)": (["ZG", "YG", "XC"], np.random.rand(3 * N, 2 * N, N)),
},
coords={
"XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}),
Expand Down Expand Up @@ -133,6 +141,10 @@ def _unrolled_cone_curvilinear_grid():
{
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, 3 * N, 2 * N, N)),
"U (A grid)": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
"V (A grid)": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
"U (C grid)": (["ZG", "YC", "XG"], np.random.rand(3 * N, 2 * N, N)),
"V (C grid)": (["ZG", "YG", "XC"], np.random.rand(3 * N, 2 * N, N)),
},
coords={
"XG": (
Expand Down
84 changes: 84 additions & 0 deletions parcels/_reprs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Parcels reprs"""

from __future__ import annotations

import textwrap
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from parcels import Field, FieldSet, ParticleSet

Check warning on line 9 in parcels/_reprs.py

View check run for this annotation

Codecov / codecov/patch

parcels/_reprs.py#L9

Added line #L9 was not covered by tests


def field_repr(field: Field) -> str:
"""Return a pretty repr for Field"""
out = f"""<{type(field).__name__}>

Check warning on line 14 in parcels/_reprs.py

View check run for this annotation

Codecov / codecov/patch

parcels/_reprs.py#L14

Added line #L14 was not covered by tests
name : {field.name!r}
data : {field.data!r}
extrapolate time: {field.allow_time_extrapolation!r}
"""
return textwrap.dedent(out).strip()

Check warning on line 19 in parcels/_reprs.py

View check run for this annotation

Codecov / codecov/patch

parcels/_reprs.py#L19

Added line #L19 was not covered by tests


def _format_list_items_multiline(items: list[str], level: int = 1) -> str:
"""Given a list of strings, formats them across multiple lines.

Uses indentation levels of 4 spaces provided by ``level``.

Example
-------
>>> output = _format_list_items_multiline(["item1", "item2", "item3"], 4)
>>> f"my_items: {output}"
my_items: [
item1,
item2,
item3,
]
"""
if len(items) == 0:
return "[]"

Check warning on line 38 in parcels/_reprs.py

View check run for this annotation

Codecov / codecov/patch

parcels/_reprs.py#L37-L38

Added lines #L37 - L38 were not covered by tests

assert level >= 1, "Indentation level >=1 supported"
indentation_str = level * 4 * " "
indentation_str_end = (level - 1) * 4 * " "

Check warning on line 42 in parcels/_reprs.py

View check run for this annotation

Codecov / codecov/patch

parcels/_reprs.py#L40-L42

Added lines #L40 - L42 were not covered by tests

items_str = ",\n".join([textwrap.indent(i, indentation_str) for i in items])
return f"[\n{items_str}\n{indentation_str_end}]"

Check warning on line 45 in parcels/_reprs.py

View check run for this annotation

Codecov / codecov/patch

parcels/_reprs.py#L44-L45

Added lines #L44 - L45 were not covered by tests


def particleset_repr(pset: ParticleSet) -> str:
"""Return a pretty repr for ParticleSet"""
if len(pset) < 10:
particles = [repr(p) for p in pset]

Check warning on line 51 in parcels/_reprs.py

View check run for this annotation

Codecov / codecov/patch

parcels/_reprs.py#L50-L51

Added lines #L50 - L51 were not covered by tests
else:
particles = [repr(pset[i]) for i in range(7)] + ["..."]

Check warning on line 53 in parcels/_reprs.py

View check run for this annotation

Codecov / codecov/patch

parcels/_reprs.py#L53

Added line #L53 was not covered by tests

out = f"""<{type(pset).__name__}>

Check warning on line 55 in parcels/_reprs.py

View check run for this annotation

Codecov / codecov/patch

parcels/_reprs.py#L55

Added line #L55 was not covered by tests
fieldset :
{textwrap.indent(repr(pset.fieldset), " " * 8)}
pclass : {pset.pclass}
repeatdt : {pset.repeatdt}
# particles: {len(pset)}
particles : {_format_list_items_multiline(particles, level=2)}
"""
return textwrap.dedent(out).strip()

Check warning on line 63 in parcels/_reprs.py

View check run for this annotation

Codecov / codecov/patch

parcels/_reprs.py#L63

Added line #L63 was not covered by tests


def fieldset_repr(fieldset: FieldSet) -> str:
"""Return a pretty repr for FieldSet"""
fields_repr = "\n".join([repr(f) for f in fieldset.get_fields()])

out = f"""<{type(fieldset).__name__}>
fields:
{textwrap.indent(fields_repr, 8 * " ")}
"""
return textwrap.dedent(out).strip()


def default_repr(obj: Any):
if is_builtin_object(obj):
return repr(obj)
return object.__repr__(obj)

Check warning on line 80 in parcels/_reprs.py

View check run for this annotation

Codecov / codecov/patch

parcels/_reprs.py#L78-L80

Added lines #L78 - L80 were not covered by tests


def is_builtin_object(obj):
return obj.__class__.__module__ == "builtins"

Check warning on line 84 in parcels/_reprs.py

View check run for this annotation

Codecov / codecov/patch

parcels/_reprs.py#L84

Added line #L84 was not covered by tests
3 changes: 1 addition & 2 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from uxarray.grid.neighbors import _barycentric_coordinates

from parcels._core.utils.unstructured import get_vertical_location_from_dims
from parcels._reprs import default_repr, field_repr
from parcels._typing import (
Mesh,
VectorType,
assert_valid_mesh,
)
from parcels.tools._helpers import default_repr, field_repr
from parcels.tools.converters import (
UnitConverter,
unitconverters_map,
Expand Down Expand Up @@ -185,7 +185,6 @@ def __init__(
e.add_note(f"Error validating field {name!r}.")
raise e

self._parent_mesh = data.attrs["mesh"]
self._mesh_type = mesh_type

# Setting the interpolation method dynamically
Expand Down
50 changes: 36 additions & 14 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import numpy as np
import xarray as xr

from parcels._reprs import fieldset_repr
from parcels._typing import Mesh
from parcels.field import Field, VectorField
from parcels.tools._helpers import fieldset_repr
from parcels.v4.grid import Grid

__all__ = ["FieldSet"]

Expand Down Expand Up @@ -41,10 +42,21 @@
"""

def __init__(self, fields: list[Field | VectorField]):
# TODO Nick : Enforce fields to be list of Field or VectorField objects
for field in fields:
if not isinstance(field, (Field, VectorField)):
raise ValueError(f"Expected `field` to be a Field or VectorField object. Got {field}")
Comment thread
VeckoTheGecko marked this conversation as resolved.

self.fields = {f.name: f for f in fields}
self.constants = {}

# TODO : Nick : Add _getattr_ magic method to allow access to fields by name
def __getattr__(self, name):
"""Get the field by name. If the field is not found, check if it's a constant."""
if name in self.fields:
return self.fields[name]
elif name in self.constants:
return self.constants[name]
else:
raise AttributeError(f"FieldSet has no attribute '{name}'")

Check warning on line 59 in parcels/fieldset.py

View check run for this annotation

Codecov / codecov/patch

parcels/fieldset.py#L59

Added line #L59 was not covered by tests

@property
def time_interval(self):
Expand Down Expand Up @@ -112,6 +124,9 @@
* `Unit converters <../examples/tutorial_unitconverters.ipynb>`__ (Default value = None)

"""
if not isinstance(field, (Field, VectorField)):
raise ValueError(f"Expected `field` to be a Field or VectorField object. Got {type(field)}")

name = field.name if name is None else name

if name in self.fields:
Expand All @@ -137,21 +152,25 @@
correction for zonal velocity U near the poles.
2. flat: No conversion, lat/lon are assumed to be in m.
"""
time = 0.0
values = np.full((1, 1, 1, 1), value)
data = xr.DataArray(
data=values,
name=name,
dims="null",
coords=[time, [0], [0], [0]],
attrs=dict(description="null", units="null", location="node", mesh="constant", mesh_type=mesh),
da = xr.DataArray(
data=np.full((1, 1, 1, 1), value),
dims=["T", "ZG", "YG", "XG"],
coords={
"ZG": (["ZG"], np.arange(1), {"axis": "Z"}),
"YG": (["YG"], np.arange(1), {"axis": "Y"}),
"XG": (["XG"], np.arange(1), {"axis": "X"}),
"lon": (["XG"], np.arange(1), {"axis": "X"}),
"lat": (["YG"], np.arange(1), {"axis": "Y"}),
"depth": (["ZG"], np.arange(1), {"axis": "Z"}),
},
)
grid = Grid(da)
Comment thread
VeckoTheGecko marked this conversation as resolved.
self.add_field(
Field(
name,
data,
da,
grid,
interp_method=None, # TODO : Need to define an interpolation method for constants
allow_time_extrapolation=True,
)
)

Expand Down Expand Up @@ -185,7 +204,10 @@
`Diffusion <../examples/tutorial_diffusion.ipynb>`__
`Periodic boundaries <../examples/tutorial_periodic_boundaries.ipynb>`__
"""
setattr(self, name, value)
if name in self.constants:
raise ValueError(f"FieldSet already has a constant with name '{name}'")

Check warning on line 208 in parcels/fieldset.py

View check run for this annotation

Codecov / codecov/patch

parcels/fieldset.py#L208

Added line #L208 was not covered by tests

self.constants[name] = np.float32(value)
Comment thread
VeckoTheGecko marked this conversation as resolved.

# def computeTimeChunk(self, time=0.0, dt=1):
# """Load a chunk of three data time steps into the FieldSet.
Expand Down
3 changes: 2 additions & 1 deletion parcels/particlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

import parcels
from parcels._compat import MPI
from parcels.tools._helpers import default_repr, timedelta_to_float
from parcels._reprs import default_repr
from parcels.tools._helpers import timedelta_to_float
from parcels.tools.warnings import FileWarning

__all__ = ["ParticleFile"]
Expand Down
3 changes: 2 additions & 1 deletion parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tqdm import tqdm

from parcels._compat import MPI
from parcels._reprs import particleset_repr
from parcels.application_kernels.advection import AdvectionRK4
from parcels.field import Field
from parcels.grid import GridType
Expand All @@ -25,7 +26,7 @@
from parcels.particle import Particle, Variable
from parcels.particledata import ParticleData, ParticleDataIterator
from parcels.particlefile import ParticleFile
from parcels.tools._helpers import particleset_repr, timedelta_to_float
from parcels.tools._helpers import timedelta_to_float
from parcels.tools.converters import _get_cftime_calendars, convert_to_flat_array
from parcels.tools.loggers import logger
from parcels.tools.statuscodes import StatusCode
Expand Down
80 changes: 0 additions & 80 deletions parcels/tools/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,12 @@
from __future__ import annotations

import functools
import textwrap
import warnings
from collections.abc import Callable
from datetime import timedelta
from typing import TYPE_CHECKING, Any

import numpy as np

if TYPE_CHECKING:
from parcels import Field, FieldSet, ParticleSet

PACKAGE = "Parcels"


Expand Down Expand Up @@ -68,81 +63,6 @@ def patch_docstring(obj: Callable, extra: str) -> None:
obj.__doc__ = f"{obj.__doc__ or ''}{extra}".strip()


def field_repr(field: Field) -> str:
"""Return a pretty repr for Field"""
out = f"""<{type(field).__name__}>
name : {field.name!r}
data : {field.data!r}
extrapolate time: {field.allow_time_extrapolation!r}
"""
return textwrap.dedent(out).strip()


def _format_list_items_multiline(items: list[str], level: int = 1) -> str:
"""Given a list of strings, formats them across multiple lines.

Uses indentation levels of 4 spaces provided by ``level``.

Example
-------
>>> output = _format_list_items_multiline(["item1", "item2", "item3"], 4)
>>> f"my_items: {output}"
my_items: [
item1,
item2,
item3,
]
"""
if len(items) == 0:
return "[]"

assert level >= 1, "Indentation level >=1 supported"
indentation_str = level * 4 * " "
indentation_str_end = (level - 1) * 4 * " "

items_str = ",\n".join([textwrap.indent(i, indentation_str) for i in items])
return f"[\n{items_str}\n{indentation_str_end}]"


def particleset_repr(pset: ParticleSet) -> str:
"""Return a pretty repr for ParticleSet"""
if len(pset) < 10:
particles = [repr(p) for p in pset]
else:
particles = [repr(pset[i]) for i in range(7)] + ["..."]

out = f"""<{type(pset).__name__}>
fieldset :
{textwrap.indent(repr(pset.fieldset), " " * 8)}
pclass : {pset.pclass}
repeatdt : {pset.repeatdt}
# particles: {len(pset)}
particles : {_format_list_items_multiline(particles, level=2)}
"""
return textwrap.dedent(out).strip()


def fieldset_repr(fieldset: FieldSet) -> str:
"""Return a pretty repr for FieldSet"""
fields_repr = "\n".join([repr(f) for f in fieldset.get_fields()])

out = f"""<{type(fieldset).__name__}>
fields:
{textwrap.indent(fields_repr, 8 * " ")}
"""
return textwrap.dedent(out).strip()


def default_repr(obj: Any):
if is_builtin_object(obj):
return repr(obj)
return object.__repr__(obj)


def is_builtin_object(obj):
return obj.__class__.__module__ == "builtins"


def timedelta_to_float(dt: float | timedelta | np.timedelta64) -> float:
"""Convert a timedelta to a float in seconds."""
if isinstance(dt, timedelta):
Expand Down
10 changes: 2 additions & 8 deletions tests/v4/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@ def test_field_incompatible_combination(data, grid):
[
pytest.param(
structured_datasets["ds_2d_left"]["data_g"], Grid(structured_datasets["ds_2d_left"]), id="ds_2d_left"
),
), # TODO: Perhaps this test should be expanded to cover more datasets?
],
)
@pytest.mark.xfail(reason="Structured grid creation is not implemented yet")
def test_field_structured_grid_creation(data, grid):
"""Test creating a field."""
field = Field(
Expand All @@ -66,15 +65,10 @@ def test_field_structured_grid_creation(data, grid):
grid=grid,
)
assert field.name == "test_field"
assert field.data == data
assert field.data.equals(data)
assert field.grid == grid


def test_field_structured_grid_creation_spherical():
# Field(..., mesh_type="spherical")
...


def test_field_unstructured_grid_creation(): ...


Expand Down
Loading
Loading