Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 6 additions & 49 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import functools

import numpy as np
import uxarray as ux
import xarray as xr

from parcels._typing import Mesh
Expand Down Expand Up @@ -41,30 +40,11 @@ class FieldSet:

"""

def __init__(self, datasets: list[Field | VectorField]):
self.datasets = datasets
def __init__(self, fields: list[Field | VectorField]):
# TODO Nick : Enforce fields to be list of Field or VectorField objects
self.fields = {f.name: f for f in fields}

self._fieldnames = []
time_origin = None
# Create pointers to each (Ux)DataArray
for ds in datasets:
for field in ds.data_vars:
if type(ds[field]) is ux.UxDataArray:
self.add_field(Field(field, ds[field], grid=ds[field].uxgrid), field)
else:
self.add_field(Field(field, ds[field]), field)
self._fieldnames.append(field)

if "time" in ds.coords:
if time_origin is None:
time_origin = ds.time.min().data
else:
time_origin = min(time_origin, ds.time.min().data)
else:
time_origin = 0.0

self.time_origin = time_origin
self._add_UVfield()
# TODO : Nick : Add _getattr_ magic method to allow access to fields by name

@property
def time_interval(self):
Expand Down Expand Up @@ -111,7 +91,7 @@ def dimrange(self, dim):

@property
def gridset_size(self):
return len(self._fieldnames)
return len(self.fields)

def add_field(self, field: Field, name: str | None = None):
"""Add a :class:`parcels.field.Field` object to the FieldSet.
Expand All @@ -137,11 +117,7 @@ def add_field(self, field: Field, name: str | None = None):
if name in self.fields:
raise ValueError(f"FieldSet already has a Field with name '{name}'")

if hasattr(self, name): # check if Field with same name already exists when adding new Field
raise RuntimeError(f"FieldSet already has a Field with name '{name}'")
else:
setattr(self, name, field)
self._fieldnames.append(name)
self.fields[name] = field

def add_constant_field(self, name: str, value, mesh: Mesh = "flat"):
"""Wrapper function to add a Field that is constant in space,
Expand Down Expand Up @@ -179,19 +155,6 @@ def add_constant_field(self, name: str, value, mesh: Mesh = "flat"):
)
)

def add_vector_field(self, vfield):
"""Add a :class:`parcels.field.VectorField` object to the FieldSet.

Parameters
----------
vfield : parcels.VectorField
class:`parcels.FieldSet.VectorField` object to be added
"""
setattr(self, vfield.name, vfield)
for v in vfield.__dict__.values():
if isinstance(v, Field) and (v not in self.get_fields()):
self.add_field(v)

def get_fields(self) -> list[Field | VectorField]:
"""Returns a list of all the :class:`parcels.field.Field` and :class:`parcels.field.VectorField`
objects associated with this FieldSet.
Expand All @@ -203,12 +166,6 @@ def get_fields(self) -> list[Field | VectorField]:
fields.append(v)
return fields

def _add_UVfield(self):
if not hasattr(self, "UV") and hasattr(self, "U") and hasattr(self, "V"):
self.add_vector_field(VectorField("UV", self.U, self.V))
if not hasattr(self, "UVW") and hasattr(self, "W"):
self.add_vector_field(VectorField("UVW", self.U, self.V, self.W))

