diff --git a/tests/test_grids/test_regular_grid.py b/tests/test_grids/test_regular_grid.py index 42bd5cd..b346561 100644 --- a/tests/test_grids/test_regular_grid.py +++ b/tests/test_grids/test_regular_grid.py @@ -78,6 +78,36 @@ def create_synthetic_rectangular_grid_dataset(decreasing=False): return ds +def create_synthetic_global_rectangular_grid_dataset(*, use_360=True, decreasing_lon=False): + """Create a synthetic global regular-grid dataset for longitude wrap tests.""" + lat = np.linspace(-10, 10, 21) + if use_360: + lon = np.arange(0, 360) + else: + lon = np.arange(-180, 180) + + if decreasing_lon: + lon = lon[::-1] + + data = np.random.rand(lat.size, lon.size) + + ds = xr.Dataset( + data_vars={ + "temp": (("lat", "lon"), data), + }, + coords={ + "lat": lat, + "lon": lon, + }, + ) + + ds.lat.attrs = {"standard_name": "latitude", "units": "degrees_north"} + ds.lon.attrs = {"standard_name": "longitude", "units": "degrees_east"} + ds.temp.attrs = {"standard_name": "sea_water_temperature"} + + return ds + + def test_grid_vars(): """ Check if the grid vars are defined properly @@ -282,6 +312,53 @@ def test_subset_polygon(): assert new_bounds == (-0.5, 0, 0.5, 0.5) +def test_subset_bbox_wrap_prime_meridian_on_360_grid(): + ds = create_synthetic_global_rectangular_grid_dataset(use_360=True) + + ds_subset = ds.xsg.subset_bbox((-10, -5, 10, 5)) + + assert ds_subset["lat"].size == 11 + assert ds_subset["lon"].size == 21 + lon_values = ds_subset["lon"].values + assert lon_values.min() == 0 + assert lon_values.max() == 359 + assert set(range(0, 11)).issubset(set(lon_values.tolist())) + assert set(range(350, 360)).issubset(set(lon_values.tolist())) + + +def test_subset_bbox_wrap_dateline_on_180_grid(): + ds = create_synthetic_global_rectangular_grid_dataset(use_360=False) + + ds_subset = ds.xsg.subset_bbox((170, -5, -170, 5)) + + assert ds_subset["lat"].size == 11 + assert ds_subset["lon"].size == 21 + lon_values = ds_subset["lon"].values + assert lon_values.min() == -180 + assert lon_values.max() == 179 + assert set(range(170, 180)).issubset(set(lon_values.tolist())) + assert set(range(-180, -169)).issubset(set(lon_values.tolist())) + + +def test_subset_bbox_wrap_prime_meridian_descending_lon(): + ds = create_synthetic_global_rectangular_grid_dataset(use_360=True, decreasing_lon=True) + + ds_subset = ds.xsg.subset_bbox((-10, -5, 10, 5)) + + assert ds_subset["lat"].size == 11 + assert ds_subset["lon"].size == 21 + lon_values = ds_subset["lon"].values + assert set(range(0, 11)).issubset(set(lon_values.tolist())) + assert set(range(350, 360)).issubset(set(lon_values.tolist())) + + +def test_subset_bbox_raises_for_span_ge_half_earth(): + ds = create_synthetic_global_rectangular_grid_dataset(use_360=False) + + with pytest.raises(ValueError, match="less than half-way around the earth"): + ds.xsg.subset_bbox((-170, -5, 170, 5)) + + # def test_vertical_levels(): # ds = xr.open_dataset(EXAMPLE_DATA / "SFBOFS_subset1.nc") # ds = ugrid.assign_ugrid_topology(ds, **grid_topology) diff --git a/tests/test_utils.py b/tests/test_utils.py index 6b64a2c..8e8a2df 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -93,6 +93,8 @@ def test_normalize_x_coords(lons, poly, norm_poly): ([-85, -84, -83, 10], bbox2_360, bbox2_180), # x1 ([60, 45, 85, 70], bbox2_360, bbox2_360), # x2 ([190, 200, 220, 250, 260], bbox2_360, bbox2_360), # x3 + ([0, 90, 180, 270, 359], [-10, 39, 10, 41], [350, 39, 10, 41]), + ([-180, -90, 0, 90, 179], [350, 39, 10, 41], [-10, 39, 10, 41]), ], ) def test_normalize_x_coords_bbox(lons, bbox, norm_bbox): diff --git a/xarray_subset_grid/grids/regular_grid.py b/xarray_subset_grid/grids/regular_grid.py index 9b01aa8..b171823 100644 --- a/xarray_subset_grid/grids/regular_grid.py +++ b/xarray_subset_grid/grids/regular_grid.py @@ -47,34 +47,76 @@ class RegularGridBBoxSelector(Selector): """Selector for regular lat/lng grids.""" bbox: tuple[float, float, float, float] - _longitude_selection: slice + _longitude_bounds: tuple[float, float] _latitude_selection: slice def __init__(self, bbox: tuple[float, float, float, float]): super().__init__() self.bbox = bbox - self._longitude_selection = slice(bbox[0], bbox[2]) + self._longitude_bounds = (bbox[0], bbox[2]) self._latitude_selection = slice(bbox[1], bbox[3]) + def _longitude_span(self) -> float: + west, east = self._longitude_bounds + return (east - west) % 360 + + def _validate_longitude_span(self): + span = self._longitude_span() + if np.isclose(span, 0.0): + raise ValueError( + "Invalid longitude bounds: west and east bounds " + "cannot define a zero-width selection" + ) + if span >= 180.0 and not np.isclose(span, 180.0): + raise ValueError( + "Invalid longitude bounds: subsetting bounds " + "must span less than half-way around the earth" + ) + if np.isclose(span, 180.0): + raise ValueError( + "Invalid longitude bounds: subsetting bounds " + "must span less than half-way around the earth" + ) + + def _build_longitude_slices(self, lon: xr.DataArray) -> list[slice]: + west, east = self._longitude_bounds + lon_min = float(lon.min().values) + lon_max = float(lon.max().values) + + if west <= east: + longitude_slices = [slice(west, east)] + else: + longitude_slices = [slice(west, lon_max), slice(lon_min, east)] + + if np.all(np.diff(lon) < 0): + longitude_slices = [slice(sl.stop, sl.start) for sl in longitude_slices] + + return longitude_slices + def select(self, ds: xr.Dataset) -> xr.Dataset: """ Perform the selection on the dataset. """ + self._validate_longitude_span() + lat = ds[ds.cf.coordinates.get("latitude")[0]] lon = ds[ds.cf.coordinates.get("longitude")[0]] + + latitude_selection = self._latitude_selection if np.all(np.diff(lat) < 0): # swap the slice if the latitudes are descending - self._latitude_selection = slice( - self._latitude_selection.stop, self._latitude_selection.start - ) - # and np.all(np.diff(lon) > 0): - if np.all(np.diff(lon) < 0): - # swap the slice if the longitudes are descending - self._longitude_selection = slice( - self._longitude_selection.stop, self._longitude_selection.start - ) + latitude_selection = slice(latitude_selection.stop, latitude_selection.start) + + longitude_selections = self._build_longitude_slices(lon) + selections = [ + ds.cf.sel(lon=lon_sel, lat=latitude_selection) for lon_sel in longitude_selections + ] + + if len(selections) == 1: + return selections[0] - return ds.cf.sel(lon=self._longitude_selection, lat=self._latitude_selection) + lon_dim = lon.dims[0] + return xr.concat(selections, dim=lon_dim) class RegularGridPolygonSelector(RegularGridBBoxSelector): diff --git a/xarray_subset_grid/utils.py b/xarray_subset_grid/utils.py index e004525..c50c5b4 100644 --- a/xarray_subset_grid/utils.py +++ b/xarray_subset_grid/utils.py @@ -51,12 +51,16 @@ def normalize_bbox_x_coords(x, bbox): bbox_x_min, bbox_x_max = bbox[0], bbox[2] - if x_max > 180 and bbox_x_max < 0: - bbox_x_min += 360 - bbox_x_max += 360 - elif x_min < 0 and bbox_x_max > 180: - bbox_x_min -= 360 - bbox_x_max -= 360 + if x_max > 180: + if bbox_x_min < 0: + bbox_x_min += 360 + if bbox_x_max < 0: + bbox_x_max += 360 + elif x_min < 0: + if bbox_x_min > 180: + bbox_x_min -= 360 + if bbox_x_max > 180: + bbox_x_max -= 360 return bbox_x_min, bbox[1], bbox_x_max, bbox[3]