diff --git a/dascore/core/attrs.py b/dascore/core/attrs.py index 93b5c9c1..7c56e09a 100644 --- a/dascore/core/attrs.py +++ b/dascore/core/attrs.py @@ -239,8 +239,12 @@ def update(self, **kwargs) -> Self: new_coords = dict(out.get("coords", {})) if isinstance(passed_in_coords, dc.CoordManager): new_coords = passed_in_coords.model_dump(exclude_unset=True) + out["dims"] = passed_in_coords.dims - coord_info, attr_info = separate_coord_info(kwargs, dims=self.dim_tuple) + dims = out.get("dims", self.dim_tuple) + if isinstance(dims, str): + dims = tuple(dims.split(",")) if dims else tuple() + coord_info, attr_info = separate_coord_info(kwargs, dims=tuple(dims)) out.update(attr_info) # Iterate the coordinate information and update. for name, coord_dict in coord_info.items(): diff --git a/dascore/core/coordmanager.py b/dascore/core/coordmanager.py index 8b518625..e748f495 100644 --- a/dascore/core/coordmanager.py +++ b/dascore/core/coordmanager.py @@ -818,7 +818,9 @@ def ndim(self): def validate_data(self, data): """Ensure data conforms to coordinates.""" - data = np.asarray([]) if data is None else data + data = np.asarray([]) if data is None else np.asarray(data) + if not self.dims and data.shape == (): + return data if self.shape != data.shape: msg = ( f"Data array has a shape of {data.shape} which doesnt match " diff --git a/tests/test_core/test_coordmanager.py b/tests/test_core/test_coordmanager.py index 93e47c66..94b31126 100644 --- a/tests/test_core/test_coordmanager.py +++ b/tests/test_core/test_coordmanager.py @@ -169,6 +169,19 @@ def test_init_list_range(self): out = get_coord_manager(input_dict, dims=list(input_dict)) assert set(out.dims) == set(input_dict) + def test_scalar_data_ok_for_no_dims(self): + """Ensure scalar data validates for coordinate managers with no dims.""" + cm = get_coord_manager() + out = cm.validate_data(np.array(1.0)) + assert out.shape == () + + def test_python_scalar_data_ok_for_no_dims(self): + """Ensure python scalars are converted and validate with no dims.""" + cm = get_coord_manager() + out = cm.validate_data(1.0) + assert isinstance(out, np.ndarray) + assert out.shape == () + def test_bad_datashape_raises(self, cm_basic): """Ensure a bad datashape raises.""" match = "match the coordinate manager shape" diff --git a/tests/test_proc/test_proc_coords.py b/tests/test_proc/test_proc_coords.py index 9704b3e6..0f2b08a5 100644 --- a/tests/test_proc/test_proc_coords.py +++ b/tests/test_proc/test_proc_coords.py @@ -673,6 +673,14 @@ def test_coord_summary(self, flat_patch): if coords: assert set(coords) == set(patch.coords.coord_map) + def test_squeeze_all_dims_len_one(self, random_patch): + """Squeezing a (1, 1) patch should produce scalar patch with no dims.""" + patch = random_patch.aggregate(dim=None, method="mean", dim_reduce="empty") + out = patch.squeeze() + assert out.data.shape == () + assert out.dims == () + assert out.attrs.dim_tuple == () + class TestGetCoord: """Tests for the get_coord convenience function."""