Skip to content
Merged

035more #1149

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions doc/source/changes/version_0_35.rst.inc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,28 @@ Miscellaneous improvements
Fixes
^^^^^

* fixed array[slice_group_from_an_incompatible_axis] and
array.sum(slice_group_from_an_incompatible_axis) (closes :issue:`1146`
and :issue:`1117`).
It used to evaluate the slice on the array axis instead of first evaluating
the slice on the axis it was created on, then take the corresponding labels
from the array axis.

>>> arr = ndtest(3)
>>> arr
a a0 a1 a2
0 1 2
>>> other_axis_a = Axis('a=a0,a1')
>>> group = other_axis_a[:]
>>> print(group)
['a0' 'a1']
>>> arr[group] # <-- before
a a0 a1 a2
0 1 2
>>> arr[group] # <-- now
a a0 a1
0 1

* fixed error message when trying to take a subset of an array with an array
key which has ndim > 1 and some bad values in the key. The message was also
improved (see the issue for details). Closes :issue:`1134`.
Expand Down
20 changes: 14 additions & 6 deletions larray/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2140,6 +2140,9 @@ def sort_values(self, key=None, axis=None, ascending=True) -> 'Array':
# FWIW, using .data, I get IGroup([1, 2, 0], axis='nat'), which works.
sorter = axis.i[indicesofsorted.data]
res = self[sorter]
# res has its axis in a different order than the original axis
# so we need this line to reverse the order below if not ascending
axis = res.axes[axis]
else:
res = self.combine_axes()
indicesofsorted = np.argsort(res.data)
Expand Down Expand Up @@ -2799,22 +2802,27 @@ def _group_aggregate(self, op, items, keepaxes=False, out=None, **kwargs) -> 'Ar
if isinstance(item, tuple):
assert all(isinstance(g, Group) for g in item)
groups = item
axis = groups[0].axis
group_axis = groups[0].axis
# they should all have the same axis (this is already checked
# in _prepare_aggregate though)
assert all(g.axis.equals(axis) for g in groups[1:])
assert all(g.axis.equals(group_axis) for g in groups[1:])
killaxis = False
else:
# item is in fact a single group
assert isinstance(item, Group), type(item)
groups = (item,)
axis = item.axis
group_axis = item.axis
# it is easier to kill the axis after the fact
killaxis = True

axis, axis_idx = res.axes[axis], res.axes.index(axis)
axis_idx = res.axes.index(group_axis)
res_axis = res.axes[axis_idx]
assert group_axis.equals(res_axis)

# potentially translate axis reference to real axes
groups = tuple(g.with_axis(axis) for g in groups)
# with_axis is correct because we already checked
# that g.axis.equals(axis)
groups = tuple(g.with_axis(res_axis) for g in groups)
res_shape[axis_idx] = len(groups)

# XXX: this code is fragile. I wonder if there isn't a way to ask the function what kind of dtype/shape it
Expand Down Expand Up @@ -2866,7 +2874,7 @@ def _group_aggregate(self, op, items, keepaxes=False, out=None, **kwargs) -> 'Ar
# We do NOT modify the axis name (eg append "_agg" or "*") even though this creates a new axis that is
# independent from the original one because the original name is what users will want to use to access
# that axis (eg in .filter kwargs)
res_axes[axis_idx] = Axis(groups, axis.name)
res_axes[axis_idx] = Axis(groups, res_axis.name)

