Skip to content

Commit 43ff83a

Browse files
committed
FEAT: implemented Session.align(other_session) (closes #501)
in the process had to implement Axis.align(*axes), AxisCollection.align(*axis_collections) and Array.align(*arrays) also deprecated passing non (arrays or scalar) to Array.align and check that each aligned axis is either anonymous or have the same name than others
1 parent 34bfb98 commit 43ff83a

File tree

6 files changed

+320
-71
lines changed

6 files changed

+320
-71
lines changed

doc/source/changes/version_0_35.rst.inc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ Syntax changes
1111
(:py:obj:`Array.plot.area()`, :py:obj:`Array.plot.bar()`,
1212
:py:obj:`Array.plot.barh()`, and :py:obj:`Array.plot.line()`).
1313

14+
* all align() methods (:py:obj:`Axis.align()`, :py:obj:`AxisCollection.align()`
15+
and :py:obj:`Array.align()`) only take options (``join``, ``axes`` and/or
16+
``fill_value``) as keywords arguments. Extra positional arguments will be
17+
considered as more objects to align (see below).
18+
1419

1520
Backward incompatible changes
1621
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -46,6 +51,9 @@ New features
4651

4752
>>> arr.plot.heatmap()
4853

54+
* implemented :py:obj:`Session.align()` to align all the arrays in several
55+
sessions at once. Closes :issue:`501`.
56+
4957
* added a feature (see the :ref:`miscellaneous section <misc>` for details). It works on :ref:`api-axis` and
5058
:ref:`api-group` objects.
5159

@@ -89,6 +97,12 @@ Miscellaneous improvements
8997

9098
* made :py:obj:`ipfp()` slightly faster when display_progress is False.
9199

100+
* all align() methods (:py:obj:`Axis.align()`, :py:obj:`AxisCollection.align()`
101+
and :py:obj:`Array.align()`) now support aligning more than two objects at
102+
once by passing them as positional arguments. For example:
103+
104+
>>> array1.align(array2, array3, join='outer')
105+
92106

93107
Fixes
94108
^^^^^

larray/core/array.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from larray.core.group import (Group, IGroup, LGroup, _to_key, _to_keys,
5555
_translate_sheet_name, _translate_group_key_hdf)
5656
from larray.core.axis import Axis, AxisReference, AxisCollection, X, _make_axis # noqa: F401
57+
from larray.core.axis import align_axis_collections
5758
from larray.core.plot import PlotObject
5859
from larray.util.misc import (table2str, size2str, ReprString,
5960
float_error_handler_factory, light_product, common_dtype,
@@ -853,6 +854,34 @@ def np_array_to_pd_index(array, name=None, tupleize_cols=True):
853854
return pd.Index(array, dtype=dtype, name=name, tupleize_cols=tupleize_cols)
854855

855856

857+
def align_arrays(values, join='outer', fill_value=nan, axes=None):
858+
bad_values = [value for value in values
859+
if not isinstance(value, Array) and not np.isscalar(value)]
860+
if bad_values:
861+
bad_types = set(type(v) for v in bad_values)
862+
bad_type_names = sorted(t.__name__ for t in bad_types)
863+
raise TypeError("align only supports Arrays and scalars but got:"
864+
f"{', '.join(bad_type_names)}")
865+
axis_collections = [
866+
value.axes if isinstance(value, Array) else AxisCollection()
867+
for value in values
868+
]
869+
# fail early because reindex does not currently support anonymous axes
870+
if any(any(name is None for name in axis_col.names)
871+
for axis_col in axis_collections):
872+
raise ValueError("arrays with anonymous axes are currently not "
873+
"supported by Array.align")
874+
try:
875+
aligned_axis_collections = align_axis_collections(axis_collections,
876+
join=join, axes=axes)
877+
except ValueError as e:
878+
raise ValueError(f"Arrays are not aligned because {e}")
879+
return tuple(value.reindex(aligned_axes, fill_value=fill_value)
880+
if isinstance(value, Array)
881+
else value
882+
for value, aligned_axes in zip(values, aligned_axis_collections))
883+
884+
856885
class Array(ABCArray):
857886
r"""
858887
An Array object represents a multidimensional, homogeneous array of fixed-size items with labeled axes.
@@ -1817,14 +1846,14 @@ def get_group(res_axes, self_axis):
18171846
else:
18181847
return res
18191848

1820-
def align(self, other, join='outer', fill_value=nan, axes=None) -> Tuple['Array', 'Array']:
1821-
r"""Align two arrays on their axes with the specified join method.
1849+
def align(self, *other, join='outer', fill_value=nan, axes=None) -> Tuple['Array', 'Array']:
1850+
r"""Align array with other(s) on their axes with the specified join method.
18221851
18231852
In other words, it ensure all common axes are compatible. Those arrays can then be used in binary operations.
18241853
18251854
Parameters
18261855
----------
1827-
other : Array-like
1856+
*other : Array-like
18281857
join : {'outer', 'inner', 'left', 'right', 'exact'}, optional
18291858
Join method. For each axis common to both arrays:
18301859
- outer: will use a label if it is in either arrays axis (ordered like the first array).
@@ -1837,13 +1866,13 @@ def align(self, other, join='outer', fill_value=nan, axes=None) -> Tuple['Array'
18371866
Value used to fill cells corresponding to label combinations which are not common to both arrays.
18381867
Defaults to NaN.
18391868
axes : AxisReference or sequence of them, optional
1840-
Axes to align. Need to be valid in both arrays. Defaults to None (all common axes). This must be specified
1869+
Axes to align. Need to be valid in all arrays. Defaults to None (all common axes). This must be specified
18411870
when mixing anonymous and non-anonymous axes.
18421871
18431872
Returns
18441873
-------
1845-
(left, right) : (Array, Array)
1846-
Aligned objects
1874+
arrays : tuple of Array
1875+
Aligned arrays
18471876
18481877
Notes
18491878
-----
@@ -1989,18 +2018,11 @@ def align(self, other, join='outer', fill_value=nan, axes=None) -> Tuple['Array'
19892018
>>> arr1.align(arr2, join='exact') # doctest: +NORMALIZE_WHITESPACE
19902019
Traceback (most recent call last):
19912020
...
1992-
ValueError: Both arrays are not aligned because align method with join='exact'
2021+
ValueError: Arrays are not aligned because align method with join='exact'
19932022
expected Axis(['a0', 'a1'], 'a') to be equal to Axis(['a0', 'a1', 'a2'], 'a')
19942023
"""
1995-
other = asarray(other)
1996-
# reindex does not currently support anonymous axes
1997-
if any(name is None for name in self.axes.names) or any(name is None for name in other.axes.names):
1998-
raise ValueError("arrays with anonymous axes are currently not supported by Array.align")
1999-
try:
2000-
left_axes, right_axes = self.axes.align(other.axes, join=join, axes=axes)
2001-
except ValueError as e:
2002-
raise ValueError(f"Both arrays are not aligned because {e}")
2003-
return self.reindex(left_axes, fill_value=fill_value), other.reindex(right_axes, fill_value=fill_value)
2024+
return align_arrays((self, *other),
2025+
join=join, fill_value=fill_value, axes=axes)
20042026

20052027
@deprecate_kwarg('reverse', 'ascending', {True: False, False: True})
20062028
def sort_values(self, key=None, axis=None, ascending=True) -> 'Array':

larray/core/axis.py

Lines changed: 129 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from larray.util.misc import (duplicates, array_lookup2, ReprString, index_by_id, renamed_to, LHDFStore,
1818
lazy_attribute, _isnoneslice, unique_list, unique_multi, Product, argsort, has_duplicates,
1919
exactly_one, concatenate_ndarrays)
20+
from larray.util.misc import first
2021
from larray.util.types import Scalar
2122

22-
2323
np_frompyfunc = np.frompyfunc
2424

2525

@@ -1330,12 +1330,12 @@ def difference(self, other) -> 'Axis':
13301330
to_drop = set(other)
13311331
return Axis([label for label in self.labels if label not in to_drop], self.name)
13321332

1333-
def align(self, other, join='outer') -> 'Axis':
1333+
def align(self, *other, join='outer') -> 'Axis':
13341334
r"""Align axis with other object using specified join method.
13351335
13361336
Parameters
13371337
----------
1338-
other : Axis or label sequence
1338+
*other : Axis or label sequence
13391339
join : {'outer', 'inner', 'left', 'right', 'exact'}, optional
13401340
Defaults to 'outer'.
13411341
@@ -1366,22 +1366,16 @@ def align(self, other, join='outer') -> 'Axis':
13661366
ValueError: align method with join='exact' expected
13671367
Axis(['a0', 'a1', 'a2'], 'a') to be equal to Axis(['a1', 'a2', 'a3'], 'a')
13681368
"""
1369-
assert join in {'outer', 'inner', 'left', 'right', 'exact'}
1370-
if join == 'outer':
1371-
return self.union(other)
1372-
elif join == 'inner':
1373-
return self.intersection(other)
1374-
elif join == 'left':
1375-
return self
1376-
elif join == 'right':
1377-
if not isinstance(other, Axis):
1378-
other = Axis(other)
1379-
return other
1380-
elif join == 'exact':
1381-
if not self.equals(other):
1382-
raise ValueError(f"align method with join='exact' expected {self!r} to be equal to {other!r}")
1383-
else:
1384-
return self
1369+
bad_objs = [obj for obj in other if not isinstance(obj, Axis)]
1370+
if bad_objs:
1371+
for obj in bad_objs:
1372+
obj_type = type(obj).__name__
1373+
warnings.warn(f"aligning an Axis to a non-Axis object "
1374+
f"({obj_type}) is deprecated. Please convert to "
1375+
f"an Axis first.", FutureWarning, stacklevel=2)
1376+
other = [obj if isinstance(obj, Axis) else Axis(obj)
1377+
for obj in other]
1378+
return align_axes((self, *other), join=join)
13851379

13861380
def to_hdf(self, filepath, key=None) -> None:
13871381
r"""
@@ -1462,6 +1456,50 @@ def ignore_labels(self) -> 'Axis':
14621456
return Axis(len(self), self.name)
14631457

14641458

1459+
def align_axes(axes: Sequence[Axis], join: str = 'outer') -> Axis:
1460+
if not all(isinstance(axis, Axis) for axis in axes):
1461+
raise TypeError("all objects to align must be Axis objects")
1462+
1463+
if join not in {'outer', 'inner', 'left', 'right', 'exact'}:
1464+
raise ValueError(f"join must be one of 'outer', 'inner', 'left', "
1465+
f"'right' or 'exact', got {join!r}")
1466+
1467+
names = [axis.name for axis in axes]
1468+
first_name = first((name for name in names if name is not None),
1469+
default=None)
1470+
if first_name is not None:
1471+
if not all(name is None or name == first_name for name in names):
1472+
raise ValueError("In align, all axes must be anonymous or "
1473+
"have the same name: "
1474+
f"{', '.join(repr(name) for name in names)}")
1475+
1476+
def join_left(axis1, axis2):
1477+
return axis1
1478+
def join_right(axis1, axis2):
1479+
return axis2
1480+
def join_exact(axis1, axis2):
1481+
if not axis1.equals(axis2):
1482+
raise ValueError(f"align method with join='exact' expected "
1483+
f"{axis1!r} to be equal to {axis2!r}")
1484+
else:
1485+
return axis1
1486+
if join == 'outer':
1487+
join_labels_func = Axis.union
1488+
elif join == 'inner':
1489+
join_labels_func = Axis.intersection
1490+
elif join == 'left':
1491+
join_labels_func = join_left
1492+
elif join == 'right':
1493+
join_labels_func = join_right
1494+
else:
1495+
assert join == 'exact'
1496+
join_labels_func = join_exact
1497+
aligned_axis = axes[0]
1498+
for axis in axes[1:]:
1499+
aligned_axis = join_labels_func(aligned_axis, axis)
1500+
return aligned_axis
1501+
1502+
14651503
def _make_axis(obj) -> Axis:
14661504
if isinstance(obj, Axis):
14671505
return obj
@@ -3552,23 +3590,25 @@ def _prepare_split_axes(self, axes, names, sep) -> dict:
35523590

35533591
split_axis = renamed_to(split_axes, 'split_axis', raise_error=True)
35543592

3555-
def align(self, other, join='outer', axes=None) -> Tuple['AxisCollection', 'AxisCollection']:
3556-
r"""Align this axis collection with another.
3593+
def align(self, *other, join='outer', axes=None) -> Tuple['AxisCollection']:
3594+
r"""Align this AxisCollection with (an)other AxisCollection(s).
35573595
35583596
This ensures all common axes are compatible.
35593597
35603598
Parameters
35613599
----------
3562-
other : AxisCollection
3600+
*other : AxisCollection
3601+
AxisCollection(s) to align with this one.
35633602
join : {'outer', 'inner', 'left', 'right', 'exact'}, optional
35643603
Defaults to 'outer'.
35653604
axes : AxisReference or sequence of them, optional
3566-
Axes to align. Need to be valid in both arrays. Defaults to None (all common axes). This must be specified
3605+
Axes to align. Need to be valid in all axis collections.
3606+
Defaults to None (all common axes). This must be specified
35673607
when mixing anonymous and non-anonymous axes.
35683608
35693609
Returns
35703610
-------
3571-
(left, right) : (AxisCollection, AxisCollection)
3611+
tuple of AxisCollection
35723612
Aligned collections
35733613
35743614
See Also
@@ -3631,31 +3671,20 @@ def align(self, other, join='outer', axes=None) -> Tuple['AxisCollection', 'Axis
36313671
Axis(['c0'], None)
36323672
])
36333673
"""
3634-
if join not in {'outer', 'inner', 'left', 'right', 'exact'}:
3635-
raise ValueError("join should be one of 'outer', 'inner', 'left', 'right' or 'exact'")
3636-
other = other if isinstance(other, AxisCollection) else AxisCollection(other)
3637-
3638-
# if axes not specified
3639-
if axes is None:
3640-
# and we have only anonymous axes on both sides
3641-
if all(name is None for name in self.names) and all(name is None for name in other.names):
3642-
# use N first axes by position
3643-
join_axes = list(range(min(len(self), len(other))))
3644-
elif any(name is None for name in self.names) or any(name is None for name in other.names):
3645-
raise ValueError("axes collections with mixed anonymous/non anonymous axes are not supported by align"
3646-
"without specifying axes explicitly")
3647-
else:
3648-
assert all(name is not None for name in self.names) and all(name is not None for name in other.names)
3649-
# use all common axes
3650-
join_axes = list(OrderedSet(self.names) & OrderedSet(other.names))
3651-
else:
3652-
if isinstance(axes, (int, str, Axis)):
3653-
axes = [axes]
3654-
join_axes = axes
3655-
new_axes = [self_axis.align(other_axis, join=join)
3656-
for self_axis, other_axis in zip(self[join_axes], other[join_axes])]
3657-
axes_changes = list(zip(join_axes, new_axes))
3658-
return self.replace(axes_changes), other.replace(axes_changes)
3674+
# For backward compatibility with older code using align with a
3675+
# non-AxisCollection second argument, we only support aligning more
3676+
# than two collection when other contains actual AxisCollection objects
3677+
bad_objs = [obj for obj in other if not isinstance(obj, AxisCollection)]
3678+
if bad_objs:
3679+
for obj in bad_objs:
3680+
obj_type = type(obj).__name__
3681+
warnings.warn(f"aligning an AxisCollection to a "
3682+
f"non-AxisCollection object ({obj_type}) is "
3683+
f"deprecated. Please convert to an AxisCollection "
3684+
f"first.", FutureWarning, stacklevel=2)
3685+
other = [AxisCollection(obj) for obj in other]
3686+
3687+
return align_axis_collections((self, *other), join=join, axes=axes)
36593688

36603689
# XXX: make this into a public method/property? AxisCollection.flat_labels[flat_indices]?
36613690
def _flat_lookup(self, flat_indices):
@@ -3802,6 +3831,57 @@ def _adv_keys_to_combined_axes(self, key, wildcard=False, sep='_'):
38023831
return AxisCollection(combined_axis)
38033832

38043833

3834+
def align_axis_collections(axis_collections, join='outer', axes=None):
3835+
if join not in {'outer', 'inner', 'left', 'right', 'exact'}:
3836+
raise ValueError("join should be one of 'outer', 'inner', 'left', "
3837+
"'right' or 'exact'")
3838+
3839+
# if axes not specified
3840+
if axes is None:
3841+
# and we have only anonymous axes
3842+
if all(name is None for col in axis_collections
3843+
for name in col.names):
3844+
# use all axes by position
3845+
max_length = max(len(col) for col in axis_collections)
3846+
join_axes_refs = list(range(max_length))
3847+
elif any(name is None for col in axis_collections
3848+
for name in col.names):
3849+
raise ValueError(
3850+
"axes collections with mixed anonymous/non anonymous axes "
3851+
"are not supported by align without specifying axes "
3852+
"explicitly")
3853+
else:
3854+
assert all(name is not None for col in axis_collections
3855+
for name in col.names)
3856+
# use all axes by name
3857+
join_axes_refs = OrderedSet(axis_collections[0].names)
3858+
for col in axis_collections[1:]:
3859+
join_axes_refs |= OrderedSet(col.names)
3860+
else:
3861+
if isinstance(axes, (int, str, Axis)):
3862+
axes = [axes]
3863+
join_axes_refs = axes
3864+
3865+
# first compute all aligned axes for all collections
3866+
axes_changes = {
3867+
axis_ref: align_axes([axis_col[axis_ref]
3868+
for axis_col in axis_collections
3869+
if axis_ref in axis_col],
3870+
join=join)
3871+
for axis_ref in join_axes_refs
3872+
}
3873+
3874+
# then apply the changed axes for the collections where the axis exists
3875+
return tuple(
3876+
axis_col.replace({
3877+
axis_ref: aligned_axis
3878+
for axis_ref, aligned_axis in axes_changes.items()
3879+
if axis_ref in axis_col
3880+
})
3881+
for axis_col in axis_collections
3882+
)
3883+
3884+
38053885
class AxisReference(ABCAxisReference, ExprNode, Axis):
38063886
def __init__(self, name):
38073887
self.name = name

0 commit comments

Comments
 (0)