diff --git a/src/opi/input/blocks/base.py b/src/opi/input/blocks/base.py index 3e2267e3..348e556e 100644 --- a/src/opi/input/blocks/base.py +++ b/src/opi/input/blocks/base.py @@ -169,6 +169,12 @@ def format_orca(self) -> str: return s + def __add__(self, other: "Block") -> "Block": + new_block = self.__class__.model_validate( + {**self.model_dump(), **other.model_dump(exclude_unset=True)} + ) + return new_block + @property def name(self) -> str: return self._name @@ -189,3 +195,16 @@ def init_inputpath(cls, inp: Any) -> Any: return InputFilePath(file=inp) else: return inp + + @classmethod + def get_subclass_by_name(cls, name: str) -> type["Block"]: + matches = {sub.__name__.lower(): sub for sub in cls.__subclasses__()} + name_matches = {sub().name.lower(): sub for sub in cls.__subclasses__()} + match = matches.get(name.lower()) or name_matches.get(name.lower()) + + if match is None: + raise ValueError( + f"No Block subclass found with name {name!r}. Available: {list(matches.keys())}" + ) + + return match diff --git a/src/opi/input/blocks/block_ice.py b/src/opi/input/blocks/block_ice.py index 8a119d0d..8c81f75d 100644 --- a/src/opi/input/blocks/block_ice.py +++ b/src/opi/input/blocks/block_ice.py @@ -21,4 +21,4 @@ class BlockIce(Block): etol: float | None = None icetype: Literal["CFGs", "CSFs", "DETs"] | None = None # > algorithm details - integrals: Literal["exact", "ri"] | None + integrals: Literal["exact", "ri"] | None = None diff --git a/src/opi/input/simple_keywords/base.py b/src/opi/input/simple_keywords/base.py index 85b70efe..f71e03a7 100644 --- a/src/opi/input/simple_keywords/base.py +++ b/src/opi/input/simple_keywords/base.py @@ -1,8 +1,52 @@ -__all__ = ("SimpleKeyword",) +__all__ = ("SimpleKeyword", "SimpleKeywordBox") + +from typing import Any class SimpleKeywordBox: - pass + """ + TODO: + - rework registry to account for latest changes. + """ + _registry: list[type["SimpleKeywordBox"]] = [] + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + cls._registry = [] + + for base in cls.__bases__: + if hasattr(base, "_registry"): + base._registry.append(cls) + + cls._registry.append(cls) + + @classmethod + def registry(cls) -> list[type["SimpleKeywordBox"]]: + return cls._registry + + @classmethod + def from_string(cls, s: str) -> "SimpleKeyword": + norm = s.lower() + for c in cls._registry: + for attr in dir(c): + if attr.startswith('_'): # Skip private/magic attributes + continue + value = getattr(c, attr) + if isinstance(value, SimpleKeyword) and value.keyword.lower() == norm: + return value + elif isinstance(value, SimpleKeyword) and attr.lower() == norm: + return value + elif isinstance(value, SimpleKeyword) and value.alias and value.alias.lower() == norm: + return value + + raise ValueError(f"Keyword {s} not found in class {cls.__name__}") + + @classmethod + def find_keyword(cls, inp: "SimpleKeyword | str") -> "SimpleKeyword": + if isinstance(inp, SimpleKeyword): + inp = inp.keyword + + return cls.from_string(inp) class SimpleKeyword: @@ -17,12 +61,14 @@ class SimpleKeyword: keyword: str simple keyword as it will appear in the ORCA .inp file """ + alias: str | None = None - def __init__(self, keyword: str) -> None: + def __init__(self, keyword: str, alias:str|None=None) -> None: self._keyword: str = "" self.keyword = keyword self._name: str = "" # self.name = name + self.alias = alias @property def keyword(self) -> str: diff --git a/src/opi/input/simple_keywords/grid.py b/src/opi/input/simple_keywords/grid.py index bcb10704..7383c969 100644 --- a/src/opi/input/simple_keywords/grid.py +++ b/src/opi/input/simple_keywords/grid.py @@ -9,11 +9,11 @@ class Grid(SimpleKeywordBox): """Enum to store all simple keywords of type Grid.""" - DEFGRID1 = SimpleKeyword("defgrid1") + DEFGRID1 = SimpleKeyword("defgrid1", alias="1") """SimpleKeyword: small grid.""" - DEFGRID2 = SimpleKeyword("defgrid2") + DEFGRID2 = SimpleKeyword("defgrid2", alias="2") """SimpleKeyword: medium grid.""" - DEFGRID3 = SimpleKeyword("defgrid3") + DEFGRID3 = SimpleKeyword("defgrid3", alias="3") """SimpleKeyword: large grid.""" REFGRID = SimpleKeyword("refgrid") """SimpleKeyword: reference grid.""" diff --git a/src/opi/input/simple_keywords/opt.py b/src/opi/input/simple_keywords/opt.py index f9f20161..90278607 100644 --- a/src/opi/input/simple_keywords/opt.py +++ b/src/opi/input/simple_keywords/opt.py @@ -6,24 +6,27 @@ __all__ = ("Opt",) -class Opt(SimpleKeywordBox): - """Enum to store all simple keywords of type Opt.""" - - OPT = SimpleKeyword("opt") - """SimpleKeyword: Perform a geometry optimization.""" - CRUDEOPT = SimpleKeyword("crudeopt") +class OptThreshold(SimpleKeywordBox): + CRUDEOPT = SimpleKeyword("crudeopt", alias="crude") """SimpleKeyword: Geometry optimization with thresholds.""" - INTERPOPT = SimpleKeyword("interpopt") + LOOSEOPT = SimpleKeyword("looseopt", alias="loose") """SimpleKeyword: Geometry optimization with thresholds.""" - LOOSEOPT = SimpleKeyword("looseopt") + NORMALOPT = SimpleKeyword("normalopt", alias="normal") """SimpleKeyword: Geometry optimization with thresholds.""" - NORMALOPT = SimpleKeyword("normalopt") + SLOPPYOPT = SimpleKeyword("sloppyopt", alias="sloppy") """SimpleKeyword: Geometry optimization with thresholds.""" - SLOPPYOPT = SimpleKeyword("sloppyopt") + TIGHTOPT = SimpleKeyword("tightopt", alias="tight") """SimpleKeyword: Geometry optimization with thresholds.""" - TIGHTOPT = SimpleKeyword("tightopt") + VERYTIGHTOPT = SimpleKeyword("verytightopt", alias="verytight") """SimpleKeyword: Geometry optimization with thresholds.""" - VERYTIGHTOPT = SimpleKeyword("verytightopt") + + +class Opt(SimpleKeywordBox): + """Enum to store all simple keywords of type Opt.""" + + OPT = SimpleKeyword("opt") + """SimpleKeyword: Perform a geometry optimization.""" + INTERPOPT = SimpleKeyword("interpopt") """SimpleKeyword: Geometry optimization with thresholds.""" OPTH = SimpleKeyword("opth") """SimpleKeyword: Optimize only hydrogen atoms.""" diff --git a/src/opi/input/simple_keywords/scf.py b/src/opi/input/simple_keywords/scf.py index 0bd72599..f4d53929 100644 --- a/src/opi/input/simple_keywords/scf.py +++ b/src/opi/input/simple_keywords/scf.py @@ -5,8 +5,67 @@ __all__ = ("Scf",) +class ScfThreshold(SimpleKeywordBox): + SLOPPYSCF = SimpleKeyword("sloppyscf", alias="sloppy") + """SimpleKeyword: SCF convergence threshold settings.""" + LOOSESCF = SimpleKeyword("loosescf", alias="loose") + """SimpleKeyword: SCF convergence threshold settings.""" + NORMALSCF = SimpleKeyword("normalscf", alias="normal") + """SimpleKeyword: SCF convergence threshold settings.""" + STRONGSCF = SimpleKeyword("strongscf", alias="strong") + """SimpleKeyword: SCF convergence threshold settings.""" + TIGHTSCF = SimpleKeyword("tightscf", alias="tight") + """SimpleKeyword: SCF convergence threshold settings.""" + VERYTIGHTSCF = SimpleKeyword("verytightscf", alias="verytight") + """SimpleKeyword: SCF convergence threshold settings.""" + EXTREMESCF = SimpleKeyword("extremescf", alias="extreme") + """SimpleKeyword: SCF convergence threshold settings.""" + + +class ScfGuess(SimpleKeywordBox): + MOREAD = SimpleKeyword("moread") + """SimpleKeyword: SCF initial guess read orbitals from gbw file.""" + EHTANO = SimpleKeyword("ehtano") + """SimpleKeyword: SCF initial guess.""" + HCORE = SimpleKeyword("hcore") + """SimpleKeyword: SCF initial guess.""" + PATOM = SimpleKeyword("patom") + """SimpleKeyword: SCF initial guess.""" + PMODEL = SimpleKeyword("pmodel") + """SimpleKeyword: SCF initial guess.""" + PMODELX = SimpleKeyword("pmodelx") + """SimpleKeyword: SCF initial guess.""" + PMODELXAV = SimpleKeyword("pmodelxav") + """SimpleKeyword: SCF initial guess.""" + PMODELXPM = SimpleKeyword("pmodelxpm") + """SimpleKeyword: SCF initial guess.""" + SYMBREAKGUESS = SimpleKeyword("symbreakguess") + """SimpleKeyword: SCF initial guess.""" + + +class ScfConvergence(SimpleKeywordBox): + EASYCONV = SimpleKeyword("easyconv", alias="easy") + """SimpleKeyword: SCF convergence strategy.""" + NORMALCONV = SimpleKeyword("normalconv", alias="normal") + """SimpleKeyword: SCF convergence strategy.""" + SLOWCONV = SimpleKeyword("slowconv", alias="slow") + """SimpleKeyword: SCF convergence strategy.""" + VERYSLOWCONV = SimpleKeyword("veryslowconv", alias="veryslow") + """SimpleKeyword: SCF convergence strategy.""" + + +class ScfSolver(SimpleKeywordBox): + DIIS = SimpleKeyword("diis") + """SimpleKeyword: SCF solver.""" + KDIIS = SimpleKeyword("kdiis") + """SimpleKeyword: SCF solver.""" + SOSCF = SimpleKeyword("soscf") + """SimpleKeyword: SCF solver.""" + TRAH = SimpleKeyword("trah") + """SimpleKeyword: SCF solver.""" -class Scf(SimpleKeywordBox): + +class Scf(ScfThreshold, ScfConvergence, ScfGuess, ScfSolver): """Enum to store all simple keywords of type Scf.""" G3CONV = SimpleKeyword("3conv") @@ -17,50 +76,29 @@ class Scf(SimpleKeywordBox): """SimpleKeyword: SCF solver combination.""" KDIISTRAH = SimpleKeyword("kdiistrah") """SimpleKeyword: SCF solver combination.""" - DIIS = SimpleKeyword("diis") - """SimpleKeyword: SCF solver.""" + NODIIS = SimpleKeyword("nodiis") """SimpleKeyword: SCF solver.""" AODIIS = SimpleKeyword("aodiis") """SimpleKeyword: SCF solver.""" NOAODIIS = SimpleKeyword("noaodiis") """SimpleKeyword: SCF solver.""" - KDIIS = SimpleKeyword("kdiis") - """SimpleKeyword: SCF solver.""" + NOKDIIS = SimpleKeyword("nokdiis") """SimpleKeyword: SCF solver.""" - SOSCF = SimpleKeyword("soscf") - """SimpleKeyword: SCF solver.""" + NOSOSCF = SimpleKeyword("nososcf") """SimpleKeyword: SCF solver.""" - TRAH = SimpleKeyword("trah") - """SimpleKeyword: SCF solver.""" + NOTRAH = SimpleKeyword("notrah") """SimpleKeyword: SCF solver.""" AUTOSTART = SimpleKeyword("autostart") """SimpleKeyword: SCF initial guess start SCF from a gbw file with the same basename (default).""" NOAUTOSTART = SimpleKeyword("noautostart") """SimpleKeyword: SCF initial guess do not start SCF from a gbw file with the same basename.""" - MOREAD = SimpleKeyword("moread") - """SimpleKeyword: SCF initial guess read orbitals from gbw file.""" - EHTANO = SimpleKeyword("ehtano") - """SimpleKeyword: SCF initial guess.""" - HCORE = SimpleKeyword("hcore") - """SimpleKeyword: SCF initial guess.""" HUECKEL = SimpleKeyword("hueckel") """SimpleKeyword: SCF initial guess.""" - PATOM = SimpleKeyword("patom") - """SimpleKeyword: SCF initial guess.""" - PMODEL = SimpleKeyword("pmodel") - """SimpleKeyword: SCF initial guess.""" - PMODELX = SimpleKeyword("pmodelx") - """SimpleKeyword: SCF initial guess.""" - PMODELXAV = SimpleKeyword("pmodelxav") - """SimpleKeyword: SCF initial guess.""" - PMODELXPM = SimpleKeyword("pmodelxpm") - """SimpleKeyword: SCF initial guess.""" - SYMBREAKGUESS = SimpleKeyword("symbreakguess") - """SimpleKeyword: SCF initial guess.""" + UNITMATRIXGUESS = SimpleKeyword("unitmatrixguess") """SimpleKeyword: SCF initial guess.""" USEGRAMSCHMIDT = SimpleKeyword("usegramschmidt") @@ -91,20 +129,6 @@ class Scf(SimpleKeywordBox): """SimpleKeyword: enable fractional occupations.""" SCFCONVFORCED = SimpleKeyword("scfconvforced") """SimpleKeyword: Force SCF convergence for subsequent operations.""" - SLOPPYSCF = SimpleKeyword("sloppyscf") - """SimpleKeyword: SCF convergence threshold settings.""" - LOOSESCF = SimpleKeyword("loosescf") - """SimpleKeyword: SCF convergence threshold settings.""" - NORMALSCF = SimpleKeyword("normalscf") - """SimpleKeyword: SCF convergence threshold settings.""" - STRONGSCF = SimpleKeyword("strongscf") - """SimpleKeyword: SCF convergence threshold settings.""" - TIGHTSCF = SimpleKeyword("tightscf") - """SimpleKeyword: SCF convergence threshold settings.""" - VERYTIGHTSCF = SimpleKeyword("verytightscf") - """SimpleKeyword: SCF convergence threshold settings.""" - EXTREMESCF = SimpleKeyword("extremescf") - """SimpleKeyword: SCF convergence threshold settings.""" SLOPPYSCFCHECK = SimpleKeyword("sloppyscfcheck") """SimpleKeyword: SCF convergence threshold settings.""" NOSLOPPYSCFCHECK = SimpleKeyword("nosloppyscfcheck") @@ -125,14 +149,6 @@ class Scf(SimpleKeywordBox): """SimpleKeyword: SCF convergence threshold settings.""" SCFCONV12 = SimpleKeyword("scfconv12") """SimpleKeyword: SCF convergence threshold settings.""" - EASYCONV = SimpleKeyword("easyconv") - """SimpleKeyword: SCF convergence strategy.""" - NORMALCONV = SimpleKeyword("normalconv") - """SimpleKeyword: SCF convergence strategy.""" - SLOWCONV = SimpleKeyword("slowconv") - """SimpleKeyword: SCF convergence strategy.""" - VERYSLOWCONV = SimpleKeyword("veryslowconv") - """SimpleKeyword: SCF convergence strategy.""" DAMP = SimpleKeyword("damp") """SimpleKeyword: SCF settings.""" NODAMP = SimpleKeyword("nodamp") diff --git a/src/opi/input/simple_keywords/solvation_model.py b/src/opi/input/simple_keywords/solvation_model.py index dd22568c..0568f22f 100644 --- a/src/opi/input/simple_keywords/solvation_model.py +++ b/src/opi/input/simple_keywords/solvation_model.py @@ -12,7 +12,7 @@ def __init__(self, model: str) -> None: super().__init__(model) def __call__(self, solvent: Solvent, /) -> SimpleKeyword: - if not isinstance(solvent, Solvent): + if not isinstance(solvent, Solvent | str): raise TypeError(f"Solvent '{solvent}' is not of Solvent type!") return SimpleKeyword(f"{self.keyword}({solvent})") diff --git a/src/opi/input/simple_keywords/solvent.py b/src/opi/input/simple_keywords/solvent.py index 9ae95c30..e71f0cb5 100644 --- a/src/opi/input/simple_keywords/solvent.py +++ b/src/opi/input/simple_keywords/solvent.py @@ -242,3 +242,11 @@ class Solvent(StrEnum): WOCTANOL = "WOCTANOL" WETOCTANOL = "WETOCTANOL" CONDUCTOR = "CONDUCTOR" + + @classmethod + def find_keyword(cls, key: str) -> str: + norm = key.lower() + for member in cls: + if member.value.lower() == norm or member.name.lower() == norm: + return str(member.value) + raise ValueError(f"Key {key} not found in {cls.__name__}") diff --git a/src/opi/tasks/engrad_task.py b/src/opi/tasks/engrad_task.py new file mode 100644 index 00000000..921c4cfc --- /dev/null +++ b/src/opi/tasks/engrad_task.py @@ -0,0 +1,53 @@ +import typing + +from opi.input.simple_keywords import Task, SimpleKeyword, Solvent +from opi.tasks.method_settings import DFTSettings +from opi.tasks.task_base import SimpleTask, TaskSettings, TaskResults + + +class EngradSettings(TaskSettings): + _name: str = "engrad" + task_keyword: typing.Annotated[SimpleKeyword, Task] = Task.ENGRAD + + +class EngradTask(SimpleTask): + def __init__(self, + method: str | SimpleKeyword, + basis_set: str | SimpleKeyword | None = None, + solvation_model: str | SimpleKeyword | None = None, + solvent: str | Solvent | None = None, + task: str | SimpleKeyword | None = None): + self._method_settings = DFTSettings( + method=method, basis_set=basis_set, solvation_model=solvation_model, solvent=solvent + ) + self._task_settings = ( + EngradSettings(task_keyword=task) + ) if task else EngradSettings() + + self._results_type = EngradResults + + +class EngradResults(TaskResults): + @property + @TaskResults.output_parse + def final_energy(self) -> float: + final_energy = self.output.get_final_energy() + + if final_energy is None: + raise ValueError("Could not get final energy from ORCA Output") + + return final_energy + + @property + @TaskResults.output_parse + def gradient(self) -> list[float]: + gradient = self.output.get_gradient() + + if gradient is None: + raise ValueError("Could not get gradient from ORCA Output") + + return gradient + + @property + def primary_property(self) -> tuple[float, list[float]]: + return self.final_energy, self.gradient \ No newline at end of file diff --git a/src/opi/tasks/freq_task.py b/src/opi/tasks/freq_task.py new file mode 100644 index 00000000..6a453941 --- /dev/null +++ b/src/opi/tasks/freq_task.py @@ -0,0 +1,49 @@ +import typing +from functools import cached_property + +from opi.input.simple_keywords import Task, SimpleKeyword, Solvent +from opi.tasks.method_settings import DFTSettings +from opi.tasks.task_base import TaskSettings, TaskResults, SimpleTask + + +class FreqSettings(TaskSettings): + _name: str = "freq" + task_keyword: typing.Annotated[SimpleKeyword, Task] = Task.FREQ + + +class FreqTask(SimpleTask): + def __init__(self, + method: str | SimpleKeyword, + basis_set: str | SimpleKeyword | None = None, + solvation_model: str | SimpleKeyword | None = None, + solvent: str | Solvent | None = None, + task: str | SimpleKeyword | None = None): + self._method_settings = DFTSettings( + method=method, basis_set=basis_set, solvation_model=solvation_model, solvent=solvent + ) + self._task_settings = ( + FreqSettings(task_keyword=task) + ) if task else FreqSettings() + + self._results_type = FreqResults + + +class FreqResults(TaskResults): + @cached_property + def status(self) -> bool: + return self.output.terminated_normally() + + @cached_property + @TaskResults.output_parse + def free_energy_delta(self) -> float: + free_energy_delta = self.output.get_free_energy_delta() + + if not free_energy_delta: + raise ValueError("Could not get free energy delta from ORCA output") + + return free_energy_delta + + @property + def primary_property(self) -> float: + return self.free_energy_delta + diff --git a/src/opi/tasks/method_settings.py b/src/opi/tasks/method_settings.py new file mode 100644 index 00000000..273a4cec --- /dev/null +++ b/src/opi/tasks/method_settings.py @@ -0,0 +1,110 @@ +import typing +import warnings + +from pydantic import field_validator, model_validator + +from opi.input import Input +from opi.input.simple_keywords import Dft, SimpleKeyword, Grid, DispersionCorrection +from opi.input.simple_keywords.scf import ScfThreshold, ScfSolver, Scf, ScfConvergence +from opi.tasks.task_base import MethodSettings + + +class DFTSettings(MethodSettings): + _name: str = "dft" + method: typing.Annotated[SimpleKeyword, Dft] + grid: typing.Annotated[SimpleKeyword, Grid] | None = None + scf_maxiter: typing.Annotated[int, "BlockScf", "maxiter"] | None = None + scf_threshold: typing.Annotated[SimpleKeyword, ScfThreshold] | None = None + scf_solver: typing.Annotated[SimpleKeyword, ScfSolver] | None = None + scf_stab: bool = False + scf_conv: typing.Annotated[SimpleKeyword, ScfConvergence] | None = None + + + @field_validator("*", mode="before") + @classmethod + def validate_fields(cls, value: typing.Any, info): + if info.field_name == "method": + try: + new_keyword = Dft.find_keyword(value) + except ValueError: + new_keyword = cls._find_dft_disp_keyword(value) + return new_keyword + else: + return super().validate_fields(value, info) + + + + @model_validator(mode="after") + @classmethod + def cross_validate(cls, data: "DFTSettings") -> "DFTSettings": + """ + Cross-validation for `DftSettings`. + If the method keyword contains '3c', the `basis_set` attribute will be set to `None`. + + The `DftSettings` object is then returned. + Parameters + ---------- + data: DFTSettings + `DFTSettings` object given as input. + + Returns + ------- + DFTSettings + Cross-validated `DFTSettings` object. + + """ + if "3c" in data.method.keyword and data.basis_set: + warnings.warn("Basis Set will be ignored due to selection of Method", UserWarning) + data.basis_set = None + + return data + + def map_to_input(self, input_object: Input) -> Input: + input_object = super().map_to_input(input_object) + + if self.scf_stab: + input_object.add_simple_keywords(Scf.SCFSTAB) + + return input_object + + @classmethod + def _find_dft_disp_keyword(cls, value: str | SimpleKeyword) -> SimpleKeyword: + """ + Function to search for a `Dft` keyword along with `DispersionCorrection`. + In the case that an - is present in keyword, the keyword is split along the - and it is verified whether the + given keyword is a valid combination of `Dft` and `DispersionCorrection` keyword. + + If it is , a `SimpleKeyword` object is created and returned. If not , a `ValueError` is raised. + Parameters + ---------- + value: str | SimpleKeyword + The value to search for. + + Returns + ------- + SimpleKeyword + The created `SimpleKeyword` object. + + Raises + ------ + ValueError + If given value is invalid + + """ + if isinstance(value, SimpleKeyword): + value = value.keyword + + if '-' in value: + try: + keywords = value.split('-') + Dft.find_keyword(keywords[0]) + DispersionCorrection.find_keyword(keywords[1]) + + return SimpleKeyword(value) + except ValueError: + raise ValueError(f"Invalid Dft keyword or dispersion correction given: {value}") + else: + raise ValueError(f"Invalid Dft keyword '{value}'") + + + diff --git a/src/opi/tasks/opt_task.py b/src/opi/tasks/opt_task.py new file mode 100644 index 00000000..08d3a6d3 --- /dev/null +++ b/src/opi/tasks/opt_task.py @@ -0,0 +1,84 @@ +import typing + +from opi.input import Input +from opi.input.simple_keywords import Task, SimpleKeyword, Solvent +from opi.input.simple_keywords.opt import OptThreshold, Opt +from opi.input.structures import Structure +from opi.tasks.method_settings import DFTSettings +from opi.tasks.task_base import TaskSettings, TaskResults, SimpleTask + + +class OptSettings(TaskSettings): + _name:str = "opt" + task_keyword: typing.Annotated[SimpleKeyword, Task] = Task.OPT + opt_threshold: typing.Annotated[SimpleKeyword, OptThreshold] | None = None + optrigid: bool = False + opt_h: bool = False + lopt:bool = False + opt_maxiter: typing.Annotated[int, "BlockGeom", "maxiter"] | None = None + + def map_to_input(self, input_object: Input) -> Input: + input_object = super().map_to_input(input_object) + + if self.optrigid: + input_object.add_simple_keywords(Opt.RIGIDBODYOPT) + + opt_map = { + (True, True): Opt.L_OPTH, + (True, False): Opt.OPTH, + (False, True): Opt.L_OPT, + } + + keyword = opt_map.get((self.opt_h, self.lopt)) + if keyword: + input_object.add_simple_keywords(keyword) + + return input_object + + +class OptTask(SimpleTask): + def __init__(self, + method: str | SimpleKeyword, + basis_set: str | SimpleKeyword | None = None, + solvation_model: str | SimpleKeyword | None = None, + solvent: str | Solvent | None = None, + task: str | SimpleKeyword | None = None): + self._method_settings = DFTSettings( + method=method, basis_set=basis_set, solvation_model=solvation_model, solvent=solvent + ) + self._task_settings = ( + OptSettings(task_keyword=task) + ) if task else OptSettings() + + self._results_type = OptResults + + +class OptResults(TaskResults): + @property + @TaskResults.output_parse + def final_energy(self) -> float: + final_energy = self.output.get_final_energy() + + if final_energy is None: + raise ValueError("Could not get final energy from ORCA Output") + + return final_energy + + + @property + @TaskResults.output_parse + def structure(self) -> Structure: + structure = self.output.get_structure() + if structure is None: + raise ValueError("Could not get structure from ORCA Output") + + return structure + + + @property + def primary_property(self) -> tuple[float, Structure]: + return self.final_energy, self.structure + + + + diff --git a/src/opi/tasks/singlepointtask.py b/src/opi/tasks/singlepointtask.py new file mode 100644 index 00000000..9d98de5d --- /dev/null +++ b/src/opi/tasks/singlepointtask.py @@ -0,0 +1,69 @@ +import typing +from pathlib import Path + +from opi.input.simple_keywords import SimpleKeyword, Solvent, Task +from opi.input.structures import BaseStructureFile, Structure +from opi.tasks.method_settings import DFTSettings +from opi.tasks.task_base import SimpleTask, TaskResults, TaskSettings + + +class SinglePointSettings(TaskSettings): + _name: str = "singlepoint" + task_keyword: typing.Annotated[SimpleKeyword, Task] = Task.SP + + +class SinglePointTask(SimpleTask): + def __init__( + self, + method: str | SimpleKeyword, + basis_set: str | SimpleKeyword | None = None, + solvation_model: str | SimpleKeyword | None = None, + solvent: str | Solvent | None = None, + task: str | SimpleKeyword | None = None, + ): + self._method_settings = DFTSettings( + method=method, basis_set=basis_set, solvation_model=solvation_model, solvent=solvent + ) + self._task_settings = ( + SinglePointSettings(task_keyword=task) + ) if task else SinglePointSettings() + + self._results_type = SinglePointResults + + def run( + self, + basename: str, + struct: Structure | BaseStructureFile, + working_dir: Path = Path("RUN"), + ncores: int | None = None, + memory: int | None = None, + moinp: Path | None = None, + strict: bool = False + ) -> "SinglePointResults": + single_point_result = super().run( + basename=basename, + struct=struct, + working_dir=working_dir, + ncores=ncores, + memory=memory, + moinp=moinp, + strict=strict + ) + + return typing.cast(SinglePointResults, single_point_result) + + +class SinglePointResults(TaskResults): + @property + @TaskResults.output_parse + def final_energy(self) -> float: + final_energy = self.output.get_final_energy() + + if final_energy is None: + raise ValueError("Could not get final energy from ORCA Output") + + return final_energy + + @property + def primary_property(self) -> float: + return float(self.final_energy) diff --git a/src/opi/tasks/task_base.py b/src/opi/tasks/task_base.py new file mode 100644 index 00000000..c9af038f --- /dev/null +++ b/src/opi/tasks/task_base.py @@ -0,0 +1,534 @@ +import shutil +import typing +from abc import ABC, abstractmethod +from functools import cached_property, wraps +from pathlib import Path + +from pydantic import BaseModel, ConfigDict, field_validator, model_validator + +from opi.core import Calculator +from opi.input import Input +from opi.input.blocks import Block +from opi.input.simple_keywords import ( + BasisSet, + Method, + SimpleKeyword, + SimpleKeywordBox, + SolvationModel, + Solvent, +) +from opi.input.structures import BaseStructureFile, Structure +from opi.output.core import Output + + +class Settings(BaseModel): + """ + TODO: + - add checking for Solvent and SolvationModel now that they are optional. + """ + model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) + _name: str + + def __str__(self) -> str: + """ + String representation of `Settings`. Mostly for debugging purposes. + + Returns + ------- + str + String representation of `Settings`. + + """ + lines = [f"{self._name.title()} Settings:"] + for field_name, value in self.model_dump().items(): + lines.append(f" {field_name}: {value}") + return "\n".join(lines) + + @staticmethod + def _get_field_metadata(hint: typing.Any) -> tuple: + """ + Function to return the metadata of the type annotation of a field. + The type hints are first retrieved after which the metadata is extracted and then returned. + + Parameters + ---------- + hint: Type hint / annotation of field + + Returns + ------- + tuple + Tuple of metadata about the field. + + """ + origin = typing.get_origin(hint) + args = typing.get_args(hint) + if origin is typing.Union: + non_none_args = [arg for arg in args if arg is not type(None)] + if non_none_args: + return Settings._get_field_metadata(non_none_args[0]) + return args[1:] if len(args) > 1 else () + + @staticmethod + def _resolve_field_value( + value: typing.Any, metadata: tuple[type["SimpleKeywordBox"]] | tuple[str, str] + ) -> typing.Any: + """ + Function to translate user input into OPI compatible types. This is done using the metadata of the type + annotation associated with the field. There are two cases: + + Case 1 + ------ + Metadata has only one value. In this case the field is a simple keyword. Here the class associated with the input + option (eg: `Dft`) , will be checked to see if the user-given input option exists , and if it does , returns + the enum member associated with the input option. + + Case 2 + ------ + Metadata has two values, validator class and key. In this case the field is a block option. The associated block + class is first fetched, and then the attribute of the block is set to the user-given input option. This is then + validated by the block itself using Pydantic features. The validation function of the block translates user input + to OPI compatible types, which is then returned. + + + Parameters + ---------- + value: typing.Any + User input value. + metadata: tuple + Tuple of metadata about the field. + + Returns + ------- + typing.Any + User input value translated to OPI compatible types. + + """ + match metadata: + case (validator,): + return validator.find_keyword(value) + + case (validator, key): + block_cls = Block.get_subclass_by_name(validator) + instance = block_cls.model_validate({key: value}) + return getattr(instance, key) + + case _: + return value + + def _get_simple_keyword(self, validator: type[SimpleKeywordBox], value) -> SimpleKeyword: + if validator == SolvationModel: + solvent = getattr(self, "solvent", None) + new_keyword = value(solvent) + else: + new_keyword = value + + return new_keyword + + def map_to_input(self, input_object: Input) -> Input: + """ + Function to map all information held in `Settings` class to an `Input` class object. The function receives an + `Input` object , which may or may not be already populated, after which the function uses the type hints of + every field defined in the class to either fetch a `SimpleKeyword` from the appropriate Enum, and adds it to + `Input.simple_keywords`, or initializes the appropriate block with the attribute set to user defined value, and + adds it to `Input.blocks`. + + The modified `Input` object is then returned. + + Parameters + ---------- + input_object: Input + `Input` object to be modified + + Returns + ------- + Input + Modified `Input` object. + """ + hints = typing.get_type_hints(self.__class__, include_extras=True) + + for field_name, field_type in hints.items(): + value = getattr(self, field_name) + if value is None: + continue + + metadata = self._get_field_metadata(field_type) + + match metadata: + case (validator,): + if validator == Solvent: + continue + + new_keyword = self._get_simple_keyword(validator, value) + if new_keyword: + input_object.add_simple_keywords(new_keyword) + + case (validator, key): + block_type = Block.get_subclass_by_name(validator) + block_class = block_type(**{key: value}) + + block_exists, *_ = input_object.has_blocks(block_type) + if not block_exists: + input_object.add_blocks(block_class) + else: + existing_block = next(iter(input_object.get_blocks(block_type).values())) + new_block = existing_block + block_class + input_object.add_blocks(new_block, overwrite=True) + + return input_object + + @field_validator("*", mode="before") + @classmethod + def validate_fields(cls, value: typing.Any, info): + """ + This field validator handles validation upon reassignment of values of class attributes. + + This validator is applied to all fields and is executed prior to Pydantics internal validation, which is useful + for handling reassignment of values or custom preprocessing logic. + + The method does the following: + 1. Retrieves type hints, including metadata and checks whether current field has corresponding type hint. + 2. Extracts field specific metadata from type hint. + 3. Resolves incoming value using metadata. + + Parameters + ---------- + value: Any + User input value. + info + Object containing contextual information about field being validated. + + Returns + ------- + Any + User input value processed and converted to OPI compatible types. + """ + if value is None: + return value + + hints = typing.get_type_hints(cls, include_extras=True) + + if info.field_name not in hints: + return value + + hint = hints[info.field_name] + metadata = cls._get_field_metadata(hint) + return cls._resolve_field_value(value, metadata) + + @model_validator(mode="before") + @classmethod + def cross_validate(cls, data: dict[str, typing.Any]) -> dict[str, typing.Any]: + """ + Function to process and validate user input, this validator handles validation upon model initialization. Since + `self.validate_field()` already exists, this function will be reserved only for cross validation. + + Parameters + ---------- + data: dict + User input data. + + Returns + ------- + dict + Cross-validated user input data + + """ + if not isinstance(data, dict): + return data + + return data + + +class TaskSettings(Settings): + task_keyword: typing.Annotated[SimpleKeyword, SimpleKeywordBox] + + +class MethodSettings(Settings): + method: typing.Annotated[SimpleKeyword, Method] + basis_set: typing.Annotated[SimpleKeyword, BasisSet] | None = None + solvation_model: typing.Annotated[SimpleKeyword, SolvationModel] | None = None + solvent: typing.Annotated[str, Solvent] | None = None + + +class SimpleTask(ABC): + _results_type: type["TaskResults"] + _task_settings: TaskSettings + _method_settings: MethodSettings + + @property + def task_settings(self) -> TaskSettings: + return self._task_settings + + @property + def method_settings(self) -> MethodSettings: + return self._method_settings + + @property + def input_object(self) -> Input: + """ + Creates configured `Input` object. First it initializes an empty instance of `Input` , and then passes it as + to corresponding `TaskSettings` and `MethodSettings` objects to be configured by user-defined data stored in those + objects. + + Returns + ------- + `Input` + `Input` object configured by user-defined data. + + """ + inp = Input() + inp = self._task_settings.map_to_input(input_object=inp) + inp = self.method_settings.map_to_input(input_object=inp) + return inp + + def __getattr__(self, name: str) -> typing.Any: + """ + Dynamically resolve attribute access by delegating to internal settings objects. + + This method is called when an attribute is not found on the instance through + the normal lookup process. It attempts to retrieve the attribute from the + internal `_task_settings` and `_method_settings` objects, in that order. + + Parameters + ---------- + name: str + Attribute name. + + Returns + -------- + Any + Attribute value. + """ + if name.startswith("_"): + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + try: + return getattr(self._task_settings, name) + except AttributeError: + pass + + try: + return getattr(self._method_settings, name) + except AttributeError: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __setattr__(self, name: str, value: typing.Any) -> None: + """ + Dynamically assign attributes, delegating to internal settings objects when appropriate. + + This method overrides the default attribute assignment behavior to route + assignments to `_task_settings` or `_method_settings` if the attribute + exists. Otherwise, the attribute is set directly on the instance. + + Parameters + ---------- + name : str + The name of the attribute to assign. + value : Any + The value to assign to the attribute. + """ + # Allow setting private attributes and special attributes normally + if name.startswith("_") or name in ("task_settings", "run"): + super().__setattr__(name, value) + else: + # Check if _task_settings exists and has this attribute + if hasattr(self, "_task_settings") and hasattr(self._task_settings, name): + setattr(self._task_settings, name, value) + elif hasattr(self, "_method_settings") and hasattr(self._method_settings, name): + setattr(self._method_settings, name, value) + else: + super().__setattr__(name, value) + + def run( + self, + basename: str, + struct: Structure | BaseStructureFile, + working_dir: Path = Path("RUN"), + ncores: int | None = None, + memory: int | None = None, + moinp: Path | None = None, + strict: bool = False + ) -> "TaskResults": + """ + Execute the computational task with the given structure and settings. + + This method prepares the working directory, configures the calculation + input parameters, and runs the calculation using an external calculator. + The results are returned as an instance of the configured results type. + + Parameters + ---------- + basename : str + Base name for the calculation. + struct : Structure or BaseStructureFile + The input structure for the calculation. + working_dir : pathlib.Path, optional + Directory in which the calculation will be executed. + If it exists, it will be removed and recreated. Defaults to "RUN". + ncores : int, optional + Number of CPU cores to use for the calculation. Overrides the default + value in the input object if provided. + memory : int, optional + Amount of memory to allocate for the calculation. Overrides the default + value in the input object if provided. + moinp : pathlib.Path, optional + Path to a molecular orbital input file. Overrides the default if provided. + strict : bool, optional + Controls whether working directory will be created/overwritten. Defaults to False. + + Returns + ------- + TaskResults + An instance of the configured results type containing the results + of the calculation. + """ + if strict: + # Must already exist + if not working_dir.exists(): + raise ValueError(f"Working directory {working_dir.resolve()} does not exist (strict mode)") + + # Must be empty + if any(working_dir.iterdir()): + raise ValueError(f"Working directory {working_dir.resolve()} is not empty (strict mode)") + + else: + # Non-strict: recreate directory + if working_dir.exists(): + shutil.rmtree(working_dir) + working_dir.mkdir() + + inp = self.input_object + + if ncores is not None: + inp.ncores = ncores + + if memory is not None: + inp.memory = memory + + if moinp is not None: + inp.moinp = moinp + + calc = Calculator(basename, working_dir=working_dir) + calc.structure = struct + calc.input = inp + + calc.write_and_run() + + return self._results_type(calc_object=calc) + + def _restart( + self, + previous_results: "TaskResults", + basename: str | None = None, + struct: Structure | BaseStructureFile | None = None, + working_dir: Path | None = None, + ncores: int | None = None, + memory: int | None = None, + moinp: Path | None = None, + use_previous_orbitals: bool = False, + ) -> "TaskResults": + """ + TODO: + - finish restart implementation (low on priority list) + Parameters + ---------- + previous_results + basename + struct + working_dir + ncores + memory + moinp + use_previous_orbitals + + Returns + ------- + + """ + prev_calc = previous_results.calc_object + + basename = basename if basename else prev_calc.basename + struct = struct if struct else prev_calc.structure + working_dir = working_dir if working_dir else prev_calc.working_dir + ncores = ncores if ncores else prev_calc.input.ncores + memory = memory if memory else prev_calc.input.memory + + if use_previous_orbitals: + prev_gbw = prev_calc.working_dir / f"{prev_calc.basename}.gbw" + if not prev_gbw.exists(): + raise FileNotFoundError(f"GBW file not found: {prev_gbw}") + moinp = prev_gbw + else: + moinp = moinp if moinp else prev_calc.input.moinp + + return self.run(basename, struct, working_dir, ncores, memory, moinp) + + +class TaskResults(ABC): + def __init__(self, calc_object: Calculator): + self.calc_object = calc_object + self._parsed = False + + @staticmethod + def output_parse( + func: typing.Callable[["TaskResults"], typing.Any], + ) -> typing.Callable[["TaskResults"], typing.Any]: + """ + Decorator to ensure output parsing is performed before accessing results. + + This decorator wraps methods of a `TaskResults` instance and guarantees + that the associated output has been parsed before the method is executed. + Parsing is performed lazily and only once per instance. + + Parameters + ---------- + func : Callable[[TaskResults], Any] + The method to wrap. It must be a method of `TaskResults` that relies + on parsed output data. + + Returns + ------- + Callable[[TaskResults], Any] + A wrapped method that ensures `self.output.parse()` has been called + before delegating to the original function. + """ + + @wraps(func) + def wrapper(self: "TaskResults"): + if not self._parsed: + self.output.parse() + self._parsed = True + return func(self) + + return wrapper + + @cached_property + def output(self) -> Output: + if not self.calc_object: + raise ValueError("calc_object not set") + + return self.calc_object.get_output() + + @cached_property + def status(self) -> bool: + return self.output.terminated_normally() and self.output.scf_converged() + + @cached_property + @abstractmethod + def primary_property(self) -> typing.Any: + pass + + def __getattr__(self, name): + """ + First tries to get attribute from the object itself. + If not found, tries to get it from self.output. + """ + # Check if 'output' exists to avoid infinite recursion + if name == 'output': + raise AttributeError(f"'{type(self).__name__}' object has no attribute 'output'") + + # Try to get the attribute from self.output + try: + return getattr(self.output, name) + except AttributeError: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") +