-
Notifications
You must be signed in to change notification settings - Fork 25
Refactor AdcMethod + ISR(1)-d #208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
390228a
d1d8a01
8cb3888
8241293
e2f0e6c
f781101
1c0e21b
5544c61
0a28bbb
ca0b8c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -28,7 +28,7 @@ | |||||||||
| from .LazyMp import LazyMp | ||||||||||
| from .adc_pp import matrix as ppmatrix | ||||||||||
| from .timings import Timer, timed_member_call | ||||||||||
| from .AdcMethod import AdcMethod | ||||||||||
| from .AdcMethod import AdcMethod, IsrMethod, Method | ||||||||||
| from .functions import ones_like | ||||||||||
| from .Intermediates import Intermediates | ||||||||||
| from .AmplitudeVector import AmplitudeVector | ||||||||||
|
|
@@ -73,17 +73,22 @@ class AdcMatrixlike: | |||||||||
|
|
||||||||||
| _special_block_orders = { | ||||||||||
| "adc2x": {"ph_ph": 2, "ph_pphh": 1, "pphh_ph": 1, "pphh_pphh": 1}, | ||||||||||
| "isr1s": {"ph_ph": 1, "ph_pphh": None, "pphh_ph": None, "pphh_pphh": None}, | ||||||||||
| } | ||||||||||
|
|
||||||||||
| @classmethod | ||||||||||
| def _default_block_orders(cls, method: AdcMethod) -> dict[str, int]: | ||||||||||
| def _default_block_orders(cls, | ||||||||||
| method: Method | ||||||||||
| ) -> dict[str, int]: | ||||||||||
|
Comment on lines
+80
to
+82
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| """ | ||||||||||
| Determines the default block orders for the given adc method. | ||||||||||
| """ | ||||||||||
| # check if we have a special method like adc2x | ||||||||||
| # I guess base_method should also contain the adc_type prefix so | ||||||||||
| # we don't need to separate different adc_types | ||||||||||
| block_orders = cls._special_block_orders.get(method.base_method.name, None) | ||||||||||
| block_orders = cls._special_block_orders.get( | ||||||||||
| method.base_method.name, None | ||||||||||
| ) | ||||||||||
| if block_orders is not None: | ||||||||||
| return block_orders.copy() | ||||||||||
| # otherwise assume that we have a "normal" PP/IP/...-ADC(n) method | ||||||||||
|
|
@@ -96,9 +101,23 @@ def _default_block_orders(cls, method: AdcMethod) -> dict[str, int]: | |||||||||
| raise ValueError(f"Unknown adc type {method.adc_type} for method " | ||||||||||
| f"{method.name}. Can not determine default block " | ||||||||||
| "orders.") | ||||||||||
| spaces = [ | ||||||||||
| "p" * i + min_space + "h" * i for i in range(0, (method.level // 2) + 1) | ||||||||||
| ] | ||||||||||
| # ADC matrices have first-order coupling whereas ISR matrices | ||||||||||
| # have a zeroth-order coupling between adjacent excitation classes | ||||||||||
| # see https://doi.org/10.1063/1.1752875 | ||||||||||
| if isinstance(method, AdcMethod): | ||||||||||
| # First-order coupling between adjacent excitation classes. | ||||||||||
| spaces = [ | ||||||||||
| "p" * i + min_space + "h" * i | ||||||||||
| for i in range(0, (method.level.to_int() // 2) + 1) | ||||||||||
| ] | ||||||||||
| elif isinstance(method, IsrMethod): | ||||||||||
| # Zeroth-order coupling between adjacent excitation classes. | ||||||||||
| spaces = [ | ||||||||||
| "p" * i + min_space + "h" * i | ||||||||||
| for i in range(0, ((method.level.to_int() + 1) // 2) + 1) | ||||||||||
| ] | ||||||||||
| else: | ||||||||||
| raise ValueError(f"Invalid method: {method.name}") | ||||||||||
| # exploit the fact that the spaces are sorted from small to high: | ||||||||||
| # If we walk the adc matrix in any direction we always have to subtract 1! | ||||||||||
| # Therefore, we can determine the order according to the position of the | ||||||||||
|
|
@@ -107,26 +126,27 @@ def _default_block_orders(cls, method: AdcMethod) -> dict[str, int]: | |||||||||
| ret = {} | ||||||||||
| for ((i1, bra), (i2, ket)) in \ | ||||||||||
| itertools.product(enumerate(spaces), repeat=2): | ||||||||||
| order = method.level - i1 - i2 | ||||||||||
| assert order >= 0 | ||||||||||
| order = method.level.to_int() - i1 - i2 | ||||||||||
| # For ISR matrices allow missing diagonal blocks. | ||||||||||
| order = None if order < 0 else order | ||||||||||
| ret[f"{bra}_{ket}"] = order | ||||||||||
| return ret | ||||||||||
|
|
||||||||||
| @classmethod | ||||||||||
| def _validate_block_orders(cls, block_orders: dict[str, int], | ||||||||||
| method: AdcMethod, | ||||||||||
| method: Method, | ||||||||||
| allow_missing_diagonal_blocks: bool = False) -> None: | ||||||||||
| """ | ||||||||||
| Validates that the given block_orders form a valid adc matrix for the given | ||||||||||
| adc method. | ||||||||||
| Validates that the given block_orders form a valid adc/isr matrix for the | ||||||||||
| given adc(isr) method. | ||||||||||
|
|
||||||||||
| Parameters | ||||||||||
| ---------- | ||||||||||
| block_orders: dict[str, int] | ||||||||||
| The block orders to validate. Block orders should be of the form | ||||||||||
| {'ph_ph': 2, 'ph_pphh': 1, ...} | ||||||||||
| method: AdcMethod | ||||||||||
| The adc method/adc type (PP-ADC, ...) for which to validate | ||||||||||
| method: Method | ||||||||||
| The adc/isr method/adc type (PP-ADC/ISR, ...) for which to validate | ||||||||||
| the block_orders. | ||||||||||
| allow_missing_diagonal_blocks: bool, optional | ||||||||||
| If set, couplings between missing diagonal blocks are allowed, e.g., | ||||||||||
|
|
@@ -137,6 +157,7 @@ def _validate_block_orders(cls, block_orders: dict[str, int], | |||||||||
| for block, order in block_orders.items(): | ||||||||||
| if order is None: | ||||||||||
| continue | ||||||||||
| assert order >= 0 | ||||||||||
| # ensure that the block is valid for the given adc type | ||||||||||
| bra, ket = block.split("_") | ||||||||||
| if not cls._is_valid_space(bra, method) or \ | ||||||||||
|
|
@@ -164,11 +185,11 @@ def _validate_block_orders(cls, block_orders: dict[str, int], | |||||||||
| f"{ket_diag} are in the matrix too.") | ||||||||||
|
|
||||||||||
| @classmethod | ||||||||||
| def _is_valid_space(cls, space: str, method: AdcMethod) -> bool: | ||||||||||
| def _is_valid_space(cls, space: str, method: Method) -> bool: | ||||||||||
| """ | ||||||||||
| Checks whether the given space ('ph' for instance) is valid for the given | ||||||||||
| adc method. Thereby we only verify that the space matches the adc_type of | ||||||||||
| adc method! | ||||||||||
| method! | ||||||||||
| """ | ||||||||||
| n_particle, n_hole = space.count("p"), space.count("h") | ||||||||||
| # ensure that the space is of the form pp...hh... | ||||||||||
|
|
@@ -236,11 +257,7 @@ def __init__(self, method, hf_or_mp, block_orders=None, intermediates=None, | |||||||||
| self.intermediates = Intermediates(self.ground_state) | ||||||||||
|
|
||||||||||
| self.block_orders = self._default_block_orders(self.method) | ||||||||||
| if block_orders is None: | ||||||||||
| if method.level > 3: | ||||||||||
| raise NotImplementedError("The ADC secular matrix is not " | ||||||||||
| f"implemented for method {method.name}.") | ||||||||||
| else: | ||||||||||
| if block_orders is not None: | ||||||||||
| self.block_orders.update(block_orders) | ||||||||||
| self._validate_block_orders( | ||||||||||
| block_orders=self.block_orders, method=self.method, | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -20,76 +20,125 @@ | |||||||||||||||
| ## along with adcc. If not, see <http://www.gnu.org/licenses/>. | ||||||||||||||||
| ## | ||||||||||||||||
| ## --------------------------------------------------------------------- | ||||||||||||||||
| from typing import Optional, TypeVar | ||||||||||||||||
| from enum import Enum | ||||||||||||||||
|
|
||||||||||||||||
| T = TypeVar("T", bound="Method") | ||||||||||||||||
|
|
||||||||||||||||
| def get_valid_methods(): | ||||||||||||||||
| valid_prefixes = ["cvs"] | ||||||||||||||||
| valid_bases = ["adc0", "adc1", "adc2", "adc2x", "adc3"] | ||||||||||||||||
|
|
||||||||||||||||
| ret = valid_bases + [p + "-" + m for p in valid_prefixes | ||||||||||||||||
| for m in valid_bases] | ||||||||||||||||
| return ret | ||||||||||||||||
| class MethodLevel(Enum): | ||||||||||||||||
| # numeric levels | ||||||||||||||||
| ZERO = 0 | ||||||||||||||||
| ONE = 1 | ||||||||||||||||
| TWO = 2 | ||||||||||||||||
| THREE = 3 | ||||||||||||||||
| FOUR = 4 | ||||||||||||||||
| FIVE = 5 | ||||||||||||||||
|
|
||||||||||||||||
| # special levels | ||||||||||||||||
| TWO_X = "2x" | ||||||||||||||||
| ONE_S = "1s" | ||||||||||||||||
| THREE_D = "3d" | ||||||||||||||||
|
Comment on lines
+38
to
+41
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be good to add a comment for each special level that explains what the level means |
||||||||||||||||
|
|
||||||||||||||||
| class AdcMethod: | ||||||||||||||||
| available_methods = get_valid_methods() | ||||||||||||||||
| def to_str(self) -> str: | ||||||||||||||||
| return str(self.value) | ||||||||||||||||
|
|
||||||||||||||||
| def __init__(self, method): | ||||||||||||||||
| if method not in self.available_methods: | ||||||||||||||||
| raise ValueError("Invalid method " + str(method) + ". Only " | ||||||||||||||||
| + ",".join(self.available_methods) + " are known.") | ||||||||||||||||
| def to_int(self) -> int: | ||||||||||||||||
| # numerical methods | ||||||||||||||||
| if isinstance(self.value, int): | ||||||||||||||||
| return self.value | ||||||||||||||||
| # return base int for special methods | ||||||||||||||||
| elif isinstance(self.value, str): | ||||||||||||||||
| return int(self.value[0]) | ||||||||||||||||
| else: | ||||||||||||||||
| raise ValueError | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| class Method: | ||||||||||||||||
| # this has to be set on the child classes | ||||||||||||||||
| _method_base_name: Optional[str] = None | ||||||||||||||||
| max_level: int = 0 | ||||||||||||||||
| special_levels: tuple[MethodLevel, ...] = tuple() | ||||||||||||||||
|
|
||||||||||||||||
| def __init__(self, method: str): | ||||||||||||||||
| assert self._method_base_name is not None | ||||||||||||||||
|
|
||||||||||||||||
| # validate base method type | ||||||||||||||||
| split = method.split("-") | ||||||||||||||||
| self.__base_method = split[-1] | ||||||||||||||||
| if not split[-1].startswith(self._method_base_name): | ||||||||||||||||
| raise ValueError(f"{split[-1]} is not a valid method type") | ||||||||||||||||
|
|
||||||||||||||||
| # validate method level | ||||||||||||||||
| level = split[-1][len(self._method_base_name):] | ||||||||||||||||
| if level.isnumeric(): | ||||||||||||||||
| self.level: MethodLevel = MethodLevel(int(level)) | ||||||||||||||||
| else: | ||||||||||||||||
| self.level: MethodLevel = MethodLevel(level) | ||||||||||||||||
| self._validate_level(self.level) | ||||||||||||||||
|
|
||||||||||||||||
| assert self._base_method == split[-1] | ||||||||||||||||
|
|
||||||||||||||||
| # validate prefix | ||||||||||||||||
| split = split[:-1] | ||||||||||||||||
| self.is_core_valence_separated = "cvs" in split | ||||||||||||||||
| if split and split[0] not in ["cvs"]: | ||||||||||||||||
| raise ValueError(f"{split[0]} is not a valid method prefix") | ||||||||||||||||
|
Comment on lines
+83
to
+84
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we also need to verify that split is at most of length 1 with cvs being the only valid entry, right?
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| self.is_core_valence_separated: bool = "cvs" in split | ||||||||||||||||
| # NOTE: added this to make the testdata generation ready for IP/EA | ||||||||||||||||
| self.adc_type = "pp" | ||||||||||||||||
| self.adc_type: str = "pp" | ||||||||||||||||
|
|
||||||||||||||||
| try: | ||||||||||||||||
| if self.__base_method == "adc2x": | ||||||||||||||||
| self.level = 2 | ||||||||||||||||
| else: | ||||||||||||||||
| self.level = int(self.__base_method[-1]) | ||||||||||||||||
| except ValueError: | ||||||||||||||||
| raise ValueError("Not a valid base method: " + self.__base_method) | ||||||||||||||||
| def _validate_level(self, level: MethodLevel) -> None: | ||||||||||||||||
| if isinstance(level.value, int): | ||||||||||||||||
| if level.value <= self.max_level: | ||||||||||||||||
| return | ||||||||||||||||
|
Comment on lines
+91
to
+93
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| def at_level(self, newlevel): | ||||||||||||||||
| """ | ||||||||||||||||
| Return an equivalent method, where only the level is changed | ||||||||||||||||
| (e.g. calling this on a CVS method returns a CVS method) | ||||||||||||||||
| """ | ||||||||||||||||
| if self.is_core_valence_separated: | ||||||||||||||||
| return AdcMethod("cvs-adc" + str(newlevel)) | ||||||||||||||||
| else: | ||||||||||||||||
| return AdcMethod("adc" + str(newlevel)) | ||||||||||||||||
| # special cases | ||||||||||||||||
| if level in self.special_levels: | ||||||||||||||||
| return | ||||||||||||||||
|
|
||||||||||||||||
| raise NotImplementedError(f"{self._base_method} is not implemented.") | ||||||||||||||||
|
|
||||||||||||||||
| @property | ||||||||||||||||
| def name(self): | ||||||||||||||||
| def name(self) -> str: | ||||||||||||||||
| """The name of the Method as string.""" | ||||||||||||||||
| if self.is_core_valence_separated: | ||||||||||||||||
| return "cvs-" + self.__base_method | ||||||||||||||||
| return "cvs-" + self._base_method | ||||||||||||||||
| else: | ||||||||||||||||
| return self.__base_method | ||||||||||||||||
| return self._base_method | ||||||||||||||||
|
|
||||||||||||||||
| @property | ||||||||||||||||
| def property_method(self): | ||||||||||||||||
| """ | ||||||||||||||||
| The name of the canonical method to use for computing properties | ||||||||||||||||
| for this ADC method. This only differs from the name property | ||||||||||||||||
| for the ADC(2)-x family of methods. | ||||||||||||||||
| """ | ||||||||||||||||
| if self.__base_method == "adc2x": | ||||||||||||||||
| return AdcMethod(self.name.replace("adc2x", "adc2")).name | ||||||||||||||||
| else: | ||||||||||||||||
| return self.name | ||||||||||||||||
| def _base_method(self) -> str: | ||||||||||||||||
| return self._method_base_name + self.level.to_str() | ||||||||||||||||
|
Comment on lines
+110
to
+111
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| @property | ||||||||||||||||
| def base_method(self): | ||||||||||||||||
| def base_method(self: T) -> T: | ||||||||||||||||
| """ | ||||||||||||||||
| The base (full) method, i.e. with all approximations such as | ||||||||||||||||
| CVS stripped off. | ||||||||||||||||
| """ | ||||||||||||||||
| return AdcMethod(self.__base_method) | ||||||||||||||||
| return self.__class__(self._base_method) | ||||||||||||||||
|
|
||||||||||||||||
| def at_level(self: T, newlevel: int) -> T: | ||||||||||||||||
| """ | ||||||||||||||||
| Return an equivalent method, where only the level is changed | ||||||||||||||||
| (e.g. calling this on a CVS method returns a CVS method) | ||||||||||||||||
| """ | ||||||||||||||||
| assert self._method_base_name is not None | ||||||||||||||||
| if self.is_core_valence_separated: | ||||||||||||||||
| return self.__class__("cvs-" + self._method_base_name + str(newlevel)) | ||||||||||||||||
| else: | ||||||||||||||||
| return self.__class__(self._method_base_name + str(newlevel)) | ||||||||||||||||
|
|
||||||||||||||||
| def as_method(self, method_cls: type[T]) -> T: | ||||||||||||||||
| """ | ||||||||||||||||
| Return a equivalent Method with the method base name replaced | ||||||||||||||||
| by the provided name. | ||||||||||||||||
| """ | ||||||||||||||||
| assert self._method_base_name is not None | ||||||||||||||||
| assert method_cls._method_base_name is not None | ||||||||||||||||
| return method_cls( | ||||||||||||||||
| self.name.replace(self._method_base_name, method_cls._method_base_name) | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| def __eq__(self, other): | ||||||||||||||||
| return self.name == other.name | ||||||||||||||||
|
|
@@ -99,3 +148,15 @@ def __ne__(self, other): | |||||||||||||||
|
|
||||||||||||||||
| def __repr__(self): | ||||||||||||||||
| return "Method(name={})".format(self.name) | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| class AdcMethod(Method): | ||||||||||||||||
| _method_base_name = "adc" | ||||||||||||||||
| max_level = 3 | ||||||||||||||||
| special_levels = (MethodLevel.TWO_X,) | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| class IsrMethod(Method): | ||||||||||||||||
| _method_base_name = "isr" | ||||||||||||||||
| max_level = 2 | ||||||||||||||||
| special_levels = (MethodLevel.ONE_S,) | ||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.