diff --git a/adcgen/generate_code/contraction.py b/adcgen/generate_code/contraction.py index 4579518..553fba9 100644 --- a/adcgen/generate_code/contraction.py +++ b/adcgen/generate_code/contraction.py @@ -66,6 +66,18 @@ class Contraction: The names of the contracted tensors. term_target_indices: tuple[Index] The target indices of the term the contraction belongs to. + Note that it is not possible for indices that are given as + 'term_target_indices' to become contracted indices! + contracted: Sequence[Index], optional + The contracted indices for the contraction. Note that the + indices might be reordered. + If not given they will be determined from the given indices + according to the einstein sum convention. + target: Sequence[Index], optional + The target indices for the contraction. Note that the given + indices might be reordered. + If not given they will be determined from the given indices + according to the einstein sum convention. """ # use counter that essentially counts how many class instances have # been created @@ -79,7 +91,9 @@ class Contraction: def __init__(self, indices: Sequence[tuple[Index, ...]], names: Sequence[str], - term_target_indices: Sequence[Index]) -> None: + term_target_indices: Sequence[Index], + contracted: Sequence[Index] | None = None, + target: Sequence[Index] | None = None) -> None: if not isinstance(indices, tuple): indices = tuple(indices) if isinstance(names, str): @@ -89,13 +103,18 @@ def __init__(self, indices: Sequence[tuple[Index, ...]], self.indices: tuple[tuple[Index, ...], ...] = indices self.names: tuple[str, ...] = names - self.contracted: tuple[Index, ...] = tuple() - self.target: tuple[Index, ...] = tuple() + + self.contracted: tuple[Index, ...] + self.target: tuple[Index, ...] self.scaling: Scaling + self._determine_contracted_and_target( + term_target_indices=term_target_indices, contracted=contracted, + target=target + ) + self._determine_scaling() + self.id: int = next(self._instance_counter) self.contraction_name: str = f"{self._base_name}_{self.id}" - self._determine_contracted_and_target(term_target_indices) - self._determine_scaling() def __str__(self): return (f"Contraction(indices={self.indices}, names={self.names}, " @@ -106,18 +125,46 @@ def __str__(self): def __repr__(self): return self.__str__() - def _determine_contracted_and_target(self, - term_target_indices: Sequence[Index] - ) -> None: + def _determine_contracted_and_target( + self, term_target_indices: Sequence[Index], + contracted: Sequence[Index] | None = None, + target: Sequence[Index] | None = None) -> None: """ Determines and sets the contracted and target indices on the contraction using the provided target indices of the term - the contraction is a part of. In case the target indices of the - contraction contain the same indices as the target indices of the - term, the target indices of the term will be used instead. + the contraction is a part of. In case of an outer contraction + (the target indices are the reordered 'term_target_indices') + the term_target_indices will be used as target indices. + If provided, the provided contracted and target indices + will be used after reordering them. """ - contracted, target = self._split_contracted_and_target( - self.indices, term_target_indices + if contracted is None and target is None: # determine both + contracted, target = self._split_contracted_and_target( + self.indices, term_target_indices + ) + elif contracted is None: # target indices are given + # -> compute contracted removing duplicates + assert isinstance(target, Sequence) + contracted = list({ + idx for idx in itertools.chain.from_iterable(self.indices) + if idx not in target + }) + elif target is None: # contracted are given + # -> compute contracted removing duplicates + assert isinstance(contracted, Sequence) + target = list({ + idx for idx in itertools.chain.from_iterable(self.indices) + if idx not in contracted + }) + # sanity checks: + # - no index in contracted and target at the same time + # - no contracted index can be a target index of the overall term + # - all indices of the objects have to be in contracted or target + assert not any(idx in target for idx in contracted) + assert not any(idx in term_target_indices for idx in contracted) + assert all( + idx in target or idx in contracted + for idx in itertools.chain.from_iterable(self.indices) ) # sort the indices canonical contracted = sorted(contracted, key=sort_idx_canonical) diff --git a/adcgen/generate_code/optimize_contractions.py b/adcgen/generate_code/optimize_contractions.py index f90f23e..82e57c2 100644 --- a/adcgen/generate_code/optimize_contractions.py +++ b/adcgen/generate_code/optimize_contractions.py @@ -5,7 +5,7 @@ from sympy import Symbol, S from ..expression import TermContainer -from ..indices import get_symbols, Index +from ..indices import get_symbols, Index, sort_idx_canonical from ..sympy_objects import SymbolicTensor, KroneckerDelta from .contraction import Contraction, Sizes @@ -137,36 +137,57 @@ def _optimize_contractions(relevant_obj_names: Sequence[str], ) -> Generator[list[Contraction], None, None]: """ Find the optimal contractions for the given relevant objects of a term. + + Parameters + ---------- + relevant_obj_names: Sequence[str] + The names of the objects to consider. + relevant_obj_indices: Sequence[tuple[Index, ...]] + The indices of the objects to consider. + target_indices: Sequence[Index] + The target indices of the term. + max_itmd_dim: int, optional + The maximum allowed dimensionality of temporary intermediate results. + If not given, the dimensionality is not restricted. + max_n_simultaneous_contracted: int, optional + The maximum number of tensors allowed in a single contraction, e.g., + 2 to prevent any hyper-contractions. """ assert len(relevant_obj_indices) == len(relevant_obj_names) if len(relevant_obj_names) < 2: raise ValueError("Need at least 2 objects to define a contraction.") # split the relevant objects into subgroups that share contracted indices - # and therefore should be contracted simultaneously connected_groups = _group_objects( obj_indices=relevant_obj_indices, target_indices=target_indices, max_group_size=max_n_simultaneous_contracted ) - - for group in connected_groups: - contr_indices = tuple(relevant_obj_indices[pos] for pos in group) - contr_names = tuple(relevant_obj_names[pos] for pos in group) - contraction = Contraction(indices=contr_indices, names=contr_names, - term_target_indices=target_indices) + for group, contracted in connected_groups: + # build a contraction for the objects and the contracted indices + indices = [relevant_obj_indices[pos] for pos in group] + names = [relevant_obj_names[pos] for pos in group] + contraction = Contraction( + indices=indices, names=names, term_target_indices=target_indices, + contracted=contracted + ) # if the contraction is not an outer contraction we have to check # the dimensionality of the intermediate tensor if max_itmd_dim is not None and \ contraction.target != target_indices and \ len(contraction.target) > max_itmd_dim: continue - # remove the contracted names and indices + # update the data excluding the just contracted objects + # and adding the contraction to the pool remaining_pos = [pos for pos in range(len(relevant_obj_names)) if pos not in group] - remaining_names = (contraction.contraction_name, - *(relevant_obj_names[pos] for pos in remaining_pos)) - remaining_indices = (contraction.target, *(relevant_obj_indices[pos] - for pos in remaining_pos)) + remaining_names = ( + contraction.contraction_name, + *(relevant_obj_names[pos] for pos in remaining_pos) + ) + remaining_indices = ( + contraction.target, + *(relevant_obj_indices[pos] for pos in remaining_pos) + ) # there are no objects left to contract -> we are done if len(remaining_names) == 1: yield [contraction] @@ -179,30 +200,29 @@ def _optimize_contractions(relevant_obj_names: Sequence[str], max_n_simultaneous_contracted=max_n_simultaneous_contracted ) for contraction_scheme in completed_schemes: + # ensure that the contracted indices don't appear in any later + # contraction again + assert not any( + s in idx for s in contraction.contracted + for c in contraction_scheme for idx in c.indices + ) contraction_scheme.insert(0, contraction) yield contraction_scheme -def _group_objects(obj_indices: Sequence[tuple[Index, ...]], - target_indices: Sequence[Index], - max_group_size: int | None = None - ) -> tuple[tuple[int, ...], ...]: +def _group_objects( + obj_indices: Sequence[tuple[Index, ...]], + target_indices: Sequence[Index], + max_group_size: int | None = None + ) -> Generator[tuple[tuple[int, ...], tuple[Index, ...]], None, None]: """ - Split the provided relevant objects into subgroups that share common - contracted indices. Thereby, a group can at most contain 'max_group_size' - objects. By default, all objects are allowed to be in one group. + Split the provided relevant objects defined by their indices + (``obj_indices``) into subgroups that share common contracted indices. + Thereby, a group can at most contain ``max_group_size`` + objects and produce a result with ``max_result_dim`` dimensions. + By default, all objects are allowed to be in one group and arbitrary + result dimensionalities are allowed. """ - # NOTE: the algorithm currently maximizes the number of contracted - # indices, i.e., a contraction runs over all common contracted - # indices. While this is fine in most cases, it might be benefitial - # to not contract all possible indices simultaneously - # in certain cases, since this leads to an increased group size: - # 0 1 2 - # d_ijk d_ij d_jl - # 0 and 1 share i and j. A contraction running only over i can be - # performed for the pair (0, 1). However, if the contraction runs - # runs over i and j, we have to consider the triple (0, 1, 2). - # sanity checks for input assert len(obj_indices) > 1 # we need at least 2 objects if max_group_size is None: @@ -211,70 +231,200 @@ def _group_objects(obj_indices: Sequence[tuple[Index, ...]], # track on which objects the indices appear idx_occurences: dict[Index, list[int]] = {} - for pos, indices in enumerate(obj_indices): - for idx in indices: - if idx not in idx_occurences: - idx_occurences[idx] = [] - idx_occurences[idx].append(pos) + for pos, idx in enumerate(obj_indices): + for s in idx: + if s not in idx_occurences: + idx_occurences[s] = [] + if s not in idx_occurences[s]: + idx_occurences[s].append(pos) + del idx - # store grouped objects and isolated objects (outer products) - # for the groups we are using a dict, since it by default returns - # keys in the order they were inserted. A set would need to be sorted - # before returning to produce consistent results. - groups: dict[tuple[int, ...], None] = {} - outer_products: list[tuple[int, int]] = [] + # cache already encountered valid groups + # excluding outer products since they can not appear twice + seen_groups: set[tuple[tuple[int, ...], tuple[Index, ...]]] = set() # iterate over all pairs of objects (index tuples) for (pos1, indices1), (pos2, indices2) in \ itertools.combinations(enumerate(obj_indices), 2): # check if the objects have any common contracted indices # -> outer products can be treated as pair contracted, _ = Contraction._split_contracted_and_target( - (indices1, indices2), target_indices + indices=(indices1, indices2), term_target_indices=target_indices + ) + if not contracted: # outer product + yield ((pos1, pos2), tuple()) + continue + contracted = sorted(contracted, key=sort_idx_canonical) + # Starting from the given pair try to explore all sensible + # combinations of contracted indices possibly increasing + # the group size (also exploring hyper-contractions) + groups = _explore_group( + seen_groups=seen_groups, obj_indices=obj_indices, + target_indices=target_indices, max_group_size=max_group_size, + idx_occurences=idx_occurences, contracted=contracted, + positions=(pos1, pos2) ) - if not contracted: - outer_products.append((pos1, pos2)) + for group in groups: + yield group + return None + + +def _explore_group( + seen_groups: set[tuple[tuple[int, ...], tuple[Index, ...]]], + obj_indices: Sequence[tuple[Index, ...]], + target_indices: Sequence[Index], + max_group_size: int, + idx_occurences: dict[Index, list[int]], + contracted: Sequence[Index], + positions: Sequence[int], + forbidden_contracted: tuple[Index, ...] = tuple() + ) -> Generator[tuple[tuple[int, ...], tuple[Index, ...]], None, None]: + """ + Recursively explores the group by expanding the number of + contracted indices and the group size. + + Parameters + ---------- + seen_groups: set[tuple[tuple[int, ...], tuple[Index, ...]]] + Cache to store already encountered groups to avoid duplications. + obj_indices: Sequence[tuple[Index, ...]] + The indices of all objects. + target_indices: Sequence[Index] + The target indices of the term the objects describe. + max_group_size: int + Upper limit for the allowed size of groups to consider. + idx_occurences: dict[Index, list[int]] + Map to connect an index to the objects (by position) it appears on + contracted: Sequence[Inde] + The common contracted indices the objects at ``positions`` share. + positions: Sequence[int] + The positions defining the group to further explore and expand. + forbidden_contracted: tuple[Index, ...], optional + Indices that are not allowed to be considered as contracted indices + during the exploration of the given group. + """ + # Iterate over all sensible subsets of contracted indices. + # For instance 2 objects might share the indices ijkl. + # However, k and l appear on 2 distinct other objects. + # Therefore, we should always contract over ij but the + # contraction over k and l should be optional since the + # group size has to grow for those contractions + contracted_variants = _contracted_variants( + contracted, positions, idx_occurences + ) + for contracted_indices in contracted_variants: + # update the positions including all objects that hold any of the + # contracted indices while checking the groups size + new_positions = tuple(sorted({ + pos for idx in contracted_indices + for pos in idx_occurences[idx] + })) + if len(new_positions) > max_group_size: continue - # get all the objects any of the contracted indices appears - positions = {pos for idx in contracted for pos in idx_occurences[idx]} - # group too large - if len(positions) > max_group_size: + # - try to update the contracted indices covering all indices + # that only appear on tensors already in the group. + # It does not make any sense to not contract over any of them + # since we can safely do so using the current group + # -> contracted_indices has to be a subset of new_contracted + indices = tuple(obj_indices[p] for p in new_positions) + new_contracted, _ = Contraction._split_contracted_and_target( + indices=indices, term_target_indices=target_indices + ) + new_contracted = [ + idx for idx in new_contracted + if all(pos in new_positions for pos in idx_occurences[idx]) + ] + assert all(s in new_contracted for s in contracted_indices) + # - however, if any of the safely contractable indices is marked + # as forbidden, we have to skip to avoid duplications + # since the combination will then be explored later + # -> new_contracted can not contain forbidden indices + if any(idx in forbidden_contracted for idx in new_contracted): continue - # avoid duplication: 0, 1 and 2 are connected by a common index + new_contracted = tuple(sorted(new_contracted, key=sort_idx_canonical)) + # - avoid duplications. For instance: + # 0, 1 and 2 are connected by a common index # -> the pair 0,1 and 0,2 will both give the triple 0,1,2 - # which will then grow in the same way independent of the starting - # pair. - key = tuple(sorted(positions)) - if key in groups: + # which will then grow in the same way independent of the starting + # pair. + if (new_positions, new_contracted) in seen_groups: continue - # store the minimal group - groups[key] = None - - # self-consistently update the contracted indices and the positions - # This corresponds to maximizing the group size. - # However, it is unclear if growing the group leads to a better - # scaling contraction. Therefore, also store smaller groups - while True: - # update the contracted indices - new_contracted, _ = Contraction._split_contracted_and_target( - [obj_indices[pos] for pos in positions], target_indices + # - current group is not a duplicate and can safely be returned while + # marking the group as explored. + seen_groups.add((new_positions, new_contracted)) + yield (new_positions, new_contracted) + # - To prevent duplications we don't want to contract over indices + # that will be covered in another iteration of contracted_variants. + # Also we don't want to mark any safely contractable indices + # as forbidden to avoid exploring stupid groups. + # -> mark missing optionaly contracted indices as forbidden + # (all indices in new_contracted not forbidden and safely contractable + # for the current group) + new_forbidden_contracted = forbidden_contracted + tuple( + s for s in contracted if s not in new_contracted + ) + # - See if there are any other contracted indices that repeat + # on new_positions that are not forbidden (will be explored later). + # -> new_contracted has logically to be a subset of + # available_contracted, since it is not possible for any index in + # new_contracted to appear in new_forbidden_contracted!! + available_contracted, _ = Contraction._split_contracted_and_target( + indices=indices, term_target_indices=target_indices + ) + available_contracted = [ + idx for idx in available_contracted + if idx not in new_forbidden_contracted + ] + available_contracted = sorted( + available_contracted, key=sort_idx_canonical + ) + assert all(s in available_contracted for s in new_contracted) + if len(available_contracted) > len(new_contracted): + child_groups = _explore_group( + seen_groups=seen_groups, obj_indices=obj_indices, + target_indices=target_indices, max_group_size=max_group_size, + idx_occurences=idx_occurences, contracted=available_contracted, + positions=new_positions, + forbidden_contracted=new_forbidden_contracted ) - # no new contracted indices pulled in -> we are done - if contracted == new_contracted: - break - # update the positions - new_positions = { - pos for idx in new_contracted for pos in idx_occurences[idx] - } - # no new positions or the extended group is too large - if new_positions == positions or \ - len(new_positions) > max_group_size: - break - # store the current group before trying to further - # increase the size - groups[tuple(sorted(new_positions))] = None - contracted = new_contracted - positions = new_positions - return (*groups.keys(), *outer_products) + for group in child_groups: + yield group + + +def _contracted_variants(contracted: Sequence[Index], + positions: Sequence[int], + idx_occurences: dict[Index, list[int]] + ) -> Generator[tuple[Index, ...], None, None]: + """ + Generates all sensible subsets of contracted indices for the + given ``contracted`` indices generated by a contraction of objects + at ``positions``. Thereby, a map connecting an index + to the objects (by position) they appear on is required + (``idx_occurences``) to avoid bad subsets. + """ + # we can always safely contract over indices that only appear on the + # already included positions (not contracting any of those would be + # stupid since the scaling remains the same but the memory scaling + # would increase) + safe_contracted = [] # those should always be contracted + optional_contracted = [] # contracting those will grow the group + for idx in contracted: + if all(pos in positions for pos in idx_occurences[idx]): + safe_contracted.append(idx) + else: + optional_contracted.append(idx) + safe_contracted = tuple(safe_contracted) + if safe_contracted: + yield safe_contracted + + if not optional_contracted: + return + # try to form all possible combinations for the optional contracted indices + combinations = itertools.chain.from_iterable( + itertools.combinations(optional_contracted, n) + for n in range(1, len(optional_contracted) + 1) + ) + for addition in combinations: + yield safe_contracted + addition def unoptimized_contraction(term: TermContainer, diff --git a/tests/generate_code/contraction_test.py b/tests/generate_code/contraction_test.py index b48aaca..da262ad 100644 --- a/tests/generate_code/contraction_test.py +++ b/tests/generate_code/contraction_test.py @@ -8,7 +8,7 @@ class TestContraction: def test_indices(self): - i, j, k, b, c = get_symbols("ijkbc") + i, j, k, b, c, P = get_symbols("ijkbcP") indices = ((i, k), (j, k)) names = ("f_oo", "f_oo") target_indices = (i, j) @@ -30,10 +30,29 @@ def test_indices(self): assert contr.contracted == (j, k) assert contr.target == (i,) indices = ((j, b), (j, k, b, c)) + target_indices = (k, c) names = ("ur1", "t2_1") contr = Contraction(indices, names, target_indices) assert contr.contracted == (j, b) assert contr.target == (k, c) + # custom contracted indices + indices = ((j, b), (j, k, b, c), (P, j, k), (P, b, c)) + names = ("ur1", "V_oovv", "U_aoo", "U_avv") + contr = Contraction(indices, names, term_target_indices=[]) + assert contr.contracted == (j, k, b, c, P) + assert contr.target == tuple() + contr = Contraction(indices, names, term_target_indices=[], + contracted=[j, b]) + assert contr.contracted == (j, b) + assert contr.target == (k, c, P) + contr = Contraction(indices, names, term_target_indices=[], + contracted=[j, b], target=[P, c, k]) + assert contr.contracted == (j, b) + assert contr.target == (k, c, P) + contr = Contraction(indices, names, term_target_indices=[], + target=[j, b, P, c, k]) + assert contr.contracted == tuple() + assert contr.target == (j, k, b, c, P) def test_scaling(self): i, j, k = get_symbols("ijk") diff --git a/tests/generate_code/optimize_contractions_test.py b/tests/generate_code/optimize_contractions_test.py index 01d5528..76f946c 100644 --- a/tests/generate_code/optimize_contractions_test.py +++ b/tests/generate_code/optimize_contractions_test.py @@ -10,7 +10,7 @@ class TestOptimizeContractions: - sizes = {"core": 5, "occ": 20, "virt": 200} + sizes = {"core": 5, "occ": 20, "virt": 200, "aux": 250} def test_factor(self): test = "{d^{i}_{a}} {d^{i}_{a}} {d^{j}_{b}} {d^{j}_{b}}" @@ -73,6 +73,56 @@ def test_hypercontraction(self): max_n_simultaneous_contracted=3, space_dims=self.sizes) + def test_qadc2(self): + # a more complicated example from the qadc2 equations + test = ( + "{p^{P}_{l}} {p^{P}_{a}} {B^{P}_{jk}} {p^{Q}_{k}} {p^{Q}_{b}} " + "{B^{Q}_{il}} {Y^{b}_{j}}" + ) + test = import_from_sympy_latex(test, convert_default_names=True) + res = optimize_contractions( + test.terms[0], target_indices="ia", space_dims=self.sizes + ) + for c in res: + print(c) + assert len(res) == 6 + i, j, k, l, a, b, P, Q = get_symbols("ijklabPQ") + ref = Contraction( + indices=((j, b), (Q, b)), names=("ur1", "p0_av"), + term_target_indices=(i, a) + ) + assert res[0] == ref + ref = Contraction( + indices=((j, Q), (P, j, k)), + names=(res[0].contraction_name, "B_aoo"), + term_target_indices=(i, a) + ) + assert res[1] == ref + ref = Contraction( + indices=((k, P, Q), (Q, k)), + names=(res[1].contraction_name, "p0_ao"), + term_target_indices=(i, a), contracted=(k,) + ) + assert res[2] == ref + ref = Contraction( + indices=((P, Q), (Q, i, l)), + names=(res[2].contraction_name, "B_aoo"), + term_target_indices=(i, a) + ) + assert res[3] == ref + ref = Contraction( + indices=((i, l, P), (P, l)), + names=(res[3].contraction_name, "p0_ao"), + term_target_indices=(i, a), contracted=(l,) + ) + assert res[4] == ref + ref = Contraction( + indices=((i, P), (P, a)), + names=(res[4].contraction_name, "p0_av"), + term_target_indices=(i, a) + ) + assert res[5] == ref + class TestGroupObjects: def test_full_connected(self): @@ -80,49 +130,63 @@ def test_full_connected(self): # 4 connected objects in different order without target indices relevant_obj_indices = [(i, j), (j, k), (j, k), (i, k)] target_indices = tuple() - res = _group_objects(obj_indices=relevant_obj_indices, - target_indices=target_indices) - assert res == ((0, 1, 2), (0, 1, 2, 3), (0, 3), (1, 2, 3)) + res = list(_group_objects( + obj_indices=relevant_obj_indices, target_indices=target_indices + )) + assert res == [((0, 1, 2), (j,)), ((0, 1, 2, 3), (i, j, k)), + ((0, 3), (i,)), ((1, 2, 3), (k,))] relevant_obj_indices = [(i, j), (j, k), (i, k), (j, k)] target_indices = tuple() - res = _group_objects(obj_indices=relevant_obj_indices, - target_indices=target_indices) - assert res == ((0, 1, 3), (0, 1, 2, 3), (0, 2), (1, 2, 3)) + res = list(_group_objects( + obj_indices=relevant_obj_indices, target_indices=target_indices + )) + assert res == [((0, 1, 3), (j,)), ((0, 1, 2, 3), (i, j, k)), + ((0, 2), (i,)), ((1, 2, 3), (k,))] relevant_obj_indices = [(j, k), (j, k), (i, j), (i, k)] target_indices = tuple() - res = _group_objects(obj_indices=relevant_obj_indices, - target_indices=target_indices) - assert res == ((0, 1, 2, 3), (0, 1, 2), (0, 1, 3), (2, 3)) + res = list(_group_objects( + obj_indices=relevant_obj_indices, target_indices=target_indices + )) + assert res == [((0, 1, 2), (j,)), ((0, 1, 3), (k,)), + ((0, 1, 2, 3), (i, j, k)), ((2, 3), (i,))] # with a target index (and a possible outer product) relevant_obj_indices = [(j, k), (j, k), (i, j), (i, k)] target_indices = (i,) - res = _group_objects(obj_indices=relevant_obj_indices, - target_indices=target_indices) - assert res == ((0, 1, 2, 3), (0, 1, 2), (0, 1, 3), (2, 3)) + res = list(_group_objects( + obj_indices=relevant_obj_indices, target_indices=target_indices + )) + assert res == [((0, 1, 2), (j,)), ((0, 1, 3), (k,)), + ((0, 1, 2, 3), (j, k)), ((2, 3), ())] def test_multiple_groups(self): i, j, k, l, p, q, r, s = get_symbols("ijklpqrs") # eri mo transformation relevant_obj_indices = [(p, q, r, s), (i, p), (j, q), (k, r), (l, s)] target_indices = (i, j, k, l) - res = _group_objects(obj_indices=relevant_obj_indices, - target_indices=target_indices) - ref = ((0, 1), (0, 2), (0, 3), (0, 4), - (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)) # outer + res = list(_group_objects( + obj_indices=relevant_obj_indices, target_indices=target_indices + )) + ref = [((0, 1), (p,)), ((0, 2), (q,)), ((0, 3), (r,)), ((0, 4), (s,)), + ((1, 2), ()), ((1, 3), ()), ((1, 4), ()), ((2, 3), ()), + ((2, 4), ()), ((3, 4), ())] assert res == ref # groups of different size relevant_obj_indices = [(p, q, r, s), (i, p), (j, p), (k, r), (l, s)] - target_indices = (i, j, k, l) - res = _group_objects(obj_indices=relevant_obj_indices, - target_indices=target_indices) - ref = ((0, 1, 2), (0, 3), (0, 4), - (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)) # outer + target_indices = (q, i, j, k, l) + res = list(_group_objects( + obj_indices=relevant_obj_indices, target_indices=target_indices + )) + ref = [((0, 1, 2), (p,)), ((0, 3), (r,)), ((0, 4), (s,)), + ((1, 3), ()), ((1, 4), ()), ((2, 3), ()), ((2, 4), ()), + ((3, 4), ())] assert res == ref # with an isolated group relevant_obj_indices = [(p, q, p, s), (i, p), (j, p), (k, r), (l, s)] - target_indices = (i, j, l) - res = _group_objects(obj_indices=relevant_obj_indices, - target_indices=target_indices) - ref = ((0, 1, 2), (0, 1, 2, 4), - (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)) # outer + target_indices = (q, i, j, l) + res = list(_group_objects( + obj_indices=relevant_obj_indices, target_indices=target_indices + )) + ref = [((0, 1, 2), (p,)), ((0, 4), (s,)), + ((0, 1, 2, 4), (p, s)), ((1, 3), ()), ((1, 4), ()), + ((2, 3), ()), ((2, 4), ()), ((3, 4), ())] assert res == ref