diff --git a/tests/conftest.py b/tests/conftest.py index 1e23dc6..4886252 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ # conftest: some configuration for the tests +from pathlib import Path + import pytest @@ -35,3 +37,18 @@ def pytest_collection_modifyitems(config, items): # if envnames: # if item.config.getoption("-E") not in envnames: # pytest.skip(f"test requires env in {envnames!r}") + +EXAMPLE_DATA = Path(__file__).parent / 'example_data' + +UGRID_FILES = [EXAMPLE_DATA / 'SFBOFS_subset1.nc', + EXAMPLE_DATA / 'small_ugrid_zero_based.nc', + EXAMPLE_DATA / 'tris_and_bounds.nc', + ] + +SGRID_FILES = [EXAMPLE_DATA / 'arakawa_c_test_grid.nc', + ] + +RGRID_FILES = [EXAMPLE_DATA / '2D-rectangular_grid_wind.nc', + EXAMPLE_DATA / 'rectangular_grid_decreasing.nc', + EXAMPLE_DATA / 'AMSEAS-subset.nc', + ] diff --git a/tests/example_data/2D-rectangular_grid_wind.nc b/tests/example_data/2D-rectangular_grid_wind.nc new file mode 100644 index 0000000..1be0a50 Binary files /dev/null and b/tests/example_data/2D-rectangular_grid_wind.nc differ diff --git a/docs/examples/example_data/SFBOFS_subset1.nc b/tests/example_data/SFBOFS_subset1.nc similarity index 100% rename from docs/examples/example_data/SFBOFS_subset1.nc rename to tests/example_data/SFBOFS_subset1.nc diff --git a/tests/example_data/rectangular_grid_decreasing.nc b/tests/example_data/rectangular_grid_decreasing.nc new file mode 100644 index 0000000..dc1f7c0 Binary files /dev/null and b/tests/example_data/rectangular_grid_decreasing.nc differ diff --git a/docs/examples/example_data/small_ugrid_zero_based.nc b/tests/example_data/small_ugrid_zero_based.nc similarity index 100% rename from docs/examples/example_data/small_ugrid_zero_based.nc rename to tests/example_data/small_ugrid_zero_based.nc diff --git a/docs/examples/example_data/tris_and_bounds.nc b/tests/example_data/tris_and_bounds.nc similarity index 100% rename from docs/examples/example_data/tris_and_bounds.nc rename to tests/example_data/tris_and_bounds.nc diff --git a/tests/test_grids/test_regular_grid.py b/tests/test_grids/test_regular_grid.py index 772b5ed..95f1abe 100644 --- a/tests/test_grids/test_regular_grid.py +++ b/tests/test_grids/test_regular_grid.py @@ -4,127 +4,274 @@ from pathlib import Path -try: - import fsspec -except ImportError: - fsspec = None +import numpy as np +import pytest + +# only needed if you want to hit AWS servers. +# try: +# import fsspec +# except ImportError: +# fsspec = None import xarray as xr +from tests.conftest import RGRID_FILES, SGRID_FILES, UGRID_FILES from xarray_subset_grid.grids.regular_grid import RegularGrid -TEST_DATA = Path(__file__).parent.parent / "example_data" +EXAMPLE_DATA = Path(__file__).parent.parent / "example_data" -TEST_FILE1 = TEST_DATA / "AMSEAS-subset.nc" # NGOFS2_RGRID.nc is a small subset of the regridded NGOFS2 model. # It was created by the "OFS subsetter" - -def test_recognise(): +@pytest.mark.parametrize("test_file", RGRID_FILES) +def test_recognize(test_file): """ works for at least one file ... """ - ds = xr.open_dataset(TEST_FILE1) + ds = xr.open_dataset(test_file) assert RegularGrid.recognize(ds) -def test_recognise_not(): +@pytest.mark.parametrize("test_file", UGRID_FILES + SGRID_FILES) +def test_recognize_not(test_file): """ - should not recognise an SGrid + should not recognize an SGrid """ - ds = xr.open_dataset(TEST_DATA / "arakawa_c_test_grid.nc") + ds = xr.open_dataset(test_file) assert not RegularGrid.recognize(ds) -####### -# These from the ugrid tests -- need to be adapted -####### +def create_synthetic_rectangular_grid_dataset(decreasing=False): + """ + Create a synthetic dataset with regular grid. -# def test_grid_vars(): -# """ -# Check if the grid vars are defined properly -# """ -# ds = xr.open_dataset(EXAMPLE_DATA / "SFBOFS_subset1.nc") + Can be either decreasing or increasing in latitude + """ -# ds = ugrid.assign_ugrid_topology(ds, **grid_topology) + lon = np.linspace(-100, -80, 21) + if decreasing: + lat = np.linspace(50, 30, 21) + else: + lat = np.linspace(30, 50, 21) -# grid_vars = ds.xsg.grid_vars + data = np.random.rand(21, 21) -# # ['mesh', 'nv', 'lon', 'lat', 'lonc', 'latc'] -# assert grid_vars == set(["mesh", "nv", "nbe", "lon", "lat", "lonc", "latc"]) + ds = xr.Dataset( + data_vars={ + "temp": (("lat", "lon"), data), + "salt": (("lat", "lon"), data), + }, + coords={ + "lat": lat, + "lon": lon, + }, + ) + # Add cf attributes + ds.lat.attrs = {"standard_name": "latitude", "units": "degrees_north"} + ds.lon.attrs = {"standard_name": "longitude", "units": "degrees_east"} + ds.temp.attrs = {"standard_name": "sea_water_temperature"} + return ds -# def test_data_vars(): -# """ -# Check if the grid vars are defined properly -# This is not currently working correctly! -# """ -# ds = xr.open_dataset(EXAMPLE_DATA / "SFBOFS_subset1.nc") -# ds = ugrid.assign_ugrid_topology(ds, **grid_topology) -# data_vars = ds.xsg.data_vars - -# assert set(data_vars) == set( -# [ -# "h", -# "zeta", -# "temp", -# "salinity", -# "u", -# "v", -# "uwind_speed", -# "vwind_speed", -# "wet_nodes", -# "wet_cells", -# ] -# ) +def test_grid_vars(): + """ + Check if the grid vars are defined properly + """ + ds = xr.open_dataset(EXAMPLE_DATA / "AMSEAS-subset.nc") -# def test_extra_vars(): -# """ -# Check if the extra vars are defined properly + grid_vars = ds.xsg.grid_vars -# This is not currently working correctly! -# """ -# ds = xr.open_dataset(EXAMPLE_DATA / "SFBOFS_subset1.nc") -# ds = ugrid.assign_ugrid_topology(ds, **grid_topology) + # ['mesh', 'nv', 'lon', 'lat', 'lonc', 'latc'] + assert grid_vars == {'lat', 'lon'} -# extra_vars = ds.xsg.extra_vars -# print([*ds]) -# print(f"{extra_vars=}") -# assert extra_vars == set( -# [ -# "nf_type", -# "Times", -# ] -# ) +def test_data_vars(): + """ + Check if the data vars are defined properly + This is not currently working correctly! -# def test_coords(): -# ds = xr.open_dataset(EXAMPLE_DATA / "SFBOFS_subset1.nc") -# ds = ugrid.assign_ugrid_topology(ds, **grid_topology) + it finds extra stuff + """ + ds = xr.open_dataset(EXAMPLE_DATA / "AMSEAS-subset.nc") + + data_vars = ds.xsg.data_vars + + # the extra "time" variables are not using the grid + # so they should not be listed as data_vars + assert data_vars == { + 'water_w', + 'salinity', + 'surf_roughness', + 'surf_temp_flux', + 'water_v', + # 'time_offset', + 'water_temp', + 'water_baro_v', + 'surf_atm_press', + 'surf_el', + 'surf_salt_flux', + 'water_u', + 'surf_wnd_stress_gridy', + 'water_baro_u', + 'watdep', + 'surf_solar_flux', + # 'time1_run', + 'surf_wnd_stress_gridx', + # 'time1_offset' + } + +# might not be needed if tested elsewhere. +def test_data_vars2(): + """ + redundant with above, by already written ... + """ + print("Testing data_vars error...") + ds = create_synthetic_rectangular_grid_dataset() + # Ensure it is recognized as a RegularGrid + assert RegularGrid.recognize(ds) + + # Access xsg accessor + data_vars = ds.xsg.data_vars + print(f"data_vars: {data_vars}") + + assert data_vars == {'salt', 'temp'} + + +def test_extra_vars(): + """ + Check if the extra vars are defined properly + """ + ds = xr.open_dataset(EXAMPLE_DATA / "AMSEAS-subset.nc") + + extra_vars = ds.xsg.extra_vars + + # the extra "time" variables are not using the grid + # so they should be listed as extra_vars + assert extra_vars == { + 'time_offset', + 'time1_run', + 'time1_offset' + } + +def test_subset_to_bb(): + """ + Not a complete test by any means, but the basics are there. + + NOTE: it doesn't test if the variables got subset corectly ... + + """ + ds = xr.open_dataset(EXAMPLE_DATA / "2D-rectangular_grid_wind.nc") + + print("initial bounds:", ds['lon'].data.min(), + ds['lat'].data.min(), + ds['lon'].data.max(), + ds['lat'].data.max(), + ) + + bbox = (-0.5, 0, 0.5, 0.5) + + ds2 = ds.xsg.subset_bbox(bbox) + + assert ds2['lat'].size == 15 + assert ds2['lon'].size == 29 + + new_bounds = (ds2['lon'].data.min(), + ds2['lat'].data.min(), + ds2['lon'].data.max(), + ds2['lat'].data.max(), + ) + print("new bounds:", new_bounds) + assert new_bounds == bbox + +def test_decreasing_latitude(): + """ + Some datasets have the latitude or longitude decreasing: 10, 9, 8 etc. + e.g the NOAA GFS met model + + subsetting should still work + + """ + ds = xr.open_dataset(EXAMPLE_DATA / "rectangular_grid_decreasing.nc") + + print("initial bounds:", ds['lon'].data.min(), + ds['lat'].data.min(), + ds['lon'].data.max(), + ds['lat'].data.max(), + ) + + bbox = (-0.5, 0, 0.5, 0.5) + + ds2 = ds.xsg.subset_bbox(bbox) + + assert ds2['lat'].size == 15 + assert ds2['lon'].size == 29 + + new_bounds = (ds2['lon'].data.min(), + ds2['lat'].data.min(), + ds2['lon'].data.max(), + ds2['lat'].data.max(), + ) + print("new bounds:", new_bounds) + assert new_bounds == bbox + +def test_decreasing_coords(): + """ + Redundant with above, but already written ... + """ + print("\nTesting decreasing coordinates support...") + ds = create_synthetic_rectangular_grid_dataset(decreasing=True) + # assert RegularGrid.recognize(ds) + + # bbox: (min_lon, min_lat, max_lon, max_lat) + bbox = (-95, 35, -85, 45) + + subset = ds.xsg.subset_bbox(bbox) + print(f"Subset size: {subset.sizes}") + + # Check if subset has data + assert subset.sizes["lat"] > 0 + assert subset.sizes["lon"] > 0 + +def test_subset_polygon(): + """ + Not a complete test by any means, but the basics are there. + + NOTE: it doesn't test if the variables got subset corectly ... + + """ + ds = xr.open_dataset(EXAMPLE_DATA / "2D-rectangular_grid_wind.nc") + + print("initial bounds:", ds['lon'].data.min(), + ds['lat'].data.min(), + ds['lon'].data.max(), + ds['lat'].data.max(), + ) + + poly = [(-0.5, 0.0), (0.0, 0.5), (0.5, 0.5), (0.5, 0.0), (0, 0.0)] + # this poly has this bounding box: + # bbox = (-0.5, 0, 0.5, 0.5) + # so results should be the same as the bbox tests + + ds2 = ds.xsg.subset_polygon(poly) + + assert ds2['lat'].size == 15 + assert ds2['lon'].size == 29 + + new_bounds = (ds2['lon'].data.min(), + ds2['lat'].data.min(), + ds2['lon'].data.max(), + ds2['lat'].data.max(), + ) + print("new bounds:", new_bounds) + assert new_bounds == (-0.5, 0, 0.5, 0.5) -# coords = ds.xsg.coords - -# print(f"{coords=}") -# print(f"{ds.coords=}") - -# assert set(coords) == set( -# [ -# "lon", -# "lat", -# "lonc", -# "latc", -# "time", -# "siglay", -# "siglev", -# ] -# ) # def test_vertical_levels(): diff --git a/tests/test_grids/test_ugrid.py b/tests/test_grids/test_ugrid.py index 3bdcfd3..1c0e99e 100644 --- a/tests/test_grids/test_ugrid.py +++ b/tests/test_grids/test_ugrid.py @@ -12,12 +12,47 @@ import pytest import xarray as xr +from tests.conftest import RGRID_FILES, SGRID_FILES, UGRID_FILES from xarray_subset_grid import Selector from xarray_subset_grid.grids import ugrid +from xarray_subset_grid.grids.ugrid import UGrid -EXAMPLE_DATA = Path(__file__).parent.parent.parent / "docs" / "examples" / "example_data" +EXAMPLE_DATA = Path(__file__).parent.parent / "example_data" + +@pytest.mark.parametrize("test_file", UGRID_FILES[:3]) +def test_recognize(test_file): + """ + works for at least one file ... + """ + print("testing: ", test_file) + ds = xr.open_dataset(test_file) + try: + ds.cf.cf_roles["mesh_topology"][0] + except KeyError: # no mesh variable + # Hacky way to deal with non-conforming examples + # This should be in a config somewhere, or ?? + if 'tris' in ds: + grid_top = {'face_node_connectivity': 'tris', + 'node_coordinates': ('lon', 'lat') + } + elif 'nv' in ds: + grid_top = {'face_node_connectivity': 'nv', + 'node_coordinates': ('lon', 'lat') + } + ds = ugrid.assign_ugrid_topology(ds, **grid_top) + + assert UGrid.recognize(ds) + + +@pytest.mark.parametrize("test_file", RGRID_FILES + SGRID_FILES) +def test_recognize_not(test_file): + """ + should not recognize an SGrid + """ + ds = xr.open_dataset(test_file) + + assert not UGrid.recognize(ds) -TEST_FILE1 = EXAMPLE_DATA / "SFBOFS_subset1.nc" # SFBOFS_subset1.nc is a smallish subset of the SFBOFS FVCOM model @@ -218,7 +253,7 @@ # cell:standard_name = "cell number" ; # cell:long_name = "Mapping to original mesh cell number" ; -# topology for TEST_FILE1 +# topology for SFBOFS_subset1 grid_topology = { "node_coordinates": "lon lat", "face_node_connectivity": "nv", diff --git a/tests/test_visualization/test_mpl_plotting.py b/tests/test_visualization/test_mpl_plotting.py index 2564158..8e6cf31 100644 --- a/tests/test_visualization/test_mpl_plotting.py +++ b/tests/test_visualization/test_mpl_plotting.py @@ -21,7 +21,7 @@ pytestmark = pytest.mark.skip(reason="matplotlib is not installed") -EXAMPLE_DATA = Path(__file__).parent.parent.parent / "docs" / "examples" / "example_data" +EXAMPLE_DATA = Path(__file__).parent.parent / "example_data" OUTPUT_DIR = Path(__file__).parent / "output" diff --git a/xarray_subset_grid/accessor.py b/xarray_subset_grid/accessor.py index 2868b4e..648f21b 100644 --- a/xarray_subset_grid/accessor.py +++ b/xarray_subset_grid/accessor.py @@ -5,9 +5,23 @@ import xarray as xr from xarray_subset_grid.grid import Grid -from xarray_subset_grid.grids import FVCOMGrid, RegularGrid, RegularGrid2d, SELFEGrid, SGrid, UGrid - -_grid_impls = [FVCOMGrid, SELFEGrid, UGrid, SGrid, RegularGrid2d, RegularGrid] +from xarray_subset_grid.grids import ( + FVCOMGrid, + RegularGrid, + # @D version doesn't appear to be different ?? + # RegularGrid2d, + SELFEGrid, + SGrid, + UGrid, +) + +_grid_impls = [FVCOMGrid, + SELFEGrid, + UGrid, + SGrid, + # RegularGrid2d, + RegularGrid + ] def register_grid_impl(grid_impl: Grid, priority: int = 0): diff --git a/xarray_subset_grid/grids/regular_grid.py b/xarray_subset_grid/grids/regular_grid.py index ee49048..a8a70a8 100644 --- a/xarray_subset_grid/grids/regular_grid.py +++ b/xarray_subset_grid/grids/regular_grid.py @@ -17,31 +17,31 @@ from xarray_subset_grid.utils import ( normalize_bbox_x_coords, normalize_polygon_x_coords, - ray_tracing_numpy, ) +# class RegularGridPolygonSelector(Selector): +# """Polygon Selector for regular lat/lon grids.""" +# # with a regular grid, you have to select the full boudning box anyway +# # this this simply computes the bounding box, and used that -class RegularGridPolygonSelector(Selector): - """Polygon Selector for regular lat/lon grids.""" +# polygon: list[tuple[float, float]] | np.ndarray +# _polygon_mask: xr.DataArray - polygon: list[tuple[float, float]] | np.ndarray - _polygon_mask: xr.DataArray +# def __init__(self, polygon: list[tuple[float, float]] | np.ndarray, mask: xr.DataArray, +# name: str): +# super().__init__() +# self.name = name +# self.polygon = polygon +# self.polygon_mask = mask - def __init__( - self, polygon: list[tuple[float, float]] | np.ndarray, mask: xr.DataArray, name: str - ): - super().__init__() - self.name = name - self.polygon = polygon - self.polygon_mask = mask +# def select(self, ds: xr.Dataset) -> xr.Dataset: +# """Perform the selection on the dataset.""" +# ds_subset = ds.cf.isel( +# lon=self._polygon_mask, +# lat=self._polygon_mask, +# ) +# return ds_subset - def select(self, ds: xr.Dataset) -> xr.Dataset: - """Perform the selection on the dataset.""" - ds_subset = ds.cf.isel( - lon=self._polygon_mask, - lat=self._polygon_mask, - ) - return ds_subset class RegularGridBBoxSelector(Selector): @@ -58,25 +58,64 @@ def __init__(self, bbox: tuple[float, float, float, float]): self._latitude_selection = slice(bbox[1], bbox[3]) def select(self, ds: xr.Dataset) -> xr.Dataset: - """Perform the selection on the dataset.""" + """ + Perform the selection on the dataset. + """ + lat = ds[ds.cf.coordinates.get("latitude")[0]] + lon = ds[ds.cf.coordinates.get("longitude")[0]] + if np.all(np.diff(lat) < 0): + # swap the slice if the latitudes are decending + self._latitude_selection = slice(self._latitude_selection.stop, + self._latitude_selection.start) + # and np.all(np.diff(lon) > 0): + if np.all(np.diff(lon) < 0): + # swap the slice if the longitudes are decending + self._longitude_selection = slice(self._longitude_selection.stop, + self._longitude_selection.start) + return ds.cf.sel(lon=self._longitude_selection, lat=self._latitude_selection) +class RegularGridPolygonSelector(RegularGridBBoxSelector): + """Polygon Selector for regular lat/lon grids.""" + # with a regular grid, you have to select the full bounding box anyway + # this this simply computes the bounding box, and uses the same code. + + def __init__(self, polygon: list[tuple[float, float]] | np.ndarray): + polygon = np.asarray(polygon) + bbox = (polygon[:,0].min(), + polygon[:,1].min(), + polygon[:,0].max(), + polygon[:,1].max(), + ) + super().__init__(bbox=bbox) + class RegularGrid(Grid): """Grid implementation for regular lat/lng grids.""" - @staticmethod def recognize(ds: xr.Dataset) -> bool: - """Recognize if the dataset matches the given grid.""" + """ + Recognize if the dataset matches the given grid. + """ lat = ds.cf.coordinates.get("latitude", None) lon = ds.cf.coordinates.get("longitude", None) if lat is None or lon is None: return False + # choose first one -- valid assumption?? + lat = lat[0] + lon = lon[0] # Make sure the coordinates are 1D and match - lat_ndim = ds[lat[0]].ndim - lon_ndim = ds[lon[0]].ndim - return lat_ndim == lon_ndim and lon_ndim == 1 + if not (1 == ds[lat].ndim == ds[lon].ndim): + return False + + # make sure that at least one variable is using both the + # latitude and longitude dimensions + # (ugrids have both coordinates, but not both dimensions) + for var_name, var in ds.data_vars.items(): + if (lon in var.dims) and (lat in var.dims): + return True + return False @property def name(self) -> str: @@ -102,34 +141,30 @@ def data_vars(self, ds: xr.Dataset) -> set[str]: """ lat = ds.cf.coordinates["latitude"][0] lon = ds.cf.coordinates["longitude"][0] - return { - var - for var in ds.data_vars - if var not in {lat, lon} - and "latitude" in var.cf.coordinates - and "longitude" in var.cf.coordinates + data_vars = {var.name for var in ds.data_vars.values() + if var.name not in {lat, lon} + and "latitude" in var.cf.coordinates + and "longitude" in var.cf.coordinates } + return data_vars + + def compute_polygon_subset_selector(self, + ds: xr.Dataset, + polygon: list[tuple[float, float]], + ) -> Selector: - def compute_polygon_subset_selector( - self, ds: xr.Dataset, polygon: list[tuple[float, float]], name: str = None - ) -> Selector: - lat = ds.cf["latitude"] - lon = ds.cf["longitude"] + polygon = np.asarray(polygon) + lon = ds.cf["longitude"].data - x = np.array(lon.flat) - polygon = normalize_polygon_x_coords(x, polygon) - polygon_mask = ray_tracing_numpy(x, lat.flat, polygon).reshape(lon.shape) + polygon = normalize_polygon_x_coords(lon, polygon) - selector = RegularGridPolygonSelector( - polygon=polygon, mask=polygon_mask, name=name or "selector" - ) + selector = RegularGridPolygonSelector(polygon=polygon) return selector - def compute_bbox_subset_selector( - self, - ds: xr.Dataset, - bbox: tuple[float, float, float, float], - ) -> Selector: + def compute_bbox_subset_selector(self, + ds: xr.Dataset, + bbox: tuple[float, float, float, float], + ) -> Selector: bbox = normalize_bbox_x_coords(ds.cf["longitude"].values, bbox) selector = RegularGridBBoxSelector(bbox) return selector diff --git a/xarray_subset_grid/grids/regular_grid_2d.py b/xarray_subset_grid/grids/regular_grid_2d.py index 56ef57a..38445b6 100644 --- a/xarray_subset_grid/grids/regular_grid_2d.py +++ b/xarray_subset_grid/grids/regular_grid_2d.py @@ -1,3 +1,9 @@ +# 2D and 3D should share a lot of code +# +# do we need a separate class for this? +# This doesn't appear, right now, to check for depth to dedermine if it's 2D + + import numpy as np import xarray as xr @@ -10,9 +16,10 @@ class RegularGrid2dSelector(Selector): polygon: list[tuple[float, float]] | np.ndarray _subset_mask: xr.DataArray - def __init__( - self, polygon: list[tuple[float, float]] | np.ndarray, subset_mask: xr.DataArray, name: str - ): + def __init__(self, + polygon: list[tuple[float, float]] | np.ndarray, + subset_mask: xr.DataArray, + name: str): super().__init__() self.name = name self.polygon = polygon @@ -72,13 +79,13 @@ def data_vars(self, ds: xr.Dataset) -> set[str]: """ lat = ds.cf.coordinates["latitude"][0] lon = ds.cf.coordinates["longitude"][0] - return { - var - for var in ds.data_vars - if var not in {lat, lon} - and "latitude" in var.cf.coordinates - and "longitude" in var.cf.coordinates + data_vars = {var.name for var in ds.data_vars.values() + if var.name not in {lat, lon} + and "latitude" in var.cf.coordinates + and "longitude" in var.cf.coordinates } + return data_vars + def compute_polygon_subset_selector( self, ds: xr.Dataset, polygon: list[tuple[float, float]], name: str = None diff --git a/xarray_subset_grid/grids/ugrid.py b/xarray_subset_grid/grids/ugrid.py index b2b0a8e..4428581 100644 --- a/xarray_subset_grid/grids/ugrid.py +++ b/xarray_subset_grid/grids/ugrid.py @@ -116,7 +116,7 @@ def recognize(ds: xr.Dataset) -> bool: try: mesh_key = ds.cf.cf_roles["mesh_topology"][0] mesh = ds[mesh_key] - except Exception: + except KeyError: return False return mesh.attrs.get("face_node_connectivity") is not None @@ -360,7 +360,7 @@ def assign_ugrid_topology( ``` grid_topology = {'node_coordinates': ('lon', 'lat'), 'face_node_connectivity': 'nv', - 'node_coordinates': ('lon', 'lat'), + 'edge_coordinates': ('lon', 'lat'), 'face_coordinates': ('lonc', 'latc'), } diff --git a/xarray_subset_grid/utils.py b/xarray_subset_grid/utils.py index 1b0b535..150d3c1 100644 --- a/xarray_subset_grid/utils.py +++ b/xarray_subset_grid/utils.py @@ -1,8 +1,11 @@ import warnings +from datetime import datetime import cf_xarray # noqa +import cftime import numpy as np import xarray as xr +from dateutil.parser import parse as parsetime def normalize_polygon_x_coords(x, poly): @@ -151,3 +154,22 @@ def compute_2d_subset_mask( polygon_mask = np.where(polygon_mask > 1, True, False) return xr.DataArray(polygon_mask, dims=mask_dims) + +def asdatetime(dt): + """ + makes sure the input is a datetime.datetime object + + if it already is, it will be passed through. + + If not it will attempt to parse a string to make a datetime object. + + None will also be passed through silently + """ + if dt is None: + return dt + # if not isinstance(dt, datetime): + if not isinstance(dt, datetime | cftime.datetime): + # assume it's an iso string, or something that dateutils can parse. + return parsetime(dt, ignoretz=True) + else: + return dt