Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 18 additions & 104 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import functools

import numpy as np
import uxarray as ux
import xarray as xr
Expand Down Expand Up @@ -39,10 +41,9 @@ class FieldSet:

"""

def __init__(self, datasets: list[xr.Dataset | ux.UxDataset]):
def __init__(self, datasets: list[Field | VectorField]):
self.datasets = datasets

self._completed: bool = False
self._fieldnames = []
time_origin = None
# Create pointers to each (Ux)DataArray
Expand All @@ -65,6 +66,18 @@ def __init__(self, datasets: list[xr.Dataset | ux.UxDataset]):
self.time_origin = time_origin
self._add_UVfield()

@property
def time_interval(self):
"""Returns the valid executable time interval of the FieldSet,
which is the intersection of the time intervals of all fields
in the FieldSet.
"""
time_intervals = (f.time_interval for f in self.fields.values())

# Filter out Nones from constant Fields
time_intervals = (t for t in time_intervals if t is not None)
return functools.reduce(lambda x, y: x.intersection(y), time_intervals)

def __repr__(self):
return fieldset_repr(self)

Expand Down Expand Up @@ -96,16 +109,6 @@ def dimrange(self, dim):

return maxleft, minright

# @property
# def particlefile(self):
# return self._particlefile

@staticmethod
def checkvaliddimensionsdict(dims):
for d in dims:
if d not in ["lon", "lat", "depth", "time"]:
raise NameError(f"{d} is not a valid key in the dimensions dictionary")

@property
def gridset_size(self):
return len(self._fieldnames)
Expand All @@ -129,12 +132,11 @@ def add_field(self, field: Field, name: str | None = None):
* `Unit converters <../examples/tutorial_unitconverters.ipynb>`__ (Default value = None)

"""
if self._completed:
raise RuntimeError(
"FieldSet has already been completed. Are you trying to add a Field after you've created the ParticleSet?"
)
name = field.name if name is None else name

if name in self.fields:
raise ValueError(f"FieldSet already has a Field with name '{name}'")

if hasattr(self, name): # check if Field with same name already exists when adding new Field
raise RuntimeError(f"FieldSet already has a Field with name '{name}'")
else:
Expand Down Expand Up @@ -207,94 +209,6 @@ def _add_UVfield(self):
if not hasattr(self, "UVW") and hasattr(self, "W"):
self.add_vector_field(VectorField("UVW", self.U, self.V, self.W))

def _check_complete(self):
assert self.U, 'FieldSet does not have a Field named "U"'
assert self.V, 'FieldSet does not have a Field named "V"'
for attr, value in vars(self).items():
if type(value) is Field:
assert value.name == attr, f"Field {value.name}.name ({attr}) is not consistent"

self._add_UVfield()

self._completed = True

# @classmethod
# def from_nemo(
# cls,
# filenames,
# variables,
# dimensions,
# mesh: Mesh = "spherical",
# allow_time_extrapolation: bool | None = None,
# tracer_interp_method: InterpMethodOption = "cgrid_tracer",
# **kwargs,
# ):

# @classmethod
# def from_mitgcm(
# cls,
# filenames,
# variables,
# dimensions,
# mesh: Mesh = "spherical",
# allow_time_extrapolation: bool | None = None,
# tracer_interp_method: InterpMethodOption = "cgrid_tracer",
# **kwargs,
# ):

# @classmethod
# def from_croco(
# cls,
# filenames,
# variables,
# dimensions,
# hc: float | None = None,
# mesh="spherical",
# allow_time_extrapolation=None,
# tracer_interp_method="cgrid_tracer",
# **kwargs,
# ):

# @classmethod
# def from_c_grid_dataset(
# cls,
# filenames,
# variables,
# dimensions,
# mesh: Mesh = "spherical",
# allow_time_extrapolation: bool | None = None,
# tracer_interp_method: InterpMethodOption = "cgrid_tracer",
# gridindexingtype: GridIndexingType = "nemo",
# **kwargs,
# ):

# @classmethod
# def from_mom5(
# cls,
# filenames,
# variables,
# dimensions,
# mesh: Mesh = "spherical",
# allow_time_extrapolation: bool | None = None,
# tracer_interp_method: InterpMethodOption = "bgrid_tracer",
# **kwargs,
# ):

# @classmethod
# def from_a_grid_dataset(cls, filenames, variables, dimensions, **kwargs):

# @classmethod
# def from_b_grid_dataset(
# cls,
# filenames,
# variables,
# dimensions,
# mesh: Mesh = "spherical",
# allow_time_extrapolation: bool | None = None,
# tracer_interp_method: InterpMethodOption = "bgrid_tracer",
# **kwargs,
# ):

def add_constant(self, name, value):
"""Add a constant to the FieldSet. Note that all constants are
stored as 32-bit floats.
Expand Down
1 change: 0 additions & 1 deletion tests/v4/test_uxarray_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def ds_fesom_channel() -> ux.UxDataset:

def test_fesom_fieldset(ds_fesom_channel):
fieldset = FieldSet([ds_fesom_channel])
fieldset._check_complete()
# Check that the fieldset has the expected properties
assert fieldset.datasets[0] == ds_fesom_channel

Expand Down
Loading