Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 60 additions & 13 deletions adcgen/generate_code/contraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,18 @@ class Contraction:
The names of the contracted tensors.
term_target_indices: tuple[Index]
The target indices of the term the contraction belongs to.
Note that it is not possible for indices that are given as
'term_target_indices' to become contracted indices!
contracted: Sequence[Index], optional
The contracted indices for the contraction. Note that the
indices might be reordered.
If not given they will be determined from the given indices
according to the einstein sum convention.
target: Sequence[Index], optional
The target indices for the contraction. Note that the given
indices might be reordered.
If not given they will be determined from the given indices
according to the einstein sum convention.
"""
# use counter that essentially counts how many class instances have
# been created
Expand All @@ -79,7 +91,9 @@ class Contraction:

def __init__(self, indices: Sequence[tuple[Index, ...]],
names: Sequence[str],
term_target_indices: Sequence[Index]) -> None:
term_target_indices: Sequence[Index],
contracted: Sequence[Index] | None = None,
target: Sequence[Index] | None = None) -> None:
if not isinstance(indices, tuple):
indices = tuple(indices)
if isinstance(names, str):
Expand All @@ -89,13 +103,18 @@ def __init__(self, indices: Sequence[tuple[Index, ...]],

self.indices: tuple[tuple[Index, ...], ...] = indices
self.names: tuple[str, ...] = names
self.contracted: tuple[Index, ...] = tuple()
self.target: tuple[Index, ...] = tuple()

self.contracted: tuple[Index, ...]
self.target: tuple[Index, ...]
self.scaling: Scaling
self._determine_contracted_and_target(
term_target_indices=term_target_indices, contracted=contracted,
target=target
)
self._determine_scaling()

self.id: int = next(self._instance_counter)
self.contraction_name: str = f"{self._base_name}_{self.id}"
self._determine_contracted_and_target(term_target_indices)
self._determine_scaling()

def __str__(self):
return (f"Contraction(indices={self.indices}, names={self.names}, "
Expand All @@ -106,18 +125,46 @@ def __str__(self):
def __repr__(self):
return self.__str__()

def _determine_contracted_and_target(self,
term_target_indices: Sequence[Index]
) -> None:
def _determine_contracted_and_target(
self, term_target_indices: Sequence[Index],
contracted: Sequence[Index] | None = None,
target: Sequence[Index] | None = None) -> None:
"""
Determines and sets the contracted and target indices on the
contraction using the provided target indices of the term
the contraction is a part of. In case the target indices of the
contraction contain the same indices as the target indices of the
term, the target indices of the term will be used instead.
the contraction is a part of. In case of an outer contraction
(the target indices are the reordered 'term_target_indices')
the term_target_indices will be used as target indices.
If provided, the provided contracted and target indices
will be used after reordering them.
"""
contracted, target = self._split_contracted_and_target(
self.indices, term_target_indices
if contracted is None and target is None: # determine both
contracted, target = self._split_contracted_and_target(
self.indices, term_target_indices
)
elif contracted is None: # target indices are given
# -> compute contracted removing duplicates
assert isinstance(target, Sequence)
contracted = list({
idx for idx in itertools.chain.from_iterable(self.indices)
if idx not in target
})
elif target is None: # contracted are given
# -> compute contracted removing duplicates
assert isinstance(contracted, Sequence)
target = list({
idx for idx in itertools.chain.from_iterable(self.indices)
if idx not in contracted
})
# sanity checks:
# - no index in contracted and target at the same time
# - no contracted index can be a target index of the overall term
# - all indices of the objects have to be in contracted or target
assert not any(idx in target for idx in contracted)
assert not any(idx in term_target_indices for idx in contracted)
assert all(
idx in target or idx in contracted
for idx in itertools.chain.from_iterable(self.indices)
)
# sort the indices canonical
contracted = sorted(contracted, key=sort_idx_canonical)
Expand Down
Loading