|
17 | 17 | from larray.util.misc import (duplicates, array_lookup2, ReprString, index_by_id, renamed_to, LHDFStore, |
18 | 18 | lazy_attribute, _isnoneslice, unique_list, unique_multi, Product, argsort, has_duplicates, |
19 | 19 | exactly_one, concatenate_ndarrays) |
| 20 | +from larray.util.misc import first |
20 | 21 | from larray.util.types import Scalar |
21 | 22 |
|
22 | | - |
23 | 23 | np_frompyfunc = np.frompyfunc |
24 | 24 |
|
25 | 25 |
|
@@ -1330,12 +1330,12 @@ def difference(self, other) -> 'Axis': |
1330 | 1330 | to_drop = set(other) |
1331 | 1331 | return Axis([label for label in self.labels if label not in to_drop], self.name) |
1332 | 1332 |
|
1333 | | - def align(self, other, join='outer') -> 'Axis': |
| 1333 | + def align(self, *other, join='outer') -> 'Axis': |
1334 | 1334 | r"""Align axis with other object using specified join method. |
1335 | 1335 |
|
1336 | 1336 | Parameters |
1337 | 1337 | ---------- |
1338 | | - other : Axis or label sequence |
| 1338 | + *other : Axis or label sequence |
1339 | 1339 | join : {'outer', 'inner', 'left', 'right', 'exact'}, optional |
1340 | 1340 | Defaults to 'outer'. |
1341 | 1341 |
|
@@ -1366,22 +1366,16 @@ def align(self, other, join='outer') -> 'Axis': |
1366 | 1366 | ValueError: align method with join='exact' expected |
1367 | 1367 | Axis(['a0', 'a1', 'a2'], 'a') to be equal to Axis(['a1', 'a2', 'a3'], 'a') |
1368 | 1368 | """ |
1369 | | - assert join in {'outer', 'inner', 'left', 'right', 'exact'} |
1370 | | - if join == 'outer': |
1371 | | - return self.union(other) |
1372 | | - elif join == 'inner': |
1373 | | - return self.intersection(other) |
1374 | | - elif join == 'left': |
1375 | | - return self |
1376 | | - elif join == 'right': |
1377 | | - if not isinstance(other, Axis): |
1378 | | - other = Axis(other) |
1379 | | - return other |
1380 | | - elif join == 'exact': |
1381 | | - if not self.equals(other): |
1382 | | - raise ValueError(f"align method with join='exact' expected {self!r} to be equal to {other!r}") |
1383 | | - else: |
1384 | | - return self |
| 1369 | + bad_objs = [obj for obj in other if not isinstance(obj, Axis)] |
| 1370 | + if bad_objs: |
| 1371 | + for obj in bad_objs: |
| 1372 | + obj_type = type(obj).__name__ |
| 1373 | + warnings.warn(f"aligning an Axis to a non-Axis object " |
| 1374 | + f"({obj_type}) is deprecated. Please convert to " |
| 1375 | + f"an Axis first.", FutureWarning, stacklevel=2) |
| 1376 | + other = [obj if isinstance(obj, Axis) else Axis(obj) |
| 1377 | + for obj in other] |
| 1378 | + return align_axes((self, *other), join=join) |
1385 | 1379 |
|
1386 | 1380 | def to_hdf(self, filepath, key=None) -> None: |
1387 | 1381 | r""" |
@@ -1462,6 +1456,50 @@ def ignore_labels(self) -> 'Axis': |
1462 | 1456 | return Axis(len(self), self.name) |
1463 | 1457 |
|
1464 | 1458 |
|
| 1459 | +def align_axes(axes: Sequence[Axis], join: str = 'outer') -> Axis: |
| 1460 | + if not all(isinstance(axis, Axis) for axis in axes): |
| 1461 | + raise TypeError("all objects to align must be Axis objects") |
| 1462 | + |
| 1463 | + if join not in {'outer', 'inner', 'left', 'right', 'exact'}: |
| 1464 | + raise ValueError(f"join must be one of 'outer', 'inner', 'left', " |
| 1465 | + f"'right' or 'exact', got {join!r}") |
| 1466 | + |
| 1467 | + names = [axis.name for axis in axes] |
| 1468 | + first_name = first((name for name in names if name is not None), |
| 1469 | + default=None) |
| 1470 | + if first_name is not None: |
| 1471 | + if not all(name is None or name == first_name for name in names): |
| 1472 | + raise ValueError("In align, all axes must be anonymous or " |
| 1473 | + "have the same name: " |
| 1474 | + f"{', '.join(repr(name) for name in names)}") |
| 1475 | + |
| 1476 | + def join_left(axis1, axis2): |
| 1477 | + return axis1 |
| 1478 | + def join_right(axis1, axis2): |
| 1479 | + return axis2 |
| 1480 | + def join_exact(axis1, axis2): |
| 1481 | + if not axis1.equals(axis2): |
| 1482 | + raise ValueError(f"align method with join='exact' expected " |
| 1483 | + f"{axis1!r} to be equal to {axis2!r}") |
| 1484 | + else: |
| 1485 | + return axis1 |
| 1486 | + if join == 'outer': |
| 1487 | + join_labels_func = Axis.union |
| 1488 | + elif join == 'inner': |
| 1489 | + join_labels_func = Axis.intersection |
| 1490 | + elif join == 'left': |
| 1491 | + join_labels_func = join_left |
| 1492 | + elif join == 'right': |
| 1493 | + join_labels_func = join_right |
| 1494 | + else: |
| 1495 | + assert join == 'exact' |
| 1496 | + join_labels_func = join_exact |
| 1497 | + aligned_axis = axes[0] |
| 1498 | + for axis in axes[1:]: |
| 1499 | + aligned_axis = join_labels_func(aligned_axis, axis) |
| 1500 | + return aligned_axis |
| 1501 | + |
| 1502 | + |
1465 | 1503 | def _make_axis(obj) -> Axis: |
1466 | 1504 | if isinstance(obj, Axis): |
1467 | 1505 | return obj |
@@ -3552,23 +3590,25 @@ def _prepare_split_axes(self, axes, names, sep) -> dict: |
3552 | 3590 |
|
3553 | 3591 | split_axis = renamed_to(split_axes, 'split_axis', raise_error=True) |
3554 | 3592 |
|
3555 | | - def align(self, other, join='outer', axes=None) -> Tuple['AxisCollection', 'AxisCollection']: |
3556 | | - r"""Align this axis collection with another. |
| 3593 | + def align(self, *other, join='outer', axes=None) -> Tuple['AxisCollection']: |
| 3594 | + r"""Align this AxisCollection with (an)other AxisCollection(s). |
3557 | 3595 |
|
3558 | 3596 | This ensures all common axes are compatible. |
3559 | 3597 |
|
3560 | 3598 | Parameters |
3561 | 3599 | ---------- |
3562 | | - other : AxisCollection |
| 3600 | + *other : AxisCollection |
| 3601 | + AxisCollection(s) to align with this one. |
3563 | 3602 | join : {'outer', 'inner', 'left', 'right', 'exact'}, optional |
3564 | 3603 | Defaults to 'outer'. |
3565 | 3604 | axes : AxisReference or sequence of them, optional |
3566 | | - Axes to align. Need to be valid in both arrays. Defaults to None (all common axes). This must be specified |
| 3605 | + Axes to align. Need to be valid in all axis collections. |
| 3606 | + Defaults to None (all common axes). This must be specified |
3567 | 3607 | when mixing anonymous and non-anonymous axes. |
3568 | 3608 |
|
3569 | 3609 | Returns |
3570 | 3610 | ------- |
3571 | | - (left, right) : (AxisCollection, AxisCollection) |
| 3611 | + tuple of AxisCollection |
3572 | 3612 | Aligned collections |
3573 | 3613 |
|
3574 | 3614 | See Also |
@@ -3631,31 +3671,20 @@ def align(self, other, join='outer', axes=None) -> Tuple['AxisCollection', 'Axis |
3631 | 3671 | Axis(['c0'], None) |
3632 | 3672 | ]) |
3633 | 3673 | """ |
3634 | | - if join not in {'outer', 'inner', 'left', 'right', 'exact'}: |
3635 | | - raise ValueError("join should be one of 'outer', 'inner', 'left', 'right' or 'exact'") |
3636 | | - other = other if isinstance(other, AxisCollection) else AxisCollection(other) |
3637 | | - |
3638 | | - # if axes not specified |
3639 | | - if axes is None: |
3640 | | - # and we have only anonymous axes on both sides |
3641 | | - if all(name is None for name in self.names) and all(name is None for name in other.names): |
3642 | | - # use N first axes by position |
3643 | | - join_axes = list(range(min(len(self), len(other)))) |
3644 | | - elif any(name is None for name in self.names) or any(name is None for name in other.names): |
3645 | | - raise ValueError("axes collections with mixed anonymous/non anonymous axes are not supported by align" |
3646 | | - "without specifying axes explicitly") |
3647 | | - else: |
3648 | | - assert all(name is not None for name in self.names) and all(name is not None for name in other.names) |
3649 | | - # use all common axes |
3650 | | - join_axes = list(OrderedSet(self.names) & OrderedSet(other.names)) |
3651 | | - else: |
3652 | | - if isinstance(axes, (int, str, Axis)): |
3653 | | - axes = [axes] |
3654 | | - join_axes = axes |
3655 | | - new_axes = [self_axis.align(other_axis, join=join) |
3656 | | - for self_axis, other_axis in zip(self[join_axes], other[join_axes])] |
3657 | | - axes_changes = list(zip(join_axes, new_axes)) |
3658 | | - return self.replace(axes_changes), other.replace(axes_changes) |
| 3674 | + # For backward compatibility with older code using align with a |
| 3675 | + # non-AxisCollection second argument, we only support aligning more |
| 3676 | + # than two collection when other contains actual AxisCollection objects |
| 3677 | + bad_objs = [obj for obj in other if not isinstance(obj, AxisCollection)] |
| 3678 | + if bad_objs: |
| 3679 | + for obj in bad_objs: |
| 3680 | + obj_type = type(obj).__name__ |
| 3681 | + warnings.warn(f"aligning an AxisCollection to a " |
| 3682 | + f"non-AxisCollection object ({obj_type}) is " |
| 3683 | + f"deprecated. Please convert to an AxisCollection " |
| 3684 | + f"first.", FutureWarning, stacklevel=2) |
| 3685 | + other = [AxisCollection(obj) for obj in other] |
| 3686 | + |
| 3687 | + return align_axis_collections((self, *other), join=join, axes=axes) |
3659 | 3688 |
|
3660 | 3689 | # XXX: make this into a public method/property? AxisCollection.flat_labels[flat_indices]? |
3661 | 3690 | def _flat_lookup(self, flat_indices): |
@@ -3802,6 +3831,57 @@ def _adv_keys_to_combined_axes(self, key, wildcard=False, sep='_'): |
3802 | 3831 | return AxisCollection(combined_axis) |
3803 | 3832 |
|
3804 | 3833 |
|
| 3834 | +def align_axis_collections(axis_collections, join='outer', axes=None): |
| 3835 | + if join not in {'outer', 'inner', 'left', 'right', 'exact'}: |
| 3836 | + raise ValueError("join should be one of 'outer', 'inner', 'left', " |
| 3837 | + "'right' or 'exact'") |
| 3838 | + |
| 3839 | + # if axes not specified |
| 3840 | + if axes is None: |
| 3841 | + # and we have only anonymous axes |
| 3842 | + if all(name is None for col in axis_collections |
| 3843 | + for name in col.names): |
| 3844 | + # use all axes by position |
| 3845 | + max_length = max(len(col) for col in axis_collections) |
| 3846 | + join_axes_refs = list(range(max_length)) |
| 3847 | + elif any(name is None for col in axis_collections |
| 3848 | + for name in col.names): |
| 3849 | + raise ValueError( |
| 3850 | + "axes collections with mixed anonymous/non anonymous axes " |
| 3851 | + "are not supported by align without specifying axes " |
| 3852 | + "explicitly") |
| 3853 | + else: |
| 3854 | + assert all(name is not None for col in axis_collections |
| 3855 | + for name in col.names) |
| 3856 | + # use all axes by name |
| 3857 | + join_axes_refs = OrderedSet(axis_collections[0].names) |
| 3858 | + for col in axis_collections[1:]: |
| 3859 | + join_axes_refs |= OrderedSet(col.names) |
| 3860 | + else: |
| 3861 | + if isinstance(axes, (int, str, Axis)): |
| 3862 | + axes = [axes] |
| 3863 | + join_axes_refs = axes |
| 3864 | + |
| 3865 | + # first compute all aligned axes for all collections |
| 3866 | + axes_changes = { |
| 3867 | + axis_ref: align_axes([axis_col[axis_ref] |
| 3868 | + for axis_col in axis_collections |
| 3869 | + if axis_ref in axis_col], |
| 3870 | + join=join) |
| 3871 | + for axis_ref in join_axes_refs |
| 3872 | + } |
| 3873 | + |
| 3874 | + # then apply the changed axes for the collections where the axis exists |
| 3875 | + return tuple( |
| 3876 | + axis_col.replace({ |
| 3877 | + axis_ref: aligned_axis |
| 3878 | + for axis_ref, aligned_axis in axes_changes.items() |
| 3879 | + if axis_ref in axis_col |
| 3880 | + }) |
| 3881 | + for axis_col in axis_collections |
| 3882 | + ) |
| 3883 | + |
| 3884 | + |
3805 | 3885 | class AxisReference(ABCAxisReference, ExprNode, Axis): |
3806 | 3886 | def __init__(self, name): |
3807 | 3887 | self.name = name |
|
0 commit comments