if isinstance(res_data, np.ndarray):
res = Array(res_data, res_axes)
Expand Down
116 changes: 72 additions & 44 deletions larray/core/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,32 +922,41 @@ def index(self, key) -> Union[int, np.ndarray, slice]:
"""
mapping = self._mapping

if isinstance(key, Group) and key.axis is not self and key.axis is not None:
try:
# XXX: this is potentially very expensive if key.key is an array or list and should be tried as a last
# resort
potential_tick = _to_tick(key)

# avoid matching 0 against False or 0.0, note that None has object dtype and so always pass this test
if self._is_key_type_compatible(potential_tick):
try:
res_idx = mapping[potential_tick]
if potential_tick != key.key:
# only warn if no KeyError was raised (potential_tick is in mapping)
msg = "Using a Group object which was used to create an aggregate to " \
"target its aggregated label is deprecated. " \
"Please use the aggregated label directly instead. " \
f"In this case, you should use {potential_tick!r} instead of " \
f"using {key!r}."
# let us hope the stacklevel does not vary by codepath
warnings.warn(msg, FutureWarning, stacklevel=8)
return res_idx
except KeyError:
pass
# we must catch TypeError because key might not be hashable (eg slice)
# IndexError is for when mapping is an ndarray
except (KeyError, TypeError, IndexError):
pass
if isinstance(key, Group):
if key.axis is self:
if isinstance(key, IGroup):
return key.key
else:
# at this point we do not care about the axis nor the name
key = key.key
elif key.axis is not None:
try:
# TODO: remove this as it is potentially very expensive
# if key.key is an array or list and should be tried
# as a last resort
potential_tick = _to_tick(key)

# avoid matching 0 against False or 0.0, note that None has
# object dtype and so always pass this test
if self._is_key_type_compatible(potential_tick):
try:
res_idx = mapping[potential_tick]
if potential_tick != key.key:
# only warn if no KeyError was raised (potential_tick is in mapping)
msg = "Using a Group object which was used to create an aggregate to " \
"target its aggregated label is deprecated. " \
"Please use the aggregated label directly instead. " \
f"In this case, you should use {potential_tick!r} instead of " \
f"using {key!r}."
# let us hope the stacklevel does not vary by codepath
warnings.warn(msg, FutureWarning, stacklevel=8)
return res_idx
except KeyError:
pass
# we must catch TypeError because key might not be hashable (eg slice)
# IndexError is for when mapping is an ndarray
except (KeyError, TypeError, IndexError):
pass

if isinstance(key, str):
# try the key as-is to allow getting at ticks with special characters (",", ":", ...)
Expand All @@ -961,24 +970,35 @@ def index(self, key) -> Union[int, np.ndarray, slice]:
except (KeyError, TypeError, IndexError):
pass

# transform "specially formatted strings" for slices, lists, LGroup and IGroup to actual objects
# transform "specially formatted strings" for slices, lists, LGroup
# and IGroup to actual objects
key = _to_key(key)

if isinstance(key, range):
key = list(key)

# this can happen when key was passed as a string and converted to a Group via _to_key
if isinstance(key, Group) and isinstance(key.axis, str) and key.axis != self.name:
raise KeyError(key)

if isinstance(key, IGroup):
if isinstance(key.axis, Axis):
assert key.axis is self
return key.key

if isinstance(key, LGroup):
# at this point we do not care about the axis nor the name
key = key.key
elif isinstance(key, Group):
key_axis = key.axis
if isinstance(key_axis, str):
if key_axis != self.name:
raise KeyError(key)
elif isinstance(key_axis, AxisReference):
if key_axis.name != self.name:
raise KeyError(key)
elif isinstance(key_axis, Axis): # we know it is not self
# IGroups will be retargeted to LGroups
key = key.retarget_to(self)
elif isinstance(key_axis, int):
raise TypeError('Axis.index() does not support Group keys with '
'integer axis')
else:
assert key_axis is None
# an IGroup can still exist at this point if the key was an IGroup
# with a compatible axis (string or AxisReference axis with the
# correct name or Axis object equal to self)
if isinstance(key, IGroup):
return key.key
else:
key = key.key

if isinstance(key, slice):
start = mapping[key.start] if key.start is not None else None
Expand Down Expand Up @@ -1915,7 +1935,8 @@ def __contains__(self, key) -> bool:
if isinstance(key, int):
return -len(self) <= key < len(self)
elif isinstance(key, Axis):
# the special case is just a performance optimization to avoid scanning through the whole list
# the special case is just a performance optimization to avoid
# scanning through the whole list
if key.name is not None:
return key.name in self._map
else:
Expand Down Expand Up @@ -2808,7 +2829,7 @@ def _guess_axis(self, axis_key):
# we have axis information but not necessarily an Axis object from self
real_axis = self[group_axis]
if group_axis is not real_axis:
axis_key = axis_key.with_axis(real_axis)
axis_key = axis_key.retarget_to(real_axis)
return axis_key

real_axis, axis_pos_key = self._translate_nice_key(axis_key)
Expand All @@ -2828,6 +2849,7 @@ def _translate_axis_key_chunk(self, axis_key):
(axis, indices)
Indices group with a valid axis (from self)
"""
orig_key = axis_key
axis_key = remove_nested_groups(axis_key)

