From 1e1c892b38cbefcf78dfc42394d285c63f227b94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20de=20Menten?= Date: Thu, 4 Dec 2025 15:12:40 +0100 Subject: [PATCH] FIX: fixed array[slice_group_from_an_incompatible_axis] (fixes #1146 and #1117) probably also fixed a few edge cases in Axis.index() --- doc/source/changes/version_0_35.rst.inc | 22 +++++ larray/core/array.py | 20 ++-- larray/core/axis.py | 116 +++++++++++++++--------- larray/tests/test_array.py | 27 +++++- larray/tests/test_axis.py | 10 ++ 5 files changed, 143 insertions(+), 52 deletions(-) diff --git a/doc/source/changes/version_0_35.rst.inc b/doc/source/changes/version_0_35.rst.inc index ffd9d2187..bc85a1c46 100644 --- a/doc/source/changes/version_0_35.rst.inc +++ b/doc/source/changes/version_0_35.rst.inc @@ -112,6 +112,28 @@ Miscellaneous improvements Fixes ^^^^^ +* fixed array[slice_group_from_an_incompatible_axis] and + array.sum(slice_group_from_an_incompatible_axis) (closes :issue:`1146` + and :issue:`1117`). + It used to evaluate the slice on the array axis instead of first evaluating + the slice on the axis it was created on, then take the corresponding labels + from the array axis. + + >>> arr = ndtest(3) + >>> arr + a a0 a1 a2 + 0 1 2 + >>> other_axis_a = Axis('a=a0,a1') + >>> group = other_axis_a[:] + >>> print(group) + ['a0' 'a1'] + >>> arr[group] # <-- before + a a0 a1 a2 + 0 1 2 + >>> arr[group] # <-- now + a a0 a1 + 0 1 + * fixed error message when trying to take a subset of an array with an array key which has ndim > 1 and some bad values in the key. The message was also improved (see the issue for details). Closes :issue:`1134`. diff --git a/larray/core/array.py b/larray/core/array.py index 501a65d4d..13be5152e 100644 --- a/larray/core/array.py +++ b/larray/core/array.py @@ -2140,6 +2140,9 @@ def sort_values(self, key=None, axis=None, ascending=True) -> 'Array': # FWIW, using .data, I get IGroup([1, 2, 0], axis='nat'), which works. sorter = axis.i[indicesofsorted.data] res = self[sorter] + # res has its axis in a different order than the original axis + # so we need this line to reverse the order below if not ascending + axis = res.axes[axis] else: res = self.combine_axes() indicesofsorted = np.argsort(res.data) @@ -2799,22 +2802,27 @@ def _group_aggregate(self, op, items, keepaxes=False, out=None, **kwargs) -> 'Ar if isinstance(item, tuple): assert all(isinstance(g, Group) for g in item) groups = item - axis = groups[0].axis + group_axis = groups[0].axis # they should all have the same axis (this is already checked # in _prepare_aggregate though) - assert all(g.axis.equals(axis) for g in groups[1:]) + assert all(g.axis.equals(group_axis) for g in groups[1:]) killaxis = False else: # item is in fact a single group assert isinstance(item, Group), type(item) groups = (item,) - axis = item.axis + group_axis = item.axis # it is easier to kill the axis after the fact killaxis = True - axis, axis_idx = res.axes[axis], res.axes.index(axis) + axis_idx = res.axes.index(group_axis) + res_axis = res.axes[axis_idx] + assert group_axis.equals(res_axis) + # potentially translate axis reference to real axes - groups = tuple(g.with_axis(axis) for g in groups) + # with_axis is correct because we already checked + # that g.axis.equals(axis) + groups = tuple(g.with_axis(res_axis) for g in groups) res_shape[axis_idx] = len(groups) # XXX: this code is fragile. I wonder if there isn't a way to ask the function what kind of dtype/shape it @@ -2866,7 +2874,7 @@ def _group_aggregate(self, op, items, keepaxes=False, out=None, **kwargs) -> 'Ar # We do NOT modify the axis name (eg append "_agg" or "*") even though this creates a new axis that is # independent from the original one because the original name is what users will want to use to access # that axis (eg in .filter kwargs) - res_axes[axis_idx] = Axis(groups, axis.name) + res_axes[axis_idx] = Axis(groups, res_axis.name) if isinstance(res_data, np.ndarray): res = Array(res_data, res_axes) diff --git a/larray/core/axis.py b/larray/core/axis.py index 9b18c4905..37e5116fd 100644 --- a/larray/core/axis.py +++ b/larray/core/axis.py @@ -922,32 +922,41 @@ def index(self, key) -> Union[int, np.ndarray, slice]: """ mapping = self._mapping - if isinstance(key, Group) and key.axis is not self and key.axis is not None: - try: - # XXX: this is potentially very expensive if key.key is an array or list and should be tried as a last - # resort - potential_tick = _to_tick(key) - - # avoid matching 0 against False or 0.0, note that None has object dtype and so always pass this test - if self._is_key_type_compatible(potential_tick): - try: - res_idx = mapping[potential_tick] - if potential_tick != key.key: - # only warn if no KeyError was raised (potential_tick is in mapping) - msg = "Using a Group object which was used to create an aggregate to " \ - "target its aggregated label is deprecated. " \ - "Please use the aggregated label directly instead. " \ - f"In this case, you should use {potential_tick!r} instead of " \ - f"using {key!r}." - # let us hope the stacklevel does not vary by codepath - warnings.warn(msg, FutureWarning, stacklevel=8) - return res_idx - except KeyError: - pass - # we must catch TypeError because key might not be hashable (eg slice) - # IndexError is for when mapping is an ndarray - except (KeyError, TypeError, IndexError): - pass + if isinstance(key, Group): + if key.axis is self: + if isinstance(key, IGroup): + return key.key + else: + # at this point we do not care about the axis nor the name + key = key.key + elif key.axis is not None: + try: + # TODO: remove this as it is potentially very expensive + # if key.key is an array or list and should be tried + # as a last resort + potential_tick = _to_tick(key) + + # avoid matching 0 against False or 0.0, note that None has + # object dtype and so always pass this test + if self._is_key_type_compatible(potential_tick): + try: + res_idx = mapping[potential_tick] + if potential_tick != key.key: + # only warn if no KeyError was raised (potential_tick is in mapping) + msg = "Using a Group object which was used to create an aggregate to " \ + "target its aggregated label is deprecated. " \ + "Please use the aggregated label directly instead. " \ + f"In this case, you should use {potential_tick!r} instead of " \ + f"using {key!r}." + # let us hope the stacklevel does not vary by codepath + warnings.warn(msg, FutureWarning, stacklevel=8) + return res_idx + except KeyError: + pass + # we must catch TypeError because key might not be hashable (eg slice) + # IndexError is for when mapping is an ndarray + except (KeyError, TypeError, IndexError): + pass if isinstance(key, str): # try the key as-is to allow getting at ticks with special characters (",", ":", ...) @@ -961,24 +970,35 @@ def index(self, key) -> Union[int, np.ndarray, slice]: except (KeyError, TypeError, IndexError): pass - # transform "specially formatted strings" for slices, lists, LGroup and IGroup to actual objects + # transform "specially formatted strings" for slices, lists, LGroup + # and IGroup to actual objects key = _to_key(key) if isinstance(key, range): key = list(key) - - # this can happen when key was passed as a string and converted to a Group via _to_key - if isinstance(key, Group) and isinstance(key.axis, str) and key.axis != self.name: - raise KeyError(key) - - if isinstance(key, IGroup): - if isinstance(key.axis, Axis): - assert key.axis is self - return key.key - - if isinstance(key, LGroup): - # at this point we do not care about the axis nor the name - key = key.key + elif isinstance(key, Group): + key_axis = key.axis + if isinstance(key_axis, str): + if key_axis != self.name: + raise KeyError(key) + elif isinstance(key_axis, AxisReference): + if key_axis.name != self.name: + raise KeyError(key) + elif isinstance(key_axis, Axis): # we know it is not self + # IGroups will be retargeted to LGroups + key = key.retarget_to(self) + elif isinstance(key_axis, int): + raise TypeError('Axis.index() does not support Group keys with ' + 'integer axis') + else: + assert key_axis is None + # an IGroup can still exist at this point if the key was an IGroup + # with a compatible axis (string or AxisReference axis with the + # correct name or Axis object equal to self) + if isinstance(key, IGroup): + return key.key + else: + key = key.key if isinstance(key, slice): start = mapping[key.start] if key.start is not None else None @@ -1915,7 +1935,8 @@ def __contains__(self, key) -> bool: if isinstance(key, int): return -len(self) <= key < len(self) elif isinstance(key, Axis): - # the special case is just a performance optimization to avoid scanning through the whole list + # the special case is just a performance optimization to avoid + # scanning through the whole list if key.name is not None: return key.name in self._map else: @@ -2808,7 +2829,7 @@ def _guess_axis(self, axis_key): # we have axis information but not necessarily an Axis object from self real_axis = self[group_axis] if group_axis is not real_axis: - axis_key = axis_key.with_axis(real_axis) + axis_key = axis_key.retarget_to(real_axis) return axis_key real_axis, axis_pos_key = self._translate_nice_key(axis_key) @@ -2828,6 +2849,7 @@ def _translate_axis_key_chunk(self, axis_key): (axis, indices) Indices group with a valid axis (from self) """ + orig_key = axis_key axis_key = remove_nested_groups(axis_key) if isinstance(axis_key, IGroup): @@ -2852,11 +2874,16 @@ def _translate_axis_key_chunk(self, axis_key): # labels but known axis if isinstance(axis_key, LGroup) and axis_key.axis is not None: try: - real_axis = self[axis_key.axis] + key_axis = axis_key.axis + real_axis = self[key_axis] + if isinstance(key_axis, (AxisReference, int)): + # this is one of the rare cases where with_axis is correct ! + axis_key = axis_key.with_axis(real_axis) + try: axis_pos_key = real_axis.index(axis_key) except KeyError: - raise ValueError(f"{axis_key!r} is not a valid label for the {real_axis.name!r} axis " + raise ValueError(f"{orig_key!r} is not a valid label for the {real_axis.name!r} axis " f"with labels: {', '.join(repr(label) for label in real_axis.labels)}") return real_axis, axis_pos_key except KeyError: @@ -3889,6 +3916,7 @@ def align_axis_collections(axis_collections, join='outer', axes=None): class AxisReference(ABCAxisReference, ExprNode, Axis): def __init__(self, name): + assert isinstance(name, (int, str)) self.name = name self._labels = None self._iswildcard = False diff --git a/larray/tests/test_array.py b/larray/tests/test_array.py index 050a4692a..d90556c15 100644 --- a/larray/tests/test_array.py +++ b/larray/tests/test_array.py @@ -560,6 +560,17 @@ def test_getitem(array): _ = array[bad[1, 2], a[3, 4]] +def test_getitem_group_from_another_axis(): + # using slice Group from an axis not present, we must retarget the group + arr = ndtest(3) + a2 = Axis('a=a0,a1') + + # issue #1146 + expected = ndtest(2) + res = arr[a2[:]] + assert_larray_equal(res, expected) + + def test_getitem_abstract_axes(array): raw = array.data a, b, c, d = array.axes @@ -1130,8 +1141,6 @@ def test_getitem_single_larray_key_guess(): _ = arr[key] - - def test_getitem_multiple_larray_key_guess(): a, b, c, d, e = ndtest((2, 3, 2, 3, 2)).axes arr = ndtest((a, b)) @@ -2161,6 +2170,20 @@ def test_group_agg_label_group(array): res = array.sum(a, c).sum((g1, g2, g3, g_all)) assert res.shape == (4, 6) + # d) group aggregate using a group from another axis + # 1) LGroup + array = ndtest(3) + smaller_a_axis = Axis('a=a0,a1') + group = smaller_a_axis[:] + res = array.sum(group) + assert res == 1 + + # 2) IGroup + group = Axis("a=a1,a0").i[0] # targets a1 + assert array[group] == 1 + res = array.sum(group) + assert res == 1 + def test_group_agg_label_group_no_axis(array): a, b, c, d = array.axes diff --git a/larray/tests/test_axis.py b/larray/tests/test_axis.py index 108c61ae0..6499b3c9a 100644 --- a/larray/tests/test_axis.py +++ b/larray/tests/test_axis.py @@ -123,6 +123,16 @@ def test_index(): assert a.index('a1') == 1 assert a.index('a1 >> A1') == 1 + time = Axis([2007, 2009], 'time') + res = time.index(time.i[1]) + assert res == 1 + + res = time.index(X.time.i[1]) + assert res == 1 + + res = time.index('time.i[1]') + assert res == 1 + def test_astype(): arr = ndtest(Axis('time=2015..2020,total')).drop('total')