Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions tests/test_grids/test_regular_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,36 @@ def create_synthetic_rectangular_grid_dataset(decreasing=False):
return ds


def create_synthetic_global_rectangular_grid_dataset(*, use_360=True, decreasing_lon=False):
"""Create a synthetic global regular-grid dataset for longitude wrap tests."""
lat = np.linspace(-10, 10, 21)
if use_360:
lon = np.arange(0, 360)
else:
lon = np.arange(-180, 180)

if decreasing_lon:
lon = lon[::-1]

data = np.random.rand(lat.size, lon.size)

ds = xr.Dataset(
data_vars={
"temp": (("lat", "lon"), data),
},
coords={
"lat": lat,
"lon": lon,
},
)

ds.lat.attrs = {"standard_name": "latitude", "units": "degrees_north"}
ds.lon.attrs = {"standard_name": "longitude", "units": "degrees_east"}
ds.temp.attrs = {"standard_name": "sea_water_temperature"}

return ds


def test_grid_vars():
"""
Check if the grid vars are defined properly
Expand Down Expand Up @@ -282,6 +312,53 @@ def test_subset_polygon():
assert new_bounds == (-0.5, 0, 0.5, 0.5)


def test_subset_bbox_wrap_prime_meridian_on_360_grid():
ds = create_synthetic_global_rectangular_grid_dataset(use_360=True)

ds_subset = ds.xsg.subset_bbox((-10, -5, 10, 5))

assert ds_subset["lat"].size == 11
assert ds_subset["lon"].size == 21
lon_values = ds_subset["lon"].values
assert lon_values.min() == 0
assert lon_values.max() == 359
assert set(range(0, 11)).issubset(set(lon_values.tolist()))
assert set(range(350, 360)).issubset(set(lon_values.tolist()))


def test_subset_bbox_wrap_dateline_on_180_grid():
ds = create_synthetic_global_rectangular_grid_dataset(use_360=False)

ds_subset = ds.xsg.subset_bbox((170, -5, -170, 5))

assert ds_subset["lat"].size == 11
assert ds_subset["lon"].size == 21
lon_values = ds_subset["lon"].values
assert lon_values.min() == -180
assert lon_values.max() == 179
assert set(range(170, 180)).issubset(set(lon_values.tolist()))
assert set(range(-180, -169)).issubset(set(lon_values.tolist()))


def test_subset_bbox_wrap_prime_meridian_descending_lon():
ds = create_synthetic_global_rectangular_grid_dataset(use_360=True, decreasing_lon=True)

ds_subset = ds.xsg.subset_bbox((-10, -5, 10, 5))

assert ds_subset["lat"].size == 11
assert ds_subset["lon"].size == 21
lon_values = ds_subset["lon"].values
assert set(range(0, 11)).issubset(set(lon_values.tolist()))
assert set(range(350, 360)).issubset(set(lon_values.tolist()))


def test_subset_bbox_raises_for_span_ge_half_earth():
ds = create_synthetic_global_rectangular_grid_dataset(use_360=False)

with pytest.raises(ValueError, match="less than half-way around the earth"):
ds.xsg.subset_bbox((-170, -5, 170, 5))


# def test_vertical_levels():
# ds = xr.open_dataset(EXAMPLE_DATA / "SFBOFS_subset1.nc")
# ds = ugrid.assign_ugrid_topology(ds, **grid_topology)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def test_normalize_x_coords(lons, poly, norm_poly):
([-85, -84, -83, 10], bbox2_360, bbox2_180), # x1
([60, 45, 85, 70], bbox2_360, bbox2_360), # x2
([190, 200, 220, 250, 260], bbox2_360, bbox2_360), # x3
([0, 90, 180, 270, 359], [-10, 39, 10, 41], [350, 39, 10, 41]),
([-180, -90, 0, 90, 179], [350, 39, 10, 41], [-10, 39, 10, 41]),
],
)
def test_normalize_x_coords_bbox(lons, bbox, norm_bbox):
Expand Down
66 changes: 54 additions & 12 deletions xarray_subset_grid/grids/regular_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,34 +47,76 @@ class RegularGridBBoxSelector(Selector):
"""Selector for regular lat/lng grids."""

bbox: tuple[float, float, float, float]
_longitude_selection: slice
_longitude_bounds: tuple[float, float]
_latitude_selection: slice

def __init__(self, bbox: tuple[float, float, float, float]):
super().__init__()
self.bbox = bbox
self._longitude_selection = slice(bbox[0], bbox[2])
self._longitude_bounds = (bbox[0], bbox[2])
self._latitude_selection = slice(bbox[1], bbox[3])

def _longitude_span(self) -> float:
west, east = self._longitude_bounds
return (east - west) % 360

def _validate_longitude_span(self):
span = self._longitude_span()
if np.isclose(span, 0.0):
raise ValueError(
"Invalid longitude bounds: west and east bounds "
"cannot define a zero-width selection"
)
if span >= 180.0 and not np.isclose(span, 180.0):
raise ValueError(
"Invalid longitude bounds: subsetting bounds "
"must span less than half-way around the earth"
)
if np.isclose(span, 180.0):
raise ValueError(
"Invalid longitude bounds: subsetting bounds "
"must span less than half-way around the earth"
)

def _build_longitude_slices(self, lon: xr.DataArray) -> list[slice]:
west, east = self._longitude_bounds
lon_min = float(lon.min().values)
lon_max = float(lon.max().values)

if west <= east:
longitude_slices = [slice(west, east)]
else:
longitude_slices = [slice(west, lon_max), slice(lon_min, east)]

if np.all(np.diff(lon) < 0):
longitude_slices = [slice(sl.stop, sl.start) for sl in longitude_slices]

return longitude_slices

def select(self, ds: xr.Dataset) -> xr.Dataset:
"""
Perform the selection on the dataset.
"""
self._validate_longitude_span()

lat = ds[ds.cf.coordinates.get("latitude")[0]]
lon = ds[ds.cf.coordinates.get("longitude")[0]]

latitude_selection = self._latitude_selection
if np.all(np.diff(lat) < 0):
# swap the slice if the latitudes are descending
self._latitude_selection = slice(
self._latitude_selection.stop, self._latitude_selection.start
)
# and np.all(np.diff(lon) > 0):
if np.all(np.diff(lon) < 0):
# swap the slice if the longitudes are descending
self._longitude_selection = slice(
self._longitude_selection.stop, self._longitude_selection.start
)
latitude_selection = slice(latitude_selection.stop, latitude_selection.start)

longitude_selections = self._build_longitude_slices(lon)
selections = [
ds.cf.sel(lon=lon_sel, lat=latitude_selection) for lon_sel in longitude_selections
]

if len(selections) == 1:
return selections[0]

return ds.cf.sel(lon=self._longitude_selection, lat=self._latitude_selection)
lon_dim = lon.dims[0]
return xr.concat(selections, dim=lon_dim)


class RegularGridPolygonSelector(RegularGridBBoxSelector):
Expand Down
16 changes: 10 additions & 6 deletions xarray_subset_grid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,16 @@ def normalize_bbox_x_coords(x, bbox):

bbox_x_min, bbox_x_max = bbox[0], bbox[2]

if x_max > 180 and bbox_x_max < 0:
bbox_x_min += 360
bbox_x_max += 360
elif x_min < 0 and bbox_x_max > 180:
bbox_x_min -= 360
bbox_x_max -= 360
if x_max > 180:
if bbox_x_min < 0:
bbox_x_min += 360
if bbox_x_max < 0:
bbox_x_max += 360
elif x_min < 0:
if bbox_x_min > 180:
bbox_x_min -= 360
if bbox_x_max > 180:
bbox_x_max -= 360

return bbox_x_min, bbox[1], bbox_x_max, bbox[3]

Expand Down