diff --git a/parcels/fieldset.py b/parcels/fieldset.py index d2ae02b18c..26543e9e8c 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -1,7 +1,6 @@ import functools import numpy as np -import uxarray as ux import xarray as xr from parcels._typing import Mesh @@ -41,30 +40,11 @@ class FieldSet: """ - def __init__(self, datasets: list[Field | VectorField]): - self.datasets = datasets + def __init__(self, fields: list[Field | VectorField]): + # TODO Nick : Enforce fields to be list of Field or VectorField objects + self.fields = {f.name: f for f in fields} - self._fieldnames = [] - time_origin = None - # Create pointers to each (Ux)DataArray - for ds in datasets: - for field in ds.data_vars: - if type(ds[field]) is ux.UxDataArray: - self.add_field(Field(field, ds[field], grid=ds[field].uxgrid), field) - else: - self.add_field(Field(field, ds[field]), field) - self._fieldnames.append(field) - - if "time" in ds.coords: - if time_origin is None: - time_origin = ds.time.min().data - else: - time_origin = min(time_origin, ds.time.min().data) - else: - time_origin = 0.0 - - self.time_origin = time_origin - self._add_UVfield() + # TODO : Nick : Add _getattr_ magic method to allow access to fields by name @property def time_interval(self): @@ -111,7 +91,7 @@ def dimrange(self, dim): @property def gridset_size(self): - return len(self._fieldnames) + return len(self.fields) def add_field(self, field: Field, name: str | None = None): """Add a :class:`parcels.field.Field` object to the FieldSet. @@ -137,11 +117,7 @@ def add_field(self, field: Field, name: str | None = None): if name in self.fields: raise ValueError(f"FieldSet already has a Field with name '{name}'") - if hasattr(self, name): # check if Field with same name already exists when adding new Field - raise RuntimeError(f"FieldSet already has a Field with name '{name}'") - else: - setattr(self, name, field) - self._fieldnames.append(name) + self.fields[name] = field def add_constant_field(self, name: str, value, mesh: Mesh = "flat"): """Wrapper function to add a Field that is constant in space, @@ -179,19 +155,6 @@ def add_constant_field(self, name: str, value, mesh: Mesh = "flat"): ) ) - def add_vector_field(self, vfield): - """Add a :class:`parcels.field.VectorField` object to the FieldSet. - - Parameters - ---------- - vfield : parcels.VectorField - class:`parcels.FieldSet.VectorField` object to be added - """ - setattr(self, vfield.name, vfield) - for v in vfield.__dict__.values(): - if isinstance(v, Field) and (v not in self.get_fields()): - self.add_field(v) - def get_fields(self) -> list[Field | VectorField]: """Returns a list of all the :class:`parcels.field.Field` and :class:`parcels.field.VectorField` objects associated with this FieldSet. @@ -203,12 +166,6 @@ def get_fields(self) -> list[Field | VectorField]: fields.append(v) return fields - def _add_UVfield(self): - if not hasattr(self, "UV") and hasattr(self, "U") and hasattr(self, "V"): - self.add_vector_field(VectorField("UV", self.U, self.V)) - if not hasattr(self, "UVW") and hasattr(self, "W"): - self.add_vector_field(VectorField("UVW", self.U, self.V, self.W)) - def add_constant(self, name, value): """Add a constant to the FieldSet. Note that all constants are stored as 32-bit floats. diff --git a/parcels/particleset.py b/parcels/particleset.py index 98ce3d0487..a2dd9f7a75 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -106,8 +106,6 @@ def __init__( self._interaction_kernel = None self.fieldset = fieldset - self.fieldset._check_complete() - self.time_origin = fieldset.time_origin self._pclass = pclass # ==== first: create a new subclass of the pclass that includes the required variables ==== # @@ -962,7 +960,7 @@ def execute( if runtime is not None and endtime is not None: raise RuntimeError("Only one of (endtime, runtime) can be specified") - mintime, maxtime = self.fieldset.dimrange("time") + mintime, maxtime = self.fieldset.dimrange("time") # TODO : change to fieldset.time_interval default_release_time = mintime if dt >= 0 else maxtime if np.any(np.isnan(self.particledata.data["time"])): diff --git a/tests/v4/test_uxarray_fieldset.py b/tests/v4/test_uxarray_fieldset.py index b9e737beb6..521aab87ca 100644 --- a/tests/v4/test_uxarray_fieldset.py +++ b/tests/v4/test_uxarray_fieldset.py @@ -4,11 +4,13 @@ import uxarray as ux from parcels import ( + Field, FieldSet, Particle, ParticleSet, UXPiecewiseConstantFace, UXPiecewiseLinearNode, + VectorField, download_example_dataset, ) @@ -26,33 +28,63 @@ def ds_fesom_channel() -> ux.UxDataset: return ds -def test_fesom_fieldset(ds_fesom_channel): - fieldset = FieldSet([ds_fesom_channel]) +@pytest.fixture +def uv_fesom_channel(ds_fesom_channel) -> VectorField: + UV = VectorField( + name="UV", + U=Field(name="U", data=ds_fesom_channel.U, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseConstantFace), + V=Field(name="V", data=ds_fesom_channel.V, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseConstantFace), + ) + return UV + + +@pytest.fixture +def uvw_fesom_channel(ds_fesom_channel) -> VectorField: + UVW = VectorField( + name="UVW", + U=Field(name="U", data=ds_fesom_channel.U, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseConstantFace), + V=Field(name="V", data=ds_fesom_channel.V, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseConstantFace), + W=Field(name="W", data=ds_fesom_channel.W, grid=ds_fesom_channel.uxgrid, interp_method=UXPiecewiseLinearNode), + ) + return UVW + + +def test_fesom_fieldset(ds_fesom_channel, uv_fesom_channel): + fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V]) # Check that the fieldset has the expected properties - assert fieldset.datasets[0] == ds_fesom_channel + assert (fieldset.fields["U"] == ds_fesom_channel.U).all() + assert (fieldset.fields["V"] == ds_fesom_channel.V).all() -def test_fesom_in_particleset(ds_fesom_channel): - fieldset = FieldSet([ds_fesom_channel]) +@pytest.mark.skip(reason="ParticleSet.__init__ needs major refactoring") +def test_fesom_in_particleset(ds_fesom_channel, uv_fesom_channel): + fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V]) # Check that the fieldset has the expected properties - assert fieldset.datasets[0] == ds_fesom_channel + assert (fieldset.fields["U"] == ds_fesom_channel.U).all() + assert (fieldset.fields["V"] == ds_fesom_channel.V).all() pset = ParticleSet(fieldset, pclass=Particle) assert pset.fieldset == fieldset -def test_set_interp_methods(ds_fesom_channel): - fieldset = FieldSet([ds_fesom_channel]) +def test_set_interp_methods(ds_fesom_channel, uv_fesom_channel): + fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V]) + # Check that the fieldset has the expected properties + assert (fieldset.fields["U"] == ds_fesom_channel.U).all() + assert (fieldset.fields["V"] == ds_fesom_channel.V).all() + # Set the interpolation method for each field - fieldset.U.interp_method = UXPiecewiseConstantFace - fieldset.V.interp_method = UXPiecewiseConstantFace - fieldset.W.interp_method = UXPiecewiseLinearNode + fieldset.fields["U"].interp_method = UXPiecewiseConstantFace + fieldset.fields["V"].interp_method = UXPiecewiseConstantFace -def test_fesom_channel(ds_fesom_channel): - fieldset = FieldSet([ds_fesom_channel]) - # Set the interpolation method for each field - fieldset.U.interp_method = UXPiecewiseConstantFace - fieldset.V.interp_method = UXPiecewiseConstantFace - fieldset.W.interp_method = UXPiecewiseLinearNode +@pytest.mark.skip(reason="ParticleSet.__init__ needs major refactoring") +def test_fesom_channel(ds_fesom_channel, uvw_fesom_channel): + fieldset = FieldSet([uvw_fesom_channel, uvw_fesom_channel.U, uvw_fesom_channel.V, uvw_fesom_channel.W]) + + # Check that the fieldset has the expected properties + assert (fieldset.fields["U"] == ds_fesom_channel.U).all() + assert (fieldset.fields["V"] == ds_fesom_channel.V).all() + assert (fieldset.fields["W"] == ds_fesom_channel.W).all() + pset = ParticleSet(fieldset, pclass=Particle) pset.execute(endtime=timedelta(days=1), dt=timedelta(hours=1))