From 1c4ab232e3658ecf91d9cbe57746d2b56905071c Mon Sep 17 00:00:00 2001 From: Joe Date: Tue, 29 Apr 2025 10:15:40 -0400 Subject: [PATCH 1/4] Remove "completed" function; add "time_interval" to fieldset --- parcels/fieldset.py | 121 +++++++------------------------------------- 1 file changed, 17 insertions(+), 104 deletions(-) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index c0f5842258..1ed90f815c 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -1,7 +1,7 @@ import numpy as np import uxarray as ux import xarray as xr - +import functools from parcels._typing import Mesh from parcels.field import Field, VectorField from parcels.tools._helpers import fieldset_repr @@ -42,7 +42,6 @@ class FieldSet: def __init__(self, datasets: list[xr.Dataset | ux.UxDataset]): self.datasets = datasets - self._completed: bool = False self._fieldnames = [] time_origin = None # Create pointers to each (Ux)DataArray @@ -65,6 +64,19 @@ def __init__(self, datasets: list[xr.Dataset | ux.UxDataset]): self.time_origin = time_origin self._add_UVfield() + + @property + def time_interval(self): + """"Returns the valid executable time interval of the FieldSet, + which is the intersection of the time intervals of all fields + in the FieldSet. + """ + time_intervals = (f.time_interval for f in self.fields.values()) + + # Filter out Nones from constant Fields + time_intervals = (t for t in time_intervals if t is not None) + return functools.reduce(lambda x, y: x.intersection(y), time_intervals) + def __repr__(self): return fieldset_repr(self) @@ -96,16 +108,6 @@ def dimrange(self, dim): return maxleft, minright - # @property - # def particlefile(self): - # return self._particlefile - - @staticmethod - def checkvaliddimensionsdict(dims): - for d in dims: - if d not in ["lon", "lat", "depth", "time"]: - raise NameError(f"{d} is not a valid key in the dimensions dictionary") - @property def gridset_size(self): return len(self._fieldnames) @@ -129,12 +131,11 @@ def add_field(self, field: Field, name: str | None = None): * `Unit converters <../examples/tutorial_unitconverters.ipynb>`__ (Default value = None) """ - if self._completed: - raise RuntimeError( - "FieldSet has already been completed. Are you trying to add a Field after you've created the ParticleSet?" - ) name = field.name if name is None else name + if name in self.fields: + raise ValueError(f"FieldSet already has a Field with name '{name}'") + if hasattr(self, name): # check if Field with same name already exists when adding new Field raise RuntimeError(f"FieldSet already has a Field with name '{name}'") else: @@ -207,94 +208,6 @@ def _add_UVfield(self): if not hasattr(self, "UVW") and hasattr(self, "W"): self.add_vector_field(VectorField("UVW", self.U, self.V, self.W)) - def _check_complete(self): - assert self.U, 'FieldSet does not have a Field named "U"' - assert self.V, 'FieldSet does not have a Field named "V"' - for attr, value in vars(self).items(): - if type(value) is Field: - assert value.name == attr, f"Field {value.name}.name ({attr}) is not consistent" - - self._add_UVfield() - - self._completed = True - - # @classmethod - # def from_nemo( - # cls, - # filenames, - # variables, - # dimensions, - # mesh: Mesh = "spherical", - # allow_time_extrapolation: bool | None = None, - # tracer_interp_method: InterpMethodOption = "cgrid_tracer", - # **kwargs, - # ): - - # @classmethod - # def from_mitgcm( - # cls, - # filenames, - # variables, - # dimensions, - # mesh: Mesh = "spherical", - # allow_time_extrapolation: bool | None = None, - # tracer_interp_method: InterpMethodOption = "cgrid_tracer", - # **kwargs, - # ): - - # @classmethod - # def from_croco( - # cls, - # filenames, - # variables, - # dimensions, - # hc: float | None = None, - # mesh="spherical", - # allow_time_extrapolation=None, - # tracer_interp_method="cgrid_tracer", - # **kwargs, - # ): - - # @classmethod - # def from_c_grid_dataset( - # cls, - # filenames, - # variables, - # dimensions, - # mesh: Mesh = "spherical", - # allow_time_extrapolation: bool | None = None, - # tracer_interp_method: InterpMethodOption = "cgrid_tracer", - # gridindexingtype: GridIndexingType = "nemo", - # **kwargs, - # ): - - # @classmethod - # def from_mom5( - # cls, - # filenames, - # variables, - # dimensions, - # mesh: Mesh = "spherical", - # allow_time_extrapolation: bool | None = None, - # tracer_interp_method: InterpMethodOption = "bgrid_tracer", - # **kwargs, - # ): - - # @classmethod - # def from_a_grid_dataset(cls, filenames, variables, dimensions, **kwargs): - - # @classmethod - # def from_b_grid_dataset( - # cls, - # filenames, - # variables, - # dimensions, - # mesh: Mesh = "spherical", - # allow_time_extrapolation: bool | None = None, - # tracer_interp_method: InterpMethodOption = "bgrid_tracer", - # **kwargs, - # ): - def add_constant(self, name, value): """Add a constant to the FieldSet. Note that all constants are stored as 32-bit floats. From 654426513c252894b6873546e29300ab33269b64 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 29 Apr 2025 14:53:11 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- parcels/fieldset.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 1ed90f815c..51d3c9e6f0 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -1,7 +1,9 @@ +import functools + import numpy as np import uxarray as ux import xarray as xr -import functools + from parcels._typing import Mesh from parcels.field import Field, VectorField from parcels.tools._helpers import fieldset_repr @@ -64,11 +66,10 @@ def __init__(self, datasets: list[xr.Dataset | ux.UxDataset]): self.time_origin = time_origin self._add_UVfield() - @property def time_interval(self): - """"Returns the valid executable time interval of the FieldSet, - which is the intersection of the time intervals of all fields + """ "Returns the valid executable time interval of the FieldSet, + which is the intersection of the time intervals of all fields in the FieldSet. """ time_intervals = (f.time_interval for f in self.fields.values()) @@ -76,7 +77,7 @@ def time_interval(self): # Filter out Nones from constant Fields time_intervals = (t for t in time_intervals if t is not None) return functools.reduce(lambda x, y: x.intersection(y), time_intervals) - + def __repr__(self): return fieldset_repr(self) From c956c288c6ee0d966c56717c1178be0fb1089f8e Mon Sep 17 00:00:00 2001 From: Joe Date: Tue, 29 Apr 2025 11:15:26 -0400 Subject: [PATCH 3/4] Fix formatting --- parcels/fieldset.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 1ed90f815c..d2ae02b18c 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -1,7 +1,9 @@ +import functools + import numpy as np import uxarray as ux import xarray as xr -import functools + from parcels._typing import Mesh from parcels.field import Field, VectorField from parcels.tools._helpers import fieldset_repr @@ -39,7 +41,7 @@ class FieldSet: """ - def __init__(self, datasets: list[xr.Dataset | ux.UxDataset]): + def __init__(self, datasets: list[Field | VectorField]): self.datasets = datasets self._fieldnames = [] @@ -64,11 +66,10 @@ def __init__(self, datasets: list[xr.Dataset | ux.UxDataset]): self.time_origin = time_origin self._add_UVfield() - @property def time_interval(self): - """"Returns the valid executable time interval of the FieldSet, - which is the intersection of the time intervals of all fields + """Returns the valid executable time interval of the FieldSet, + which is the intersection of the time intervals of all fields in the FieldSet. """ time_intervals = (f.time_interval for f in self.fields.values()) @@ -76,7 +77,7 @@ def time_interval(self): # Filter out Nones from constant Fields time_intervals = (t for t in time_intervals if t is not None) return functools.reduce(lambda x, y: x.intersection(y), time_intervals) - + def __repr__(self): return fieldset_repr(self) From 82408e14490bfe4c00893d5e256c0adb1a723097 Mon Sep 17 00:00:00 2001 From: Joe Date: Tue, 29 Apr 2025 11:18:26 -0400 Subject: [PATCH 4/4] Remove call to _check_complete --- tests/v4/test_uxarray_fieldset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/v4/test_uxarray_fieldset.py b/tests/v4/test_uxarray_fieldset.py index 0e168fd098..b9e737beb6 100644 --- a/tests/v4/test_uxarray_fieldset.py +++ b/tests/v4/test_uxarray_fieldset.py @@ -28,7 +28,6 @@ def ds_fesom_channel() -> ux.UxDataset: def test_fesom_fieldset(ds_fesom_channel): fieldset = FieldSet([ds_fesom_channel]) - fieldset._check_complete() # Check that the fieldset has the expected properties assert fieldset.datasets[0] == ds_fesom_channel