def add_constant(self, name, value):
"""Add a constant to the FieldSet. Note that all constants are
stored as 32-bit floats.
Expand Down
4 changes: 1 addition & 3 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@ def __init__(
self._interaction_kernel = None

self.fieldset = fieldset
self.fieldset._check_complete()
self.time_origin = fieldset.time_origin
self._pclass = pclass

# ==== first: create a new subclass of the pclass that includes the required variables ==== #
Expand Down Expand Up @@ -962,7 +960,7 @@ def execute(
if runtime is not None and endtime is not None:
raise RuntimeError("Only one of (endtime, runtime) can be specified")

mintime, maxtime = self.fieldset.dimrange("time")
mintime, maxtime = self.fieldset.dimrange("time") # TODO : change to fieldset.time_interval

default_release_time = mintime if dt >= 0 else maxtime
if np.any(np.isnan(self.particledata.data["time"])):
Expand Down
66 changes: 49 additions & 17 deletions tests/v4/test_uxarray_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import uxarray as ux

from parcels import (
Field,
FieldSet,
Particle,
ParticleSet,
UXPiecewiseConstantFace,
UXPiecewiseLinearNode,
VectorField,
download_example_dataset,
)

Expand All @@ -26,33 +28,63 @@ def ds_fesom_channel() -> ux.UxDataset:
return ds


def test_fesom_fieldset(ds_fesom_channel):
fieldset = FieldSet([ds_fesom_channel])
@pytest.fixture
def uv_fesom_channel(ds_fesom_channel) -> VectorField:
UV = VectorField(
name="UV",
U=Field(name="U", data=ds_fesom_channel.U, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseConstantFace),
V=Field(name="V", data=ds_fesom_channel.V, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseConstantFace),
)
return UV


@pytest.fixture
def uvw_fesom_channel(ds_fesom_channel) -> VectorField:
UVW = VectorField(
name="UVW",
U=Field(name="U", data=ds_fesom_channel.U, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseConstantFace),
V=Field(name="V", data=ds_fesom_channel.V, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseConstantFace),
W=Field(name="W", data=ds_fesom_channel.W, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseLinearNode),
)
return UVW


def test_fesom_fieldset(ds_fesom_channel, uv_fesom_channel):
fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V])
# Check that the fieldset has the expected properties
assert fieldset.datasets[0] == ds_fesom_channel
assert (fieldset.fields["U"] == ds_fesom_channel.U).all()
assert (fieldset.fields["V"] == ds_fesom_channel.V).all()


def test_fesom_in_particleset(ds_fesom_channel):
fieldset = FieldSet([ds_fesom_channel])
@pytest.mark.skip(reason="ParticleSet.__init__ needs major refactoring")
def test_fesom_in_particleset(ds_fesom_channel, uv_fesom_channel):
fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V])
# Check that the fieldset has the expected properties
assert fieldset.datasets[0] == ds_fesom_channel
assert (fieldset.fields["U"] == ds_fesom_channel.U).all()
assert (fieldset.fields["V"] == ds_fesom_channel.V).all()
pset = ParticleSet(fieldset, pclass=Particle)
assert pset.fieldset == fieldset


def test_set_interp_methods(ds_fesom_channel):
fieldset = FieldSet([ds_fesom_channel])
def test_set_interp_methods(ds_fesom_channel, uv_fesom_channel):
fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V])
# Check that the fieldset has the expected properties
assert (fieldset.fields["U"] == ds_fesom_channel.U).all()
assert (fieldset.fields["V"] == ds_fesom_channel.V).all()

# Set the interpolation method for each field
fieldset.U.interp_method = UXPiecewiseConstantFace
fieldset.V.interp_method = UXPiecewiseConstantFace
fieldset.W.interp_method = UXPiecewiseLinearNode
fieldset.fields["U"].interp_method = UXPiecewiseConstantFace
fieldset.fields["V"].interp_method = UXPiecewiseConstantFace


def test_fesom_channel(ds_fesom_channel):
fieldset = FieldSet([ds_fesom_channel])
# Set the interpolation method for each field
fieldset.U.interp_method = UXPiecewiseConstantFace
fieldset.V.interp_method = UXPiecewiseConstantFace
fieldset.W.interp_method = UXPiecewiseLinearNode
@pytest.mark.skip(reason="ParticleSet.__init__ needs major refactoring")
def test_fesom_channel(ds_fesom_channel, uvw_fesom_channel):
fieldset = FieldSet([uvw_fesom_channel, uvw_fesom_channel.U, uvw_fesom_channel.V, uvw_fesom_channel.W])

# Check that the fieldset has the expected properties
assert (fieldset.fields["U"] == ds_fesom_channel.U).all()
assert (fieldset.fields["V"] == ds_fesom_channel.V).all()
assert (fieldset.fields["W"] == ds_fesom_channel.W).all()

pset = ParticleSet(fieldset, pclass=Particle)
pset.execute(endtime=timedelta(days=1), dt=timedelta(hours=1))
Loading