if isinstance(axis_key, IGroup):
Expand All @@ -2852,11 +2874,16 @@ def _translate_axis_key_chunk(self, axis_key):
# labels but known axis
if isinstance(axis_key, LGroup) and axis_key.axis is not None:
try:
real_axis = self[axis_key.axis]
key_axis = axis_key.axis
real_axis = self[key_axis]
if isinstance(key_axis, (AxisReference, int)):
# this is one of the rare cases where with_axis is correct !
axis_key = axis_key.with_axis(real_axis)

try:
axis_pos_key = real_axis.index(axis_key)
except KeyError:
raise ValueError(f"{axis_key!r} is not a valid label for the {real_axis.name!r} axis "
raise ValueError(f"{orig_key!r} is not a valid label for the {real_axis.name!r} axis "
f"with labels: {', '.join(repr(label) for label in real_axis.labels)}")
return real_axis, axis_pos_key
except KeyError:
Expand Down Expand Up @@ -3889,6 +3916,7 @@ def align_axis_collections(axis_collections, join='outer', axes=None):

class AxisReference(ABCAxisReference, ExprNode, Axis):
def __init__(self, name):
assert isinstance(name, (int, str))
self.name = name
self._labels = None
self._iswildcard = False
Expand Down
27 changes: 25 additions & 2 deletions larray/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,17 @@ def test_getitem(array):
_ = array[bad[1, 2], a[3, 4]]


def test_getitem_group_from_another_axis():
# using slice Group from an axis not present, we must retarget the group
arr = ndtest(3)
a2 = Axis('a=a0,a1')

# issue #1146
expected = ndtest(2)
res = arr[a2[:]]
assert_larray_equal(res, expected)


def test_getitem_abstract_axes(array):
raw = array.data
a, b, c, d = array.axes
Expand Down Expand Up @@ -1130,8 +1141,6 @@ def test_getitem_single_larray_key_guess():
_ = arr[key]




def test_getitem_multiple_larray_key_guess():
a, b, c, d, e = ndtest((2, 3, 2, 3, 2)).axes
arr = ndtest((a, b))
Expand Down Expand Up @@ -2161,6 +2170,20 @@ def test_group_agg_label_group(array):
res = array.sum(a, c).sum((g1, g2, g3, g_all))
assert res.shape == (4, 6)

# d) group aggregate using a group from another axis
# 1) LGroup
array = ndtest(3)
smaller_a_axis = Axis('a=a0,a1')
group = smaller_a_axis[:]
res = array.sum(group)
assert res == 1

# 2) IGroup
group = Axis("a=a1,a0").i[0] # targets a1
assert array[group] == 1
res = array.sum(group)
assert res == 1


def test_group_agg_label_group_no_axis(array):
a, b, c, d = array.axes
Expand Down
10 changes: 10 additions & 0 deletions larray/tests/test_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,16 @@ def test_index():
assert a.index('a1') == 1
assert a.index('a1 >> A1') == 1

time = Axis([2007, 2009], 'time')
res = time.index(time.i[1])
assert res == 1

res = time.index(X.time.i[1])
assert res == 1

res = time.index('time.i[1]')
assert res == 1


def test_astype():
arr = ndtest(Axis('time=2015..2020,total')).drop('total')
Expand Down
Loading