From c77a20ad4b557712460dc03a01e7a3902907e69e Mon Sep 17 00:00:00 2001 From: jonasleitner Date: Thu, 5 Jun 2025 08:13:42 +0200 Subject: [PATCH 1/5] update split_idx_string using regular expressions --- adcgen/indices.py | 19 +++++++------------ tests/indices_test.py | 31 ++++++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/adcgen/indices.py b/adcgen/indices.py index ab61638..eb715a7 100644 --- a/adcgen/indices.py +++ b/adcgen/indices.py @@ -1,5 +1,6 @@ from collections.abc import Sequence, Collection, Mapping from typing import Any, TypeGuard, TYPE_CHECKING +import re from sympy import Dummy, Tuple @@ -288,18 +289,12 @@ def split_idx_string(str_tosplit: str) -> list[str]: """ Splits an index string of the form 'ij12a3b' in a list ['i','j12','a3','b'] """ - splitted = [] - temp = [] - for i, idx in enumerate(str_tosplit): - temp.append(idx) - try: - if str_tosplit[i+1].isdigit(): - continue - else: - splitted.append("".join(temp)) - temp.clear() - except IndexError: - splitted.append("".join(temp)) + # findall only returns the matching parts of the string + # -> ensure that we don't loose part of the string + # (string starting with a numnber) + splitted = re.findall(r"\D\d*", str_tosplit) + if "".join(splitted) != str_tosplit: + raise ValueError(f"Invalid index string {str_tosplit}") return splitted diff --git a/tests/indices_test.py b/tests/indices_test.py index dba0d84..1417f80 100644 --- a/tests/indices_test.py +++ b/tests/indices_test.py @@ -1,4 +1,6 @@ -from adcgen.indices import Indices, get_symbols +from adcgen.indices import Indices, get_symbols, split_idx_string + +import pytest class TestIndices: @@ -57,3 +59,30 @@ def test_get_symbols(self): assert i is get_symbols("i", "a").pop() assert [i, I, a] == get_symbols("iIa", "abb") assert [a, i, I] == get_symbols("aiI", "bab") + + +def test_split_idx_string(): + # single index + res = split_idx_string("i") + assert res == ["i"] + res = split_idx_string("i3") + assert res == ["i3"] + res = split_idx_string("i3234235") + assert res == ["i3234235"] + # multiple indices without number + res = split_idx_string("iJa") + assert res == ["i", "J", "a"] + # multiple indices with numbers + res = split_idx_string("i3J11a2") + assert res == ["i3", "J11", "a2"] + # some indices with number + res = split_idx_string("i334Jab") + assert res == ["i334", "J", "a", "b"] + res = split_idx_string("i3J33ab") + assert res == ["i3", "J33", "a", "b"] + # arbitrary index names + res = split_idx_string("i⍺β3Ɣ23") + assert res == ["i", "⍺", "β3", "Ɣ23"] + # invalid string: starting with a number + with pytest.raises(ValueError): + split_idx_string("3Ɣ23") From 17c34bf68d5b2735042d55aa14f28684604bc3bd Mon Sep 17 00:00:00 2001 From: jonasleitner Date: Thu, 5 Jun 2025 16:00:31 +0200 Subject: [PATCH 2/5] refactor import_from_sympy_latex --- adcgen/__init__.py | 4 +- adcgen/expression/__init__.py | 4 +- adcgen/expression/import_from_sympy_latex.py | 450 ++++++++++++++++++ adcgen/func.py | 250 +--------- tests/conftest.py | 2 +- tests/expression/__init__.py | 0 .../import_from_sympy_latex_test.py | 137 ++++++ .../substitute_contracted_test.py | 0 tests/{ => expression}/symbolic_denom_test.py | 0 tests/factor_intermediates_test.py | 3 +- tests/generate_code/__init__.py | 0 tests/{ => generate_code}/contraction_test.py | 0 .../optimize_contractions_test.py | 2 +- 13 files changed, 598 insertions(+), 254 deletions(-) create mode 100644 adcgen/expression/import_from_sympy_latex.py create mode 100644 tests/expression/__init__.py create mode 100644 tests/expression/import_from_sympy_latex_test.py rename tests/{ => expression}/substitute_contracted_test.py (100%) rename tests/{ => expression}/symbolic_denom_test.py (100%) create mode 100644 tests/generate_code/__init__.py rename tests/{ => generate_code}/contraction_test.py (100%) rename tests/{ => generate_code}/optimize_contractions_test.py (99%) diff --git a/adcgen/__init__.py b/adcgen/__init__.py index f2aad38..4fc56f7 100644 --- a/adcgen/__init__.py +++ b/adcgen/__init__.py @@ -1,9 +1,9 @@ from .core_valence_separation import apply_cvs_approximation from .derivative import derivative from .eri_orbenergy import EriOrbenergy -from .expression import ExprContainer +from .expression import ExprContainer, import_from_sympy_latex from .factor_intermediates import factor_intermediates -from .func import import_from_sympy_latex, evaluate_deltas, wicks +from .func import evaluate_deltas, wicks from .generate_code import (generate_code, optimize_contractions, Contraction, unoptimized_contraction) from .groundstate import GroundState diff --git a/adcgen/expression/__init__.py b/adcgen/expression/__init__.py index d88cab0..6b86c3f 100644 --- a/adcgen/expression/__init__.py +++ b/adcgen/expression/__init__.py @@ -1,4 +1,5 @@ from .expr_container import ExprContainer +from .import_from_sympy_latex import import_from_sympy_latex from .normal_ordered_container import NormalOrderedContainer from .object_container import ObjectContainer from .polynom_container import PolynomContainer @@ -8,5 +9,6 @@ __all__ = [ "ExprContainer", "NormalOrderedContainer", "PolynomContainer", - "ObjectContainer", "TermContainer" + "ObjectContainer", "TermContainer", + "import_from_sympy_latex" ] diff --git a/adcgen/expression/import_from_sympy_latex.py b/adcgen/expression/import_from_sympy_latex.py new file mode 100644 index 0000000..38e0174 --- /dev/null +++ b/adcgen/expression/import_from_sympy_latex.py @@ -0,0 +1,450 @@ +import re + +from sympy.physics.secondquant import F, Fd, NO +from sympy import Expr, Mul, Pow, S, Symbol, sqrt, sympify + +from ..indices import Index, get_symbols +from ..sympy_objects import ( + Amplitude, AntiSymmetricTensor, KroneckerDelta, NonSymmetricTensor, + SymmetricTensor +) +from ..tensor_names import tensor_names, is_adc_amplitude, is_t_amplitude +from .expr_container import ExprContainer + + +def import_from_sympy_latex(expr_string: str, + convert_default_names: bool = False + ) -> ExprContainer: + """ + Imports an expression from a string created by the + :py:function:`sympy.latex` function. + + Parameters + ---------- + convert_default_names : bool, optional + If set, all default tensor names found in the expression to import + will be converted to the currently configured names. + + Returns + ------- + ExprContainer + The imported expression in a :py:class:`ExprContainer` container. + Note that bra-ket symmetry is not set during the import. + """ + expr_string = expr_string.strip() + if not expr_string: + return ExprContainer(0) + # split the expression in the individual terms and (potentially) + # add a '+' sign to the first term + term_strings = _split_terms(expr_string) + if not term_strings[0].startswith(("+", "-")): + term_strings[0] = "+ " + term_strings[0] + + expr = S.Zero + for term in term_strings: + expr += _import_term( + term_string=term, convert_default_names=convert_default_names + ) + return ExprContainer(expr) + + +######################## +# import functionality # +######################## +def _import_term(term_string: str, convert_default_names: bool) -> Expr: + """ + Import the given term from string + """ + # extract the sign of the term + sign = term_string[0] + if sign not in ["+", "-"]: + raise ValueError("Term string has to start with '+' or '-' sign.") + term: Expr = S.NegativeOne if sign == "-" else S.One + term_string = term_string[1:].strip() + + if term_string.startswith(r"\frac"): # fraction + term *= _import_fraction( + term_string, convert_default_names=convert_default_names + ) + else: + term *= _import_product( + term_string, convert_default_names=convert_default_names + ) + return term + + +def _import_fraction(fraction: str, convert_default_names: bool) -> Expr: + """ + Imports a fraction '\\frac{num}{denom}' from string. + """ + numerator, denominator = _split_fraction(fraction) + res: Expr = S.One + # import num + if _is_sum(numerator): + res *= import_from_sympy_latex( + numerator, convert_default_names=convert_default_names + ).inner + else: + res *= _import_product( + numerator, convert_default_names=convert_default_names + ) + # import denom + if _is_sum(denominator): + res /= import_from_sympy_latex( + denominator, convert_default_names=convert_default_names + ).inner + else: + res /= _import_product( + denominator, convert_default_names=convert_default_names + ) + assert isinstance(res, Expr) + return res + + +def _import_product(product: str, convert_default_names: bool) -> Expr: + """ + Imports a product (a term that is no fraction) of objects. + Objects are separated by a space. + """ + # we have to have a product at this point + return Mul(*( + _import_object(obj, convert_default_names=convert_default_names) + for obj in _split_product(product) + )) + + +def _import_object(obj_str: str, convert_default_names: bool) -> Expr: + """ + Imports the given object (a part of a product) from string. + """ + if obj_str.isnumeric(): # prefactor + return sympify(int(obj_str)) + elif obj_str.startswith(r"\sqrt{"): # sqrt{n} prefactor + assert obj_str[-1] == "}" + return sqrt(import_from_sympy_latex( + obj_str[6:-1].strip(), convert_default_names=convert_default_names + ).inner) + elif obj_str.startswith(r"\delta_{"): # KroneckerDelta + return _import_kronecker_delta( + obj_str, convert_default_names=convert_default_names + ) + elif obj_str.startswith(r"\left("): + return _import_polynom( + obj_str, convert_default_names=convert_default_names + ) + elif obj_str.startswith(r"\left\{"): # NO + return _import_normal_ordered( + obj_str, convert_default_names=convert_default_names + ) + else: + # the remaining objects are harder to identify: + # tensor, creation, annihilation or symbol + return _import_tensor_like( + obj_str, convert_default_names=convert_default_names + ) + + +def _import_kronecker_delta(delta: str, convert_default_names: bool) -> Expr: + """ + Imports the given KroneckerDelta of the from + '\\delta_{p q}' from string + """ + # a delta should not have an exponent! + delta, exponent = _split_base_and_exponent(delta) + assert exponent is None + _ = convert_default_names + assert delta.startswith(r"\delta_{") and delta.endswith("}") + # extract and import the indices + p, q = delta[8:-1].strip().split() + p, q = _import_index(p), _import_index(q) + return KroneckerDelta(p, q) + + +def _import_polynom(polynom: str, convert_default_names: bool): + """ + Imports the given polynom of the form '\\left(...)\\right^{exp}' + from string + """ + # try to extract the exponent (if available) (base)^{exponent} + base, exponent = _split_base_and_exponent(polynom) + assert base.startswith(r"\left(") and base.endswith(r"\right)") + # import base and exponent and build a Pow object + res = import_from_sympy_latex( + base[6:-7].strip(), convert_default_names=convert_default_names + ).inner + if exponent is not None: + assert exponent.startswith("{") and exponent.endswith("}") + exponent = import_from_sympy_latex( + exponent[1:-1].strip(), convert_default_names=convert_default_names + ).inner + res = Pow(res, exponent) + return res + + +def _import_normal_ordered(no: str, convert_default_names: bool): + """ + Imports the given NormalOrdered object of the form + '\\left\\{...\\right\\}' from string + """ + # a NO object should not have an exponent! + no, exponent = _split_base_and_exponent(no) + assert exponent is None + res = import_from_sympy_latex( + no[7:-8], convert_default_names=convert_default_names + ).inner + return NO(res) + + +def _import_tensor_like(tensor: str, convert_default_names: bool) -> Expr: + """ + Imports a tensor like object (Symbol, creation and annihilation operators + and Tensors). + """ + # possible input: + # - symbol: 'A^{exp}' + # - create: 'a_{p}^{exp}' + # - annihilate: 'a^\dagger_{p}' or '{a^\dagger_{p}}^{exp}' + # - antisymtensor + symtensor: {name^{upper}_{lower}}^{exp} + # - nonsymtensor: {name_{indices}}^{exp} + + # split the object in base and exponent + base, exponent = _split_base_and_exponent(tensor) + # remove the outer layer of curly braces (if present) + # so we can split with a stack size of zero + if base.startswith("{"): + base = base[1:] + if base.endswith("}"): + base = base[:-1] + components = _split_tensor_like(base) + assert components # we have to have at least a name + name = components[0] + components = components[1:] + # remove 1 layer of curly braces from the components (mostly indices) + # -> there should be no curly braces left afterwards + for i, comp in enumerate(components): + if comp.startswith("{"): + comp = comp[1:] + if comp.endswith("}"): + comp = comp[:-1] + components[i] = comp + # if desired map the default tensor names to their currently + # configured name. + # -> this allows expressions with the default names to + # be imported and mapped to the current configuration, correctly + # recognizing Amplitudes and SymmetricTensors. + # Name should be free of curly braces by now + if convert_default_names: + name = tensor_names.map_default_name(name) + # import the tensor like object + if not components: # Symbol + res = Symbol(name) + elif name == "a": # creation or annihilation operator + if len(components) == 2: # create + assert components[0] == r"\dagger" + res = Fd(_import_index(components[1])) + elif len(components) == 1: # annihilate + res = F(_import_index(components[0])) + else: + raise RuntimeError(f"Invalid second quantized operator: {tensor}.") + elif len(components) == 2: # antisymtensor, symtensor or amplitude + upper = _import_indices(components[0]) + lower = _import_indices(components[1]) + # figure out which tensor class to use + if is_adc_amplitude(name) or is_t_amplitude(name): + res = Amplitude(name, upper, lower) + elif name in (tensor_names.coulomb, tensor_names.sym_orb_denom, + tensor_names.ri_sym, + tensor_names.ri_asym_eri, + tensor_names.ri_asym_factor): + res = SymmetricTensor(name, upper, lower) + else: + res = AntiSymmetricTensor(name, upper, lower) + elif len(components) == 1: # nonsymtensor + res = NonSymmetricTensor(name, _import_indices(components[0])) + else: + raise RuntimeError(f"Unknown tensor like object {tensor}") + # attach the exponent if necessary + if exponent is not None: + assert exponent.startswith("{") and exponent.endswith("}") + exponent = import_from_sympy_latex( + exponent[1:-1], convert_default_names=convert_default_names + ).inner + res = Pow(res, exponent) + return res + + +def _import_indices(idx_str: str) -> list[Index]: + """ + Imports the given string of indices + """ + return [_import_index(idx) for idx in _split_indices(idx_str)] + + +def _import_index(index_str: str) -> Index: + """ + Imports the given index of the form + 'a', 'a2', 'a_{\\alpha}' or 'a2_{\\alpha}' + from string. + """ + # extract the spin + spin = None + if index_str.endswith(r"_{\alpha}"): + spin = "a" + index_str = index_str[:-9] + elif index_str.endswith(r"_{\beta}"): + spin = "b" + index_str = index_str[:-8] + # build the index + idx = get_symbols(index_str, spin) + assert len(idx) == 1 + return idx.pop() + + +################################################# +# functionality to split a string in components # +################################################# +def _split_terms(expr_string: str) -> list[str]: + """ + Split the expression string into a list of term strings + """ + # we need to split the string on +- signs while keeping track of the + # brackets (+ in a bracket does not indicate a new term) + stack: list[str] = [] + terms: list[str] = [] + term_start_idx = 0 + for i, char in enumerate(expr_string): + if char in ["{", "("]: + stack.append(char) + elif char == "}": + assert stack.pop() == "{" + elif char == ")": + assert stack.pop() == "(" + elif char in ["+", "-"] and not stack and i != term_start_idx: + terms.append(expr_string[term_start_idx:i].strip()) + term_start_idx = i + assert not stack + terms.append(expr_string[term_start_idx:].strip()) # last term + return terms + + +def _split_fraction(fraction: str) -> tuple[str, str]: + """ + Splits a fraction '\\frac{num}{denom}' string in numerator and denominator + """ + assert fraction.startswith(r"\frac{") and fraction.endswith("}") + # remove outer opening and closing brace + # -> num}{denom + fraction = fraction[6:-1] + stack = 0 + num, denom = None, None + for i, char in enumerate(fraction): + if char == "{": + stack += 1 + if not stack: # found the opening brace of the denominator + denom = fraction[i+1:].strip() # consume remaining string + break + elif char == "}": + if not stack: # found the closing brace of the numberator + num = fraction[:i].strip() + stack -= 1 + if num is None or denom is None: + raise ValueError("Could not extract numerator and denominator from " + f"{fraction}") + assert not stack + return num, denom + + +def _split_product(product: str) -> list[str]: + """ + Splits the product of objects that are separated by space into the + individual objects. + """ + # individual objects are separated by spaces + stack: list[str] = [] + objects: list[str] = [] + obj_start_idx = 0 + for i, char in enumerate(product): + if char in ["{", "("]: + stack.append(char) + elif char == "}": + assert stack.pop() == "{" + elif char == ")": + assert stack.pop() == "(" + elif char in ["+", "-"]: + # this breaks down if the input is a sum and no product + assert stack + elif char == " " and not stack and i != obj_start_idx: + objects.append(product[obj_start_idx:i].strip()) + obj_start_idx = i + objects.append(product[obj_start_idx:].strip()) # the last object + return objects + + +def _split_tensor_like(obj_str: str) -> list[str]: + """ + Splits a tensor like object (Symbol, F, Fd, SymbolicTensor) + onto its components using the delimiters '^' and '_'. + """ + stack = 0 + components: list[str] = [] + component_start_idx = 0 + for i, char in enumerate(obj_str): + if char == "{": + stack += 1 + elif char == "}": + stack -= 1 + elif char in ["^", "_"] and not stack and i != component_start_idx: + components.append(obj_str[component_start_idx:i].strip()) + component_start_idx = i + 1 + remainder = obj_str[component_start_idx:].strip() + if remainder: + components.append(remainder) + return components + + +def _split_base_and_exponent(obj_str: str) -> tuple[str, str | None]: + """ + Splits the object of the form base^exp in base and exponent. + If the object has no exponent, None is returned as exponent. + """ + stack = 0 + base, exponent = None, None + for i, char in enumerate(obj_str): + if char in ["{", "("]: + stack += 1 + elif char in ["}", ")"]: + stack -= 1 + elif char == "^" and not stack: + base = obj_str[:i].strip() + exponent = obj_str[i+1:].strip() + break + if base is None: # we have no exponent + base = obj_str + return base, exponent + + +def _split_indices(idx_str: str) -> list[str]: + """ + Splits an index string of the form 'ab2b_{alpha}b2_{beta}' into a + list [‘a‘, 'b2', 'b_{\\alpha}', 'b2_{\\beta}']. + """ + splitted = re.findall(r"\D\d*(?:_\{\\(?:alpha|beta)\})?", idx_str) + # ensure that we did not drop a part of the string + assert "".join(splitted) == idx_str + return splitted + + +###################################################### +# Functionality to identify the character of strings # +###################################################### +def _is_sum(sum: str) -> bool: + stack = 0 + for char in sum: + if char in ["{", "("]: + stack += 1 + elif char in ["}", ")"]: + stack -= 1 + elif char in ["+", "-"] and not stack: + return True + assert not stack + return False diff --git a/adcgen/func.py b/adcgen/func.py index ec5319b..6f26c88 100644 --- a/adcgen/func.py +++ b/adcgen/func.py @@ -4,17 +4,13 @@ from sympy.physics.secondquant import ( F, Fd, FermionicOperator, NO ) -from sympy import S, Add, Expr, Mul, Pow, sqrt, Symbol, sympify +from sympy import S, Add, Expr, Mul from .expression import ExprContainer from .misc import Inputerror from .rules import Rules -from .indices import Index, Indices, get_symbols, split_idx_string -from .sympy_objects import ( - KroneckerDelta, NonSymmetricTensor, AntiSymmetricTensor, SymmetricTensor, - Amplitude -) -from .tensor_names import is_adc_amplitude, is_t_amplitude, tensor_names +from .indices import Index, Indices, get_symbols +from .sympy_objects import KroneckerDelta def gen_term_orders(order: int, term_length: int, min_order: int @@ -48,246 +44,6 @@ def gen_term_orders(order: int, term_length: int, min_order: int return [comb for comb in combinations if sum(comb) == order] -def import_from_sympy_latex(expr_string: str, - convert_default_names: bool = False - ) -> ExprContainer: - """ - Imports an expression from a string created by the 'sympy.latex' function. - - Parameters - ---------- - convert_default_names : bool, optional - If set, all default tensor names found in the expression to import - will be converted to the currently configured names. - - Returns - ------- - ExprContainer - The imported expression in a 'Expr' container. Note that no assumptions - (sym_tensors or antisym_tensors) have been applied yet. - """ - - def import_indices(indices: str) -> list[Index]: - # split at the end of each index with a spin label - # -> n1n2n3_{spin} - idx: list[Index] = [] - for sub_part in indices.split("}"): - if not sub_part: # skip empty string - continue - if "_{\\" in sub_part: # the last index has a spin label - names, spin = sub_part.split("_{\\") - if spin not in ["alpha", "beta"]: - raise RuntimeError(f"Found invalid spin on Index: {spin}. " - f"Input: {indices}") - names = split_idx_string(names) - idx.extend(get_symbols(names[:-1])) - idx.extend(get_symbols(names[-1], spin[0])) - else: # no index has a spin label - idx.extend(get_symbols(sub_part)) - return idx - - def import_tensor(tensor: str) -> Expr: - # split the tensor in base and exponent - stack: list[str] = [] - separator: int | None = None - for i, c in enumerate(tensor): - if c == "{": - stack.append(c) - elif c == "}": - assert stack.pop() == "{" - elif not stack and c == "^": - separator = i - break - if separator is None: - exponent = 1 - else: - exponent = tensor[separator+1:] - exponent = int(exponent.lstrip("{").rstrip("}")) - tensor = tensor[:separator] - # done with processing the exponent - # -> deal with the tensor. remove 1 layer of curly brackets and - # afterwards split the tensor string into its components - if tensor[0] == "{": - tensor = tensor[1:] - if tensor[-1] == "}": - tensor = tensor[:-1] - stack.clear() - components: list[str] = [] - temp: list[str] = [] - for i, c in enumerate(tensor): - if c == "{": - stack.append(c) - elif c == "}": - assert stack.pop() == "{" - elif not stack and c in ["^", "_"]: - components.append("".join(temp)) - temp.clear() - continue - temp.append(c) - if temp: - components.append("".join(temp)) - name, indices = components[0], components[1:] - # if desired map the default tensor names to their currently - # configured name - # -> this allows expressions with the default names to - # be imported and mapped to the current configuration, correctly - # recognizing Amplitudes and SymmetricTensors. - if convert_default_names: - name = tensor_names.map_default_name(name) - - # remove 1 layer of brackets from all indices - for i, idx in enumerate(indices): - if idx[0] == "{": - idx = idx[1:] - if idx[-1] == "}": - idx = idx[:-1] - indices[i] = idx - - if len(indices) == 0: # no indices -> a symbol - base: Expr = Symbol(name) - elif name == "a": # create / annihilate - if len(indices) == 2 and indices[0] == "\\dagger": - base: Expr = Fd(*import_indices(indices[1])) - elif len(indices) == 1: - base: Expr = F(*import_indices(indices[0])) - else: - raise RuntimeError("Unknown second quantized operator: ", - tensor) - elif len(indices) == 2: # antisym-/symtensor or amplitude - upper = import_indices(indices[0]) - lower = import_indices(indices[1]) - # ADC-Amplitude or t-amplitudes - if is_adc_amplitude(name) or is_t_amplitude(name): - base: Expr = Amplitude(name, upper, lower) - elif name in (tensor_names.coulomb, tensor_names.ri_sym, - tensor_names.ri_asym_eri, - tensor_names.ri_asym_factor): - # eri in chemist notation or RI tensor - base: Expr = SymmetricTensor(name, upper, lower) - else: - base: Expr = AntiSymmetricTensor(name, upper, lower) - elif len(indices) == 1: # nonsymtensor - base: Expr = NonSymmetricTensor(name, import_indices(indices[0])) - else: - raise RuntimeError(f"Unknown tensor object: {tensor}") - assert isinstance(base, Expr) - return Pow(base, exponent) - - def import_obj(obj_str: str) -> Expr: - # import an individial object - if obj_str.isnumeric(): # prefactor - return sympify(int(obj_str)) - elif obj_str.startswith("\\sqrt{"): # sqrt{x} prefactor - return sqrt(int(obj_str[:-1].replace("\\sqrt{", "", 1))) - elif obj_str.startswith("\\delta_"): # KroneckerDelta - idx = obj_str[:-1].replace("\\delta_{", "", 1).split() - idx = import_indices("".join(idx)) - if len(idx) != 2: - raise RuntimeError(f"Invalid indices for delta: {idx}.") - ret = KroneckerDelta(*idx) - assert isinstance(ret, Expr) - return ret - elif obj_str.startswith("\\left("): # braket - # need to take care of exponent of the braket! - base, exponent = obj_str.rsplit('\\right)', 1) - if exponent: # exponent != "" -> ^{x} -> exponent != 1 - exponent = int(exponent[:-1].lstrip('^{')) - else: - exponent = 1 - obj_str = base.replace("\\left(", "", 1) - obj = import_from_sympy_latex( - obj_str, convert_default_names=convert_default_names - ) - return Pow(obj.inner, exponent) - elif obj_str.startswith("\\left\\{"): # NO - no, unexpected_stuff = obj_str.rsplit("\\right\\}", 1) - if unexpected_stuff: - raise NotImplementedError(f"Unexpected NO object: {obj_str}.") - obj_str = no.replace("\\left\\{", "", 1) - obj = import_from_sympy_latex( - obj_str, convert_default_names=convert_default_names - ) - return NO(obj.inner) - else: # tensor or creation/annihilation operator or symbol - return import_tensor(obj_str) - - def split_terms(expr_string: str) -> list[str]: - stack: list[str] = [] - terms: list[str] = [] - - term_start_idx = 0 - for i, char in enumerate(expr_string): - if char in ['{', '(']: - stack.append(char) - elif char == '}': - assert stack.pop() == '{' - elif char == ')': - assert stack.pop() == '(' - elif char in ['+', '-'] and not stack and i != term_start_idx: - terms.append(expr_string[term_start_idx:i]) - term_start_idx = i - terms.append(expr_string[term_start_idx:]) # append last term - return terms - - def import_term(term_string: str) -> Expr: - from sympy import Mul - - stack: list[str] = [] - objects: list[str] = [] - - obj_start_idx = 0 - for i, char in enumerate(term_string): - if char in ['{', '(']: - stack.append(char) - elif char == '}': - assert stack.pop() == '{' - elif char == ')': - assert stack.pop() == '(' - # in case we have a denom of the form: - # 2a+2b+4c and not 2 * (a+b+2c) - elif char in ['+', '-'] and not stack: - return import_from_sympy_latex( - term_string, convert_default_names=convert_default_names - ).inner - elif char == " " and not stack and i != obj_start_idx: - objects.append(term_string[obj_start_idx:i]) - obj_start_idx = i + 1 - objects.append(term_string[obj_start_idx:]) # last object - return Mul(*(import_obj(o) for o in objects)) - - expr_string = expr_string.strip() - if not expr_string: - return ExprContainer(0) - - terms = split_terms(expr_string) - if terms[0][0] not in ['+', '-']: - terms[0] = '+ ' + terms[0] - - sympy_expr = S.Zero - for term in terms: - sign = term[0] # extract the sign of the term - if sign not in ['+', '-']: - raise ValueError(f"Found invalid sign {sign} in term {term}") - term = term[1:].strip() - - sympy_term = S.NegativeOne if sign == '-' else S.One - assert isinstance(sympy_term, Expr) - - if term.startswith("\\frac"): # fraction - # remove frac layout and split: \\frac{...}{...} - num, denom = term[:-1].replace("\\frac{", "", 1).split("}{") - else: # no denominator - num, denom = term, None - - sympy_term = Mul(sympy_term, import_term(num)) - assert isinstance(sympy_term, Expr) - if denom is not None: - sympy_term = Mul(sympy_term, S.One/import_term(denom)) - sympy_expr += sympy_term - assert isinstance(sympy_expr, Expr) - return ExprContainer(sympy_expr) - - def evaluate_deltas( expr: Expr, target_idx: Sequence[str] | Index | Sequence[Index] | None = None diff --git a/tests/conftest.py b/tests/conftest.py index ce946f4..999031a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ import pathlib import json -from adcgen.func import import_from_sympy_latex +from adcgen import import_from_sympy_latex @pytest.fixture(scope='session') diff --git a/tests/expression/__init__.py b/tests/expression/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/expression/import_from_sympy_latex_test.py b/tests/expression/import_from_sympy_latex_test.py new file mode 100644 index 0000000..f24b9c7 --- /dev/null +++ b/tests/expression/import_from_sympy_latex_test.py @@ -0,0 +1,137 @@ +from adcgen import import_from_sympy_latex, get_symbols, tensor_names +from adcgen.sympy_objects import ( + Amplitude, AntiSymmetricTensor, KroneckerDelta, NonSymmetricTensor, + SymmetricTensor +) + +from sympy.physics.secondquant import F, Fd, NO +from sympy import Pow, Rational, S, latex, sqrt + + +class TestImportFromSympyLatex: + def test_empty(self): + assert import_from_sympy_latex("").inner is S.Zero + + def test_number(self): + assert import_from_sympy_latex("2").inner == 2 + assert import_from_sympy_latex(r"\frac{1}{2}").inner == Rational(1, 2) + assert import_from_sympy_latex(r"\sqrt{2}").inner == sqrt(2) + + def test_delta(self): + # spin orbitals + delta = KroneckerDelta(*get_symbols("pq")) + assert import_from_sympy_latex(latex(delta)).inner - delta is S.Zero + # spatial orbitals + delta = KroneckerDelta(*get_symbols("pq", spins="aa")) + assert import_from_sympy_latex(latex(delta)).inner - delta is S.Zero + # mixed + delta = KroneckerDelta(*get_symbols("p", spins="b"), *get_symbols("p")) + assert import_from_sympy_latex(latex(delta)).inner - delta is S.Zero + + def test_antisymmetric_tensor(self): + # spin orbitals + i, j, k, l = get_symbols("ijkl") # noqa E741 + tensor = AntiSymmetricTensor("x", (i, j), (k, l)) + assert import_from_sympy_latex(latex(tensor)).inner - tensor is S.Zero + # spatial orbitals + i, j, k, l = get_symbols("ijkl", spins="abab") # noqa E741 + tensor = AntiSymmetricTensor("tensor", (i,), (j, k, l)) + assert import_from_sympy_latex(latex(tensor)).inner - tensor is S.Zero + # exponent > 1 + tensor = Pow(tensor, 2) + assert import_from_sympy_latex(latex(tensor)).inner - tensor is S.Zero + # exponent < 1 + tensor = Pow(tensor, -1) + assert import_from_sympy_latex(latex(tensor)).inner - tensor is S.Zero + + def test_symmetric_tensor(self): + # spin orbitals + i, j, k, l = get_symbols("ijkl") # noqa E741 + tensor = SymmetricTensor(tensor_names.coulomb, (i, j), (k, l)) + assert import_from_sympy_latex(latex(tensor)).inner - tensor is S.Zero + # spatial orbitals + i, j, k, l = get_symbols("ijkl", spins="abab") # noqa E741 + tensor = SymmetricTensor(tensor_names.coulomb, (i,), (j, k, l)) + assert import_from_sympy_latex(latex(tensor)).inner - tensor is S.Zero + + def test_amplitude(self): + # spin orbitals + i, j, k, l = get_symbols("ijkl") # noqa E741 + tensor = Amplitude(tensor_names.left_adc_amplitude, (i, j), (k, l)) + assert import_from_sympy_latex(latex(tensor)).inner - tensor is S.Zero + # spatial orbitals + i, j, k, l = get_symbols("ijkl", spins="abab") # noqa E741 + tensor = Amplitude(tensor_names.gs_amplitude, (i,), (j, k, l)) + assert import_from_sympy_latex(latex(tensor)).inner - tensor is S.Zero + + def test_nonsymmetric_tensor(self): + # spin orbitals + i, j = get_symbols("ij") # noqa E741 + tensor = NonSymmetricTensor("bla", (i, j)) + assert import_from_sympy_latex(latex(tensor)).inner - tensor is S.Zero + # spatial orbitals + i, j = get_symbols("ij", spins="ab") # noqa E741 + tensor = NonSymmetricTensor("bla", (i, j)) + assert import_from_sympy_latex(latex(tensor)).inner - tensor is S.Zero + + def test_second_quant_operator(self): + i, j = get_symbols("ij") + op = F(i) + assert import_from_sympy_latex(latex(op)).inner - op is S.Zero + op = Fd(i) + assert import_from_sympy_latex(latex(op)).inner - op is S.Zero + op = NO(F(i) * Fd(j)) + assert import_from_sympy_latex(latex(op)).inner - op is S.Zero + + def test_product(self): + i, j = get_symbols("ij") + prod = ( + AntiSymmetricTensor(tensor_names.fock, (i,), (j,)) + * NonSymmetricTensor(tensor_names.orb_energy, (i,)) + ) + assert import_from_sympy_latex(latex(prod)).inner - prod is S.Zero + + def test_sum(self): + i, j = get_symbols("ij") + sum = ( + AntiSymmetricTensor(tensor_names.fock, (i,), (j,)) + + Pow(NonSymmetricTensor(tensor_names.orb_energy, (i,)), -42) + ) + assert import_from_sympy_latex(latex(sum)).inner - sum is S.Zero + + def test_simple_frac(self): + # t2_1 amplitude + i, j, a, b = get_symbols("ijab") + num = AntiSymmetricTensor(tensor_names.eri, (i, j), (a, b)) + denom = ( + NonSymmetricTensor(tensor_names.orb_energy, (a,)) + + NonSymmetricTensor(tensor_names.orb_energy, (b,)) + - NonSymmetricTensor(tensor_names.orb_energy, (i,)) + - NonSymmetricTensor(tensor_names.orb_energy, (j,)) + ) + fraction = num / denom + imported = import_from_sympy_latex(latex(fraction)) + assert imported.inner - fraction is S.Zero + + def test_complex_frac(self): + # t2_1 * 1 / singles denom + i, j, k, a, b, c = get_symbols("ijkabc", "ababab") + num = AntiSymmetricTensor(tensor_names.eri, (i, j), (a, b)) + denom = ( + NonSymmetricTensor(tensor_names.orb_energy, (a,)) + + NonSymmetricTensor(tensor_names.orb_energy, (b,)) + - NonSymmetricTensor(tensor_names.orb_energy, (i,)) + - NonSymmetricTensor(tensor_names.orb_energy, (j,)) + ) + denom *= ( + - NonSymmetricTensor(tensor_names.orb_energy, (c,)) + + NonSymmetricTensor(tensor_names.orb_energy, (k,)) + ) + fraction = num / denom + # check the not expanded expression + imported = import_from_sympy_latex(latex(fraction)) + assert imported.inner - fraction is S.Zero + # check the expanded expression + fraction = fraction.expand() + imported = import_from_sympy_latex(latex(fraction)) + assert imported.inner - fraction is S.Zero diff --git a/tests/substitute_contracted_test.py b/tests/expression/substitute_contracted_test.py similarity index 100% rename from tests/substitute_contracted_test.py rename to tests/expression/substitute_contracted_test.py diff --git a/tests/symbolic_denom_test.py b/tests/expression/symbolic_denom_test.py similarity index 100% rename from tests/symbolic_denom_test.py rename to tests/expression/symbolic_denom_test.py diff --git a/tests/factor_intermediates_test.py b/tests/factor_intermediates_test.py index 35d68fb..10b31f1 100644 --- a/tests/factor_intermediates_test.py +++ b/tests/factor_intermediates_test.py @@ -1,6 +1,5 @@ -from adcgen.expression import ExprContainer +from adcgen.expression import ExprContainer, import_from_sympy_latex from adcgen.factor_intermediates import factor_intermediates -from adcgen.func import import_from_sympy_latex from adcgen.intermediates import t2eri_2, p0_3_oo from adcgen.simplify import simplify diff --git a/tests/generate_code/__init__.py b/tests/generate_code/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/contraction_test.py b/tests/generate_code/contraction_test.py similarity index 100% rename from tests/contraction_test.py rename to tests/generate_code/contraction_test.py diff --git a/tests/optimize_contractions_test.py b/tests/generate_code/optimize_contractions_test.py similarity index 99% rename from tests/optimize_contractions_test.py rename to tests/generate_code/optimize_contractions_test.py index 12a2866..01d5528 100644 --- a/tests/optimize_contractions_test.py +++ b/tests/generate_code/optimize_contractions_test.py @@ -1,4 +1,4 @@ -from adcgen.func import import_from_sympy_latex +from adcgen.expression import import_from_sympy_latex from adcgen.generate_code.contraction import Contraction from adcgen.generate_code.optimize_contractions import ( _group_objects, optimize_contractions From 211d09ec3a58b174d19fce7902063be2ab521175 Mon Sep 17 00:00:00 2001 From: jonasleitner Date: Fri, 6 Jun 2025 20:12:54 +0200 Subject: [PATCH 3/5] add callables to identify amplitudes and symmetrictensors during import --- adcgen/expression/import_from_sympy_latex.py | 183 ++++++++++++++---- .../import_from_sympy_latex_test.py | 36 ++++ 2 files changed, 177 insertions(+), 42 deletions(-) diff --git a/adcgen/expression/import_from_sympy_latex.py b/adcgen/expression/import_from_sympy_latex.py index 38e0174..d6237be 100644 --- a/adcgen/expression/import_from_sympy_latex.py +++ b/adcgen/expression/import_from_sympy_latex.py @@ -1,9 +1,10 @@ +from collections.abc import Callable import re from sympy.physics.secondquant import F, Fd, NO -from sympy import Expr, Mul, Pow, S, Symbol, sqrt, sympify +from sympy import Expr, Pow, S, Symbol, sqrt, sympify -from ..indices import Index, get_symbols +from ..indices import Index, get_symbols, split_idx_string from ..sympy_objects import ( Amplitude, AntiSymmetricTensor, KroneckerDelta, NonSymmetricTensor, SymmetricTensor @@ -12,18 +13,44 @@ from .expr_container import ExprContainer -def import_from_sympy_latex(expr_string: str, - convert_default_names: bool = False - ) -> ExprContainer: +def import_from_sympy_latex( + expr_string: str, convert_default_names: bool = False, + is_amplitude: Callable[[str], bool] | None = None, + is_symmetric_tensor: Callable[[str], bool] | None = None + ) -> ExprContainer: """ Imports an expression from a string created by the - :py:function:`sympy.latex` function. + :py:func:`sympy.latex` function. Parameters ---------- convert_default_names : bool, optional If set, all default tensor names found in the expression to import will be converted to the currently configured names. + For instance, ERIs named 'V' by default will be renamed to + whatever :py:attr:`TensorNames.eri` defines. + is_amplitude: callable, optional + A callable that takes a tensor name and returns whether a + tensor with the corresponding name should be imported + as :py:class:`Amplitude`. + Note that this is checked after the (optional) conversion + of default names, i.e., tensors named 't' (default name + for ground state amplitudes) will first be converted to + :py:attr:`TensorNames.gs_amplitude` before + consulting the callable. + Defaults to :py:func:`is_amplitude` defined below. + is_symmetric_tensor: callable, optional + A callable that takes a tensor name and returns whether a + tensor with the corresponding name should be imported + as :py:class:`SymmetricTensor`. + Note that this is checked after the (optional) conversion + of default names and after checking whether the + tensor should be imported as :py:class:`Amplitude`. + Tensors (with upper and lower indices) that are not + identified as :py:class:`Amplitude` or + :py:class:`SymmetricTensor` will finally be imported as + :py:class:`AntiSymmetricTensor`. + Defaults to :py:func:`is_symmetric_tensor` defined below. Returns ------- @@ -31,6 +58,13 @@ def import_from_sympy_latex(expr_string: str, The imported expression in a :py:class:`ExprContainer` container. Note that bra-ket symmetry is not set during the import. """ + if is_amplitude is None: + is_amplitude = globals()["is_amplitude"] + assert isinstance(is_amplitude, Callable) + if is_symmetric_tensor is None: + is_symmetric_tensor = globals()["is_symmetric_tensor"] + assert isinstance(is_symmetric_tensor, Callable) + expr_string = expr_string.strip() if not expr_string: return ExprContainer(0) @@ -43,15 +77,41 @@ def import_from_sympy_latex(expr_string: str, expr = S.Zero for term in term_strings: expr += _import_term( - term_string=term, convert_default_names=convert_default_names + term_string=term, convert_default_names=convert_default_names, + is_amplitude=is_amplitude, is_symmetric_tensor=is_symmetric_tensor ) return ExprContainer(expr) +def is_amplitude(name: str) -> bool: + """ + Whether a tensor with the given name should be imported as + :py:class:`Amplitude` tensor. + (ADC or ground state amplitude) + """ + return is_adc_amplitude(name) or is_t_amplitude(name) + + +def is_symmetric_tensor(name: str) -> bool: + """ + Whether a tensor with the given name should be imported as + :py:class:`SymmetricTensor`. + (Coulomb, symbolic orbital energy denominator, RI tensors) + """ + sym_names = ( + tensor_names.coulomb, tensor_names.sym_orb_denom, + tensor_names.ri_sym, tensor_names.ri_asym_eri, + tensor_names.ri_asym_factor + ) + return name in sym_names + + ######################## # import functionality # ######################## -def _import_term(term_string: str, convert_default_names: bool) -> Expr: +def _import_term(term_string: str, convert_default_names: bool, + is_amplitude: Callable[[str], bool], + is_symmetric_tensor: Callable[[str], bool]) -> Expr: """ Import the given term from string """ @@ -64,16 +124,20 @@ def _import_term(term_string: str, convert_default_names: bool) -> Expr: if term_string.startswith(r"\frac"): # fraction term *= _import_fraction( - term_string, convert_default_names=convert_default_names + term_string, convert_default_names=convert_default_names, + is_amplitude=is_amplitude, is_symmetric_tensor=is_symmetric_tensor ) else: term *= _import_product( - term_string, convert_default_names=convert_default_names + term_string, convert_default_names=convert_default_names, + is_amplitude=is_amplitude, is_symmetric_tensor=is_symmetric_tensor ) return term -def _import_fraction(fraction: str, convert_default_names: bool) -> Expr: +def _import_fraction(fraction: str, convert_default_names: bool, + is_amplitude: Callable[[str], bool], + is_symmetric_tensor: Callable[[str], bool]) -> Expr: """ Imports a fraction '\\frac{num}{denom}' from string. """ @@ -82,38 +146,49 @@ def _import_fraction(fraction: str, convert_default_names: bool) -> Expr: # import num if _is_sum(numerator): res *= import_from_sympy_latex( - numerator, convert_default_names=convert_default_names + numerator, convert_default_names=convert_default_names, + is_amplitude=is_amplitude, is_symmetric_tensor=is_symmetric_tensor ).inner else: res *= _import_product( - numerator, convert_default_names=convert_default_names + numerator, convert_default_names=convert_default_names, + is_amplitude=is_amplitude, is_symmetric_tensor=is_symmetric_tensor ) # import denom if _is_sum(denominator): res /= import_from_sympy_latex( - denominator, convert_default_names=convert_default_names + denominator, convert_default_names=convert_default_names, + is_amplitude=is_amplitude, is_symmetric_tensor=is_symmetric_tensor ).inner else: res /= _import_product( - denominator, convert_default_names=convert_default_names + denominator, convert_default_names=convert_default_names, + is_amplitude=is_amplitude, is_symmetric_tensor=is_symmetric_tensor ) assert isinstance(res, Expr) return res -def _import_product(product: str, convert_default_names: bool) -> Expr: +def _import_product(product: str, convert_default_names: bool, + is_amplitude: Callable[[str], bool], + is_symmetric_tensor: Callable[[str], bool]) -> Expr: """ Imports a product (a term that is no fraction) of objects. Objects are separated by a space. """ # we have to have a product at this point - return Mul(*( - _import_object(obj, convert_default_names=convert_default_names) - for obj in _split_product(product) - )) + res = S.One + for obj in _split_product(product): + res *= _import_object( + obj, convert_default_names=convert_default_names, + is_amplitude=is_amplitude, is_symmetric_tensor=is_symmetric_tensor + ) + return res -def _import_object(obj_str: str, convert_default_names: bool) -> Expr: +def _import_object(obj_str: str, convert_default_names: bool, + is_amplitude: Callable[[str], bool], + is_symmetric_tensor: Callable[[str], bool]) -> Expr: """ Imports the given object (a part of a product) from string. """ @@ -122,37 +197,45 @@ def _import_object(obj_str: str, convert_default_names: bool) -> Expr: elif obj_str.startswith(r"\sqrt{"): # sqrt{n} prefactor assert obj_str[-1] == "}" return sqrt(import_from_sympy_latex( - obj_str[6:-1].strip(), convert_default_names=convert_default_names + obj_str[6:-1].strip(), convert_default_names=convert_default_names, + is_amplitude=is_amplitude, is_symmetric_tensor=is_symmetric_tensor ).inner) elif obj_str.startswith(r"\delta_{"): # KroneckerDelta return _import_kronecker_delta( - obj_str, convert_default_names=convert_default_names + obj_str, convert_default_names=convert_default_names, + is_amplitude=is_amplitude, is_symmetric_tensor=is_symmetric_tensor ) elif obj_str.startswith(r"\left("): return _import_polynom( - obj_str, convert_default_names=convert_default_names + obj_str, convert_default_names=convert_default_names, + is_amplitude=is_amplitude, is_symmetric_tensor=is_symmetric_tensor ) elif obj_str.startswith(r"\left\{"): # NO return _import_normal_ordered( - obj_str, convert_default_names=convert_default_names + obj_str, convert_default_names=convert_default_names, + is_amplitude=is_amplitude, is_symmetric_tensor=is_symmetric_tensor ) else: # the remaining objects are harder to identify: # tensor, creation, annihilation or symbol return _import_tensor_like( - obj_str, convert_default_names=convert_default_names + obj_str, convert_default_names=convert_default_names, + is_amplitude=is_amplitude, is_symmetric_tensor=is_symmetric_tensor ) -def _import_kronecker_delta(delta: str, convert_default_names: bool) -> Expr: +def _import_kronecker_delta(delta: str, convert_default_names: bool, + is_amplitude: Callable[[str], bool], + is_symmetric_tensor: Callable[[str], bool] + ) -> Expr: """ Imports the given KroneckerDelta of the from '\\delta_{p q}' from string """ + _ = convert_default_names, is_amplitude, is_symmetric_tensor # a delta should not have an exponent! delta, exponent = _split_base_and_exponent(delta) assert exponent is None - _ = convert_default_names assert delta.startswith(r"\delta_{") and delta.endswith("}") # extract and import the indices p, q = delta[8:-1].strip().split() @@ -160,7 +243,10 @@ def _import_kronecker_delta(delta: str, convert_default_names: bool) -> Expr: return KroneckerDelta(p, q) -def _import_polynom(polynom: str, convert_default_names: bool): +def _import_polynom(polynom: str, convert_default_names: bool, + is_amplitude: Callable[[str], bool], + is_symmetric_tensor: Callable[[str], bool] + ) -> Expr: """ Imports the given polynom of the form '\\left(...)\\right^{exp}' from string @@ -170,18 +256,24 @@ def _import_polynom(polynom: str, convert_default_names: bool): assert base.startswith(r"\left(") and base.endswith(r"\right)") # import base and exponent and build a Pow object res = import_from_sympy_latex( - base[6:-7].strip(), convert_default_names=convert_default_names + base[6:-7].strip(), convert_default_names=convert_default_names, + is_amplitude=is_amplitude, is_symmetric_tensor=is_symmetric_tensor ).inner if exponent is not None: assert exponent.startswith("{") and exponent.endswith("}") exponent = import_from_sympy_latex( - exponent[1:-1].strip(), convert_default_names=convert_default_names + exponent[1:-1].strip(), + convert_default_names=convert_default_names, + is_amplitude=is_amplitude, is_symmetric_tensor=is_symmetric_tensor ).inner res = Pow(res, exponent) return res -def _import_normal_ordered(no: str, convert_default_names: bool): +def _import_normal_ordered(no: str, convert_default_names: bool, + is_amplitude: Callable[[str], bool], + is_symmetric_tensor: Callable[[str], bool] + ) -> Expr: """ Imports the given NormalOrdered object of the form '\\left\\{...\\right\\}' from string @@ -190,12 +282,16 @@ def _import_normal_ordered(no: str, convert_default_names: bool): no, exponent = _split_base_and_exponent(no) assert exponent is None res = import_from_sympy_latex( - no[7:-8], convert_default_names=convert_default_names + no[7:-8], convert_default_names=convert_default_names, + is_amplitude=is_amplitude, is_symmetric_tensor=is_symmetric_tensor ).inner return NO(res) -def _import_tensor_like(tensor: str, convert_default_names: bool) -> Expr: +def _import_tensor_like(tensor: str, convert_default_names: bool, + is_amplitude: Callable[[str], bool], + is_symmetric_tensor: Callable[[str], bool] + ) -> Expr: """ Imports a tensor like object (Symbol, creation and annihilation operators and Tensors). @@ -250,12 +346,9 @@ def _import_tensor_like(tensor: str, convert_default_names: bool) -> Expr: upper = _import_indices(components[0]) lower = _import_indices(components[1]) # figure out which tensor class to use - if is_adc_amplitude(name) or is_t_amplitude(name): + if is_amplitude(name): res = Amplitude(name, upper, lower) - elif name in (tensor_names.coulomb, tensor_names.sym_orb_denom, - tensor_names.ri_sym, - tensor_names.ri_asym_eri, - tensor_names.ri_asym_factor): + elif is_symmetric_tensor(name): res = SymmetricTensor(name, upper, lower) else: res = AntiSymmetricTensor(name, upper, lower) @@ -267,7 +360,8 @@ def _import_tensor_like(tensor: str, convert_default_names: bool) -> Expr: if exponent is not None: assert exponent.startswith("{") and exponent.endswith("}") exponent = import_from_sympy_latex( - exponent[1:-1], convert_default_names=convert_default_names + exponent[1:-1], convert_default_names=convert_default_names, + is_amplitude=is_amplitude, is_symmetric_tensor=is_symmetric_tensor ).inner res = Pow(res, exponent) return res @@ -294,6 +388,9 @@ def _import_index(index_str: str) -> Index: elif index_str.endswith(r"_{\beta}"): spin = "b" index_str = index_str[:-8] + else: + # ensure that we have a single index without a spin + assert len(split_idx_string(index_str)) == 1 # build the index idx = get_symbols(index_str, spin) assert len(idx) == 1 @@ -374,7 +471,9 @@ def _split_product(product: str) -> list[str]: # this breaks down if the input is a sum and no product assert stack elif char == " " and not stack and i != obj_start_idx: - objects.append(product[obj_start_idx:i].strip()) + obj = product[obj_start_idx:i].strip() + if obj: + objects.append(obj) obj_start_idx = i objects.append(product[obj_start_idx:].strip()) # the last object return objects @@ -425,7 +524,7 @@ def _split_base_and_exponent(obj_str: str) -> tuple[str, str | None]: def _split_indices(idx_str: str) -> list[str]: """ - Splits an index string of the form 'ab2b_{alpha}b2_{beta}' into a + Splits an index string of the form 'ab2b_{\\alpha}b2_{\\beta}' into a list [‘a‘, 'b2', 'b_{\\alpha}', 'b2_{\\beta}']. """ splitted = re.findall(r"\D\d*(?:_\{\\(?:alpha|beta)\})?", idx_str) diff --git a/tests/expression/import_from_sympy_latex_test.py b/tests/expression/import_from_sympy_latex_test.py index f24b9c7..20e8dd8 100644 --- a/tests/expression/import_from_sympy_latex_test.py +++ b/tests/expression/import_from_sympy_latex_test.py @@ -135,3 +135,39 @@ def test_complex_frac(self): fraction = fraction.expand() imported = import_from_sympy_latex(latex(fraction)) assert imported.inner - fraction is S.Zero + + def test_is_amplitude(self): + i, j = get_symbols("ij") + tensor = Amplitude(tensor_names.eri, (i,), (j,)) + # by default we should get a AntiSymmetricTensor + imported = import_from_sympy_latex(latex(tensor)) + assert tensor - imported.inner is not S.Zero + # now import as amplitude + imported = import_from_sympy_latex( + latex(tensor), is_amplitude=lambda n: n == tensor_names.eri + ) + assert tensor - imported.inner is S.Zero + # is_amplitude should have priority over is_symmetric_tensor + imported = import_from_sympy_latex( + latex(tensor), is_amplitude=lambda n: n == tensor_names.eri, + is_symmetric_tensor=lambda n: n == tensor_names.eri + ) + assert tensor - imported.inner is S.Zero + + def test_is_symmetric_tensor(self): + i, j = get_symbols("ij") + tensor = SymmetricTensor(tensor_names.eri, (i,), (j,)) + # by default we should get a AntiSymmetricTensor + imported = import_from_sympy_latex(latex(tensor)) + assert tensor - imported.inner is not S.Zero + # now import as symmetrictensor + imported = import_from_sympy_latex( + latex(tensor), is_symmetric_tensor=lambda n: n == tensor_names.eri + ) + assert tensor - imported.inner is S.Zero + # is_amplitude should have priority over is_symmetric_tensor + imported = import_from_sympy_latex( + latex(tensor), is_amplitude=lambda n: n == tensor_names.eri, + is_symmetric_tensor=lambda n: n == tensor_names.eri + ) + assert tensor - imported.inner is not S.Zero From 914462c222976b078cc484456319489c2aab6f68 Mon Sep 17 00:00:00 2001 From: jonasleitner Date: Sat, 7 Jun 2025 08:21:08 +0200 Subject: [PATCH 4/5] fix bug in spin detection in coulomb expansion and add ri unit tests --- adcgen/expression/object_container.py | 4 +- adcgen/resolution_of_identity.py | 17 ++--- tests/resolution_of_identity_test.py | 100 +++++++++++++++++++++++++- 3 files changed, 107 insertions(+), 14 deletions(-) diff --git a/adcgen/expression/object_container.py b/adcgen/expression/object_container.py index cb39958..559e614 100644 --- a/adcgen/expression/object_container.py +++ b/adcgen/expression/object_container.py @@ -750,7 +750,7 @@ def expand_coulomb_ri(self, factorisation: str = 'sym', """ from .expr_container import ExprContainer - if factorisation not in ('sym', 'asym'): + if factorisation not in ("sym", "asym"): raise NotImplementedError("Only symmetric ('sym') and asymmetric " "('asym') factorisation of the Coulomb " "integral is implemented") @@ -776,7 +776,7 @@ def expand_coulomb_ri(self, factorisation: str = 'sym', # assign alpha spin if represented in spatial orbitals idx = self.idx has_spin = bool(idx[0].spin) - if any(bool(s) != has_spin for s in idx): + if any(bool(s.spin) != has_spin for s in idx): raise NotImplementedError(f"The coulomb integral {self} has " "to be represented either in spatial" " or spin orbitals. A mixture is not" diff --git a/adcgen/resolution_of_identity.py b/adcgen/resolution_of_identity.py index d38bb25..63044d4 100644 --- a/adcgen/resolution_of_identity.py +++ b/adcgen/resolution_of_identity.py @@ -56,20 +56,17 @@ def apply_resolution_of_identity(expr: ExprContainer, Inputerror If the expression still contains antisymmetric ERIs. """ - + assert isinstance(expr, ExprContainer) # Check if a valid factorisation is given - if factorisation not in ('sym', 'asym'): - raise Inputerror('Only symmetric (sym) and asymmetric (asym) ' - 'factorisation modes are supported. ' - f'Received: {factorisation}') - + if factorisation not in ("sym", "asym"): + raise Inputerror("Only symmetric (sym) and asymmetric (asym) " + "factorisation modes are supported. " + f"Received: {factorisation}") # Check whether the expression contains antisymmetric ERIs if Symbol(tensor_names.eri) in expr.inner.atoms(Symbol): - raise Inputerror('Resolution of Identity requires that the ERIs' - ' be expanded first.') - + raise Inputerror("Resolution of Identity requires that the ERIs" + " be expanded first.") ri_expr = expr.expand_coulomb_ri(factorisation=factorisation) if resolve_indices: ri_expr.substitute_contracted() - return ri_expr diff --git a/tests/resolution_of_identity_test.py b/tests/resolution_of_identity_test.py index fb745d4..ebfb238 100644 --- a/tests/resolution_of_identity_test.py +++ b/tests/resolution_of_identity_test.py @@ -1,7 +1,12 @@ -from adcgen.spatial_orbitals import transform_to_spatial_orbitals +from adcgen.expression import ExprContainer +from adcgen.indices import Index, get_symbols +from adcgen.misc import Inputerror from adcgen.resolution_of_identity import apply_resolution_of_identity from adcgen.simplify import simplify -from adcgen.expression import ExprContainer +from adcgen.spatial_orbitals import transform_to_spatial_orbitals +from adcgen import ( + AntiSymmetricTensor, SymmetricTensor, tensor_names, +) from sympy import S @@ -9,6 +14,97 @@ class TestResolutionOfIdentity(): + def test_sanity_checks(self): + i, j, k, l = get_symbols("ijkl") # noqa E741 + # forgot to expand the antisym ERI + tensor = AntiSymmetricTensor(tensor_names.eri, (i, j), (k, l)) + with pytest.raises(Inputerror): + apply_resolution_of_identity(ExprContainer(tensor)) + # implementation assumes braket symmetry + tensor = SymmetricTensor(tensor_names.coulomb, (i, j), (k, l)) + with pytest.raises(NotImplementedError): + apply_resolution_of_identity(ExprContainer(tensor)) + # mix of spin and spatial orbitals is not allowed + i, j = get_symbols("ij", "aa") + tensor = SymmetricTensor(tensor_names.coulomb, (i, j), (k, l)) + with pytest.raises(NotImplementedError): + apply_resolution_of_identity(ExprContainer(tensor)) + + def test_sym_factorisation(self): + i, j, k, l, P = get_symbols("ijklP") # noqa E741 + # expand a single coulomb integral in spin orbital basis + tensor = SymmetricTensor(tensor_names.coulomb, (i, j), (k, l), 1) + res = apply_resolution_of_identity( + ExprContainer(tensor), factorisation="sym" + ) + ref = ( + SymmetricTensor(tensor_names.ri_sym, (P,), (i, j)) + * SymmetricTensor(tensor_names.ri_sym, (P,), (k, l)) + ) + assert res.inner - ref is S.Zero + # with spatial orbitals + i, j, k, l, P = get_symbols("ijklP", "ababa") # noqa E741 + tensor = SymmetricTensor(tensor_names.coulomb, (i, j), (k, l), 1) + res = apply_resolution_of_identity( + ExprContainer(tensor), factorisation="sym" + ) + ref = ( + SymmetricTensor(tensor_names.ri_sym, (P,), (i, j)) + * SymmetricTensor(tensor_names.ri_sym, (P,), (k, l)) + ) + assert res.inner - ref is S.Zero + + def test_asym_factorisation(self): + i, j, k, l, P = get_symbols("ijklP") # noqa E741 + tensor = SymmetricTensor(tensor_names.coulomb, (i, j), (k, l), 1) + res = apply_resolution_of_identity( + ExprContainer(tensor), factorisation="asym" + ) + ref = ( + SymmetricTensor(tensor_names.ri_asym_factor, (P,), (i, j)) + * SymmetricTensor(tensor_names.ri_asym_eri, (P,), (k, l)) + ) + assert res.inner - ref is S.Zero + # with spatial orbitals + i, j, k, l, P = get_symbols("ijklP", "ababa") # noqa E741 + tensor = SymmetricTensor(tensor_names.coulomb, (i, j), (k, l), 1) + res = apply_resolution_of_identity( + ExprContainer(tensor), factorisation="asym" + ) + ref = ( + SymmetricTensor(tensor_names.ri_asym_factor, (P,), (i, j)) + * SymmetricTensor(tensor_names.ri_asym_eri, (P,), (k, l)) + ) + assert res.inner - ref is S.Zero + + def test_resolve_indices(self): + i, j, k, l, P, Q = get_symbols("ijklPQ") + # without resolve indices a unknown P should be in the res + tensor = SymmetricTensor(tensor_names.coulomb, (i, j), (k, l), 1) + res = apply_resolution_of_identity( + ExprContainer(tensor), factorisation="sym", resolve_indices=False + ) + ref = ( + SymmetricTensor(tensor_names.ri_sym, (P,), (i, j)) + * SymmetricTensor(tensor_names.ri_sym, (P,), (k, l)) + ) + assert P not in res.inner.atoms(Index) + # subsitute contracted should resolve the unknown P + res.substitute_contracted() + assert P in res.inner.atoms(Index) + assert res.inner - ref is S.Zero + # add another tensor and try with resolve_indices to immediately + # get a good result + tensor *= SymmetricTensor(tensor_names.coulomb, (i, k), (j, l), 1) + res = apply_resolution_of_identity( + ExprContainer(tensor), factorisation="sym", resolve_indices=True + ) + ref *= ( + SymmetricTensor(tensor_names.ri_sym, (Q,), (i, k)) + * SymmetricTensor(tensor_names.ri_sym, (Q,), (j, l)) + ) + assert P in res.inner.atoms(Index) and Q in res.inner.atoms(Index) + assert simplify(res - ref).inner is S.Zero @pytest.mark.parametrize('variant', ['mp', 're']) @pytest.mark.parametrize('order', [0, 1, 2, 3]) From 081429328f204316d5ab081bfa06f89dca808a50 Mon Sep 17 00:00:00 2001 From: jonasleitner Date: Mon, 9 Jun 2025 10:30:14 +0200 Subject: [PATCH 5/5] add float import --- adcgen/expression/import_from_sympy_latex.py | 4 +++- tests/expression/import_from_sympy_latex_test.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/adcgen/expression/import_from_sympy_latex.py b/adcgen/expression/import_from_sympy_latex.py index d6237be..c7499cc 100644 --- a/adcgen/expression/import_from_sympy_latex.py +++ b/adcgen/expression/import_from_sympy_latex.py @@ -192,8 +192,10 @@ def _import_object(obj_str: str, convert_default_names: bool, """ Imports the given object (a part of a product) from string. """ - if obj_str.isnumeric(): # prefactor + if obj_str.isnumeric(): # prefactor: int return sympify(int(obj_str)) + elif re.fullmatch(r"\d+\.\d+", obj_str): # prefactor: float + return sympify(float(obj_str)) elif obj_str.startswith(r"\sqrt{"): # sqrt{n} prefactor assert obj_str[-1] == "}" return sqrt(import_from_sympy_latex( diff --git a/tests/expression/import_from_sympy_latex_test.py b/tests/expression/import_from_sympy_latex_test.py index 20e8dd8..e59ac3d 100644 --- a/tests/expression/import_from_sympy_latex_test.py +++ b/tests/expression/import_from_sympy_latex_test.py @@ -16,6 +16,8 @@ def test_number(self): assert import_from_sympy_latex("2").inner == 2 assert import_from_sympy_latex(r"\frac{1}{2}").inner == Rational(1, 2) assert import_from_sympy_latex(r"\sqrt{2}").inner == sqrt(2) + assert import_from_sympy_latex("2.0").inner == 2.0 + assert import_from_sympy_latex("42.42").inner == 42.42 def test_delta(self): # spin orbitals