From 125f439f77092c2ffc297b0fcced76bfae01acdf Mon Sep 17 00:00:00 2001 From: Nakul Santhosh Date: Wed, 25 Feb 2026 14:07:24 +0100 Subject: [PATCH 01/10] add from_string method for simple keywords --- src/opi/input/simple_keywords/base.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/opi/input/simple_keywords/base.py b/src/opi/input/simple_keywords/base.py index 85b70efe..e262fc45 100644 --- a/src/opi/input/simple_keywords/base.py +++ b/src/opi/input/simple_keywords/base.py @@ -1,7 +1,26 @@ __all__ = ("SimpleKeyword",) +_SIMPLE_KEYWORD_REGISTRY = [] class SimpleKeywordBox: + _registry = [] + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls._registry.append(cls) + + @classmethod + def from_string(cls, s): + norm = s.lower() + for c in cls._registry: + for attr, value in vars(c).items(): + if isinstance(value, SimpleKeyword) and value.keyword.lower() == norm: + return value + + raise ValueError(f"Keyword {s} not found.") + + +class Method(SimpleKeywordBox): pass From 665e5d32478c868241025c2533b306916b6f9976 Mon Sep 17 00:00:00 2001 From: Nakul Santhosh Date: Mon, 2 Mar 2026 16:03:12 +0100 Subject: [PATCH 02/10] registry --- src/opi/input/simple_keywords/base.py | 10 +++++++++- src/opi/input/simple_keywords/dft.py | 4 +++- src/opi/input/simple_keywords/function.py | 5 +++++ src/opi/input/simple_keywords/wft.py | 4 +++- 4 files changed, 20 insertions(+), 3 deletions(-) create mode 100644 src/opi/input/simple_keywords/function.py diff --git a/src/opi/input/simple_keywords/base.py b/src/opi/input/simple_keywords/base.py index e262fc45..c5e2575a 100644 --- a/src/opi/input/simple_keywords/base.py +++ b/src/opi/input/simple_keywords/base.py @@ -7,7 +7,15 @@ class SimpleKeywordBox: def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - cls._registry.append(cls) + cls._registry = [] + + for base in cls.__bases__: + if hasattr(base, "_registry"): + base._registry.append(cls) + + @classmethod + def registry(cls): + return cls._registry @classmethod def from_string(cls, s): diff --git a/src/opi/input/simple_keywords/dft.py b/src/opi/input/simple_keywords/dft.py index f37f969d..fdbda82f 100644 --- a/src/opi/input/simple_keywords/dft.py +++ b/src/opi/input/simple_keywords/dft.py @@ -5,8 +5,10 @@ __all__ = ("Dft",) +from opi.input.simple_keywords.function import Function -class Dft(SimpleKeywordBox): + +class Dft(Function): """Enum to store all simple keywords of type Dft.""" B3LYP3C = SimpleKeyword("b3lyp3c") diff --git a/src/opi/input/simple_keywords/function.py b/src/opi/input/simple_keywords/function.py new file mode 100644 index 00000000..d51115c0 --- /dev/null +++ b/src/opi/input/simple_keywords/function.py @@ -0,0 +1,5 @@ +from opi.input.simple_keywords import SimpleKeywordBox + + +class Function(SimpleKeywordBox): + pass \ No newline at end of file diff --git a/src/opi/input/simple_keywords/wft.py b/src/opi/input/simple_keywords/wft.py index e383e97c..4d946e55 100644 --- a/src/opi/input/simple_keywords/wft.py +++ b/src/opi/input/simple_keywords/wft.py @@ -5,8 +5,10 @@ __all__ = ("Wft",) +from opi.input.simple_keywords.function import Function -class Wft(SimpleKeywordBox): + +class Wft(Function): """Enum to store all simple keywords of type Wft.""" HF = SimpleKeyword("hf") From aa9d68fcdb1a8d108d2bad39029aeacdc65d1390 Mon Sep 17 00:00:00 2001 From: Nakul Santhosh Date: Mon, 16 Mar 2026 17:24:47 +0100 Subject: [PATCH 03/10] task params class, simple keyword lookup --- src/opi/input/blocks/base.py | 11 +++ src/opi/input/simple_keywords/base.py | 19 ++-- src/opi/tasks/task_base.py | 119 ++++++++++++++++++++++++++ 3 files changed, 144 insertions(+), 5 deletions(-) create mode 100644 src/opi/tasks/task_base.py diff --git a/src/opi/input/blocks/base.py b/src/opi/input/blocks/base.py index 3e2267e3..efdb7d05 100644 --- a/src/opi/input/blocks/base.py +++ b/src/opi/input/blocks/base.py @@ -189,3 +189,14 @@ def init_inputpath(cls, inp: Any) -> Any: return InputFilePath(file=inp) else: return inp + + @classmethod + def get_subclass_by_name(cls, name: str) -> type["Block"]: + matches = { + sub.__name__.lower(): sub + for sub in cls.__subclasses__() + } + match = matches.get(name.lower()) + if match is None: + raise ValueError(f"No Block subclass found with name {name!r}. Available: {list(matches.keys())}") + return match diff --git a/src/opi/input/simple_keywords/base.py b/src/opi/input/simple_keywords/base.py index c5e2575a..104dbe56 100644 --- a/src/opi/input/simple_keywords/base.py +++ b/src/opi/input/simple_keywords/base.py @@ -1,11 +1,9 @@ __all__ = ("SimpleKeyword",) -_SIMPLE_KEYWORD_REGISTRY = [] - class SimpleKeywordBox: - _registry = [] + _registry: list[type["SimpleKeywordBox"]] = [] - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls, **kwargs) -> None: super().__init_subclass__(**kwargs) cls._registry = [] @@ -13,8 +11,10 @@ def __init_subclass__(cls, **kwargs): if hasattr(base, "_registry"): base._registry.append(cls) + cls._registry.append(cls) + @classmethod - def registry(cls): + def registry(cls) -> list: return cls._registry @classmethod @@ -28,6 +28,15 @@ def from_string(cls, s): raise ValueError(f"Keyword {s} not found.") + @classmethod + def find_keyword(cls, inp: "SimpleKeyword | str") -> "SimpleKeyword": + if isinstance(inp, SimpleKeyword): + inp = inp.keyword + + return cls.from_string(inp) + + + class Method(SimpleKeywordBox): pass diff --git a/src/opi/tasks/task_base.py b/src/opi/tasks/task_base.py new file mode 100644 index 00000000..e8ac8fd1 --- /dev/null +++ b/src/opi/tasks/task_base.py @@ -0,0 +1,119 @@ +import shutil +import typing +from pathlib import Path +from typing import Any + +from pydantic import model_validator, BaseModel, ConfigDict + +from opi.core import Calculator +from opi.input import Input +from opi.input.blocks import Block +from opi.input.structures import Structure + + + +class TaskParams(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + + def map_to_input(self, input_object: Input) -> Input: + hints = typing.get_type_hints(self.__class__, include_extras=True) + + for field_name, field_type in hints.items(): + value = getattr(self, field_name) + + args = typing.get_args(field_type) + metadata = args[1:] + + + match metadata: + case (validator, ): + input_object.add_simple_keywords(value) + case (validator, key): + block_type = Block.get_subclass_by_name(validator) + block_class = block_type(**{key: value}) + + block_exists, *_ = input_object.has_blocks(block_type) + if not block_exists: + input_object.add_blocks(block_class) + else: + existing_block = next(iter(input_object.get_blocks(type(block_class)).values())) + new_block = block_type.model_validate({**existing_block.model_dump(), **block_class.model_dump(exclude_unset=True)}) + input_object.add_blocks(new_block, overwrite=True) + + return input_object + + + @model_validator(mode='before') + @classmethod + def validate(cls, data: dict) -> dict: + hints = typing.get_type_hints(cls, include_extras=True) + + for field_name, hint in hints.items(): + value = data.get(field_name) + print(f"{field_name}: {hint}") + if field_name not in data: + continue + + args = typing.get_args(hint) + metadata = args[1:] + + match metadata: + case (validator,): + keyword = validator.find_keyword(data[field_name]) + data[field_name] = keyword + + case (validator, key): + block_cls = Block.get_subclass_by_name(validator) + instance = block_cls.model_validate({key: value}) + data[field_name] = getattr(instance, key) + + + return data + + + +class Task(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + task_parameters: TaskParams + input_object: Input | None = Input() + + def __init__(self, /, **data: Any): + super().__init__(**data) + + + def run(self, basename: str, struct: Structure, working_dir:Path | None = Path("RUN"), ncores:int | None = None, memory:int | None = None, moinp: Path | None = None) -> "TaskResults": + + # > recreate the working dir + shutil.rmtree(working_dir, ignore_errors=True) + working_dir.mkdir() + + if ncores: + self.input_object.ncores = ncores + + if memory: + self.input_object.memory = memory + + if moinp: + self.input_object.moinp = moinp + + + self.input_object = self.task_parameters.map_to_input(input_object=self.input_object) + + calc = Calculator(basename, working_dir=working_dir) + calc.structure = struct + calc.input = self.input_object + + calc.write_and_run() + + return TaskResults() + + +class TaskResults: + pass + + + + + + From 1796784b7fa6593bd03e5d606db58ea1a34e55aa Mon Sep 17 00:00:00 2001 From: Nakul Santhosh Date: Wed, 25 Mar 2026 09:44:38 +0100 Subject: [PATCH 04/10] minimal implementation of Simple Tasks --- src/opi/input/blocks/base.py | 15 +++++- src/opi/input/blocks/block_ice.py | 2 +- src/opi/input/simple_keywords/base.py | 6 ++- .../input/simple_keywords/solvation_model.py | 2 +- src/opi/input/simple_keywords/solvent.py | 8 +++ src/opi/tasks/task_base.py | 49 ++++++++++++++----- 6 files changed, 65 insertions(+), 17 deletions(-) diff --git a/src/opi/input/blocks/base.py b/src/opi/input/blocks/base.py index efdb7d05..a9151518 100644 --- a/src/opi/input/blocks/base.py +++ b/src/opi/input/blocks/base.py @@ -169,6 +169,10 @@ def format_orca(self) -> str: return s + def __add__(self, other: "Block" ) -> "Block": + new_block = self.__class__.model_validate({**self.model_dump(), **other.model_dump(exclude_unset=True)}) + return new_block + @property def name(self) -> str: return self._name @@ -196,7 +200,16 @@ def get_subclass_by_name(cls, name: str) -> type["Block"]: sub.__name__.lower(): sub for sub in cls.__subclasses__() } - match = matches.get(name.lower()) + name_matches = { + sub().name.lower(): sub + for sub in cls.__subclasses__() + } + match = matches.get(name.lower()) or name_matches.get(name.lower()) + if match is None: raise ValueError(f"No Block subclass found with name {name!r}. Available: {list(matches.keys())}") + + return match + + diff --git a/src/opi/input/blocks/block_ice.py b/src/opi/input/blocks/block_ice.py index 8a119d0d..8c81f75d 100644 --- a/src/opi/input/blocks/block_ice.py +++ b/src/opi/input/blocks/block_ice.py @@ -21,4 +21,4 @@ class BlockIce(Block): etol: float | None = None icetype: Literal["CFGs", "CSFs", "DETs"] | None = None # > algorithm details - integrals: Literal["exact", "ri"] | None + integrals: Literal["exact", "ri"] | None = None diff --git a/src/opi/input/simple_keywords/base.py b/src/opi/input/simple_keywords/base.py index 104dbe56..aa974251 100644 --- a/src/opi/input/simple_keywords/base.py +++ b/src/opi/input/simple_keywords/base.py @@ -14,16 +14,18 @@ def __init_subclass__(cls, **kwargs) -> None: cls._registry.append(cls) @classmethod - def registry(cls) -> list: + def registry(cls) -> list[type["SimpleKeywordBox"]]: return cls._registry @classmethod - def from_string(cls, s): + def from_string(cls, s:str) -> "SimpleKeyword": norm = s.lower() for c in cls._registry: for attr, value in vars(c).items(): if isinstance(value, SimpleKeyword) and value.keyword.lower() == norm: return value + elif isinstance(value, SimpleKeyword) and attr.lower() == norm: + return value raise ValueError(f"Keyword {s} not found.") diff --git a/src/opi/input/simple_keywords/solvation_model.py b/src/opi/input/simple_keywords/solvation_model.py index dd22568c..0568f22f 100644 --- a/src/opi/input/simple_keywords/solvation_model.py +++ b/src/opi/input/simple_keywords/solvation_model.py @@ -12,7 +12,7 @@ def __init__(self, model: str) -> None: super().__init__(model) def __call__(self, solvent: Solvent, /) -> SimpleKeyword: - if not isinstance(solvent, Solvent): + if not isinstance(solvent, Solvent | str): raise TypeError(f"Solvent '{solvent}' is not of Solvent type!") return SimpleKeyword(f"{self.keyword}({solvent})") diff --git a/src/opi/input/simple_keywords/solvent.py b/src/opi/input/simple_keywords/solvent.py index 9ae95c30..e71f0cb5 100644 --- a/src/opi/input/simple_keywords/solvent.py +++ b/src/opi/input/simple_keywords/solvent.py @@ -242,3 +242,11 @@ class Solvent(StrEnum): WOCTANOL = "WOCTANOL" WETOCTANOL = "WETOCTANOL" CONDUCTOR = "CONDUCTOR" + + @classmethod + def find_keyword(cls, key: str) -> str: + norm = key.lower() + for member in cls: + if member.value.lower() == norm or member.name.lower() == norm: + return str(member.value) + raise ValueError(f"Key {key} not found in {cls.__name__}") diff --git a/src/opi/tasks/task_base.py b/src/opi/tasks/task_base.py index e8ac8fd1..70046f93 100644 --- a/src/opi/tasks/task_base.py +++ b/src/opi/tasks/task_base.py @@ -7,9 +7,10 @@ from opi.core import Calculator from opi.input import Input -from opi.input.blocks import Block +from opi.input.blocks import Block, BlockScf +from opi.input.simple_keywords import SimpleKeyword, BasisSet, SolvationModel, Solvent from opi.input.structures import Structure - +from opi.output.core import Output class TaskParams(BaseModel): @@ -28,7 +29,15 @@ def map_to_input(self, input_object: Input) -> Input: match metadata: case (validator, ): - input_object.add_simple_keywords(value) + if validator == SolvationModel: + if not self.solvent: + raise ValueError("Solvent not set") + new_keyword = value(self.solvent) + input_object.add_simple_keywords(new_keyword) + elif validator == Solvent: + continue + else: + input_object.add_simple_keywords(value) case (validator, key): block_type = Block.get_subclass_by_name(validator) block_class = block_type(**{key: value}) @@ -37,8 +46,8 @@ def map_to_input(self, input_object: Input) -> Input: if not block_exists: input_object.add_blocks(block_class) else: - existing_block = next(iter(input_object.get_blocks(type(block_class)).values())) - new_block = block_type.model_validate({**existing_block.model_dump(), **block_class.model_dump(exclude_unset=True)}) + existing_block = next(iter(input_object.get_blocks(block_type).values())) + new_block = existing_block + block_class input_object.add_blocks(new_block, overwrite=True) return input_object @@ -51,7 +60,6 @@ def validate(cls, data: dict) -> dict: for field_name, hint in hints.items(): value = data.get(field_name) - print(f"{field_name}: {hint}") if field_name not in data: continue @@ -75,14 +83,14 @@ def validate(cls, data: dict) -> dict: class Task(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - task_parameters: TaskParams - input_object: Input | None = Input() + # task_parameters: TaskParams + input_object: Input = Input() def __init__(self, /, **data: Any): super().__init__(**data) - def run(self, basename: str, struct: Structure, working_dir:Path | None = Path("RUN"), ncores:int | None = None, memory:int | None = None, moinp: Path | None = None) -> "TaskResults": + def run(self, basename: str, struct: Structure, working_dir:Path = Path("RUN"), ncores:int | None = None, memory:int | None = None, moinp: Path | None = None) -> "TaskResults": # > recreate the working dir shutil.rmtree(working_dir, ignore_errors=True) @@ -106,11 +114,28 @@ def run(self, basename: str, struct: Structure, working_dir:Path | None = Path(" calc.write_and_run() - return TaskResults() + output = calc.get_output() + + print(output.get_final_energy()) + + return TaskResults(calc_object=calc) + + +class TaskResults(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + calc_object: Calculator | None = None + # + # def __init__(self, calc: Calculator, /, **data: Any): + # super().__init__(**data) + # self.calc_object = calc + + @property + def output(self) -> Output: + if not self.calc_object: + raise ValueError("calc_object not set") + return self.calc_object.get_output() -class TaskResults: - pass From 9fa47b5cd3d230615a879d7e3066a44ac99a75ac Mon Sep 17 00:00:00 2001 From: Nakul Santhosh Date: Wed, 25 Mar 2026 16:50:25 +0100 Subject: [PATCH 05/10] flesh out structure of base Task classes --- src/opi/input/blocks/base.py | 23 ++--- src/opi/input/simple_keywords/base.py | 15 ++- src/opi/input/simple_keywords/dft.py | 4 +- src/opi/input/simple_keywords/function.py | 5 - src/opi/input/simple_keywords/wft.py | 4 +- src/opi/tasks/singlepointtask.py | 71 ++++++++++++++ src/opi/tasks/task_base.py | 108 ++++++++++++---------- 7 files changed, 146 insertions(+), 84 deletions(-) delete mode 100644 src/opi/input/simple_keywords/function.py create mode 100644 src/opi/tasks/singlepointtask.py diff --git a/src/opi/input/blocks/base.py b/src/opi/input/blocks/base.py index a9151518..348e556e 100644 --- a/src/opi/input/blocks/base.py +++ b/src/opi/input/blocks/base.py @@ -169,8 +169,10 @@ def format_orca(self) -> str: return s - def __add__(self, other: "Block" ) -> "Block": - new_block = self.__class__.model_validate({**self.model_dump(), **other.model_dump(exclude_unset=True)}) + def __add__(self, other: "Block") -> "Block": + new_block = self.__class__.model_validate( + {**self.model_dump(), **other.model_dump(exclude_unset=True)} + ) return new_block @property @@ -196,20 +198,13 @@ def init_inputpath(cls, inp: Any) -> Any: @classmethod def get_subclass_by_name(cls, name: str) -> type["Block"]: - matches = { - sub.__name__.lower(): sub - for sub in cls.__subclasses__() - } - name_matches = { - sub().name.lower(): sub - for sub in cls.__subclasses__() - } + matches = {sub.__name__.lower(): sub for sub in cls.__subclasses__()} + name_matches = {sub().name.lower(): sub for sub in cls.__subclasses__()} match = matches.get(name.lower()) or name_matches.get(name.lower()) if match is None: - raise ValueError(f"No Block subclass found with name {name!r}. Available: {list(matches.keys())}") - + raise ValueError( + f"No Block subclass found with name {name!r}. Available: {list(matches.keys())}" + ) return match - - diff --git a/src/opi/input/simple_keywords/base.py b/src/opi/input/simple_keywords/base.py index aa974251..7366859b 100644 --- a/src/opi/input/simple_keywords/base.py +++ b/src/opi/input/simple_keywords/base.py @@ -1,9 +1,12 @@ -__all__ = ("SimpleKeyword",) +__all__ = ("SimpleKeyword", "SimpleKeywordBox") + +from typing import Any + class SimpleKeywordBox: _registry: list[type["SimpleKeywordBox"]] = [] - def __init_subclass__(cls, **kwargs) -> None: + def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) cls._registry = [] @@ -18,7 +21,7 @@ def registry(cls) -> list[type["SimpleKeywordBox"]]: return cls._registry @classmethod - def from_string(cls, s:str) -> "SimpleKeyword": + def from_string(cls, s: str) -> "SimpleKeyword": norm = s.lower() for c in cls._registry: for attr, value in vars(c).items(): @@ -29,7 +32,6 @@ def from_string(cls, s:str) -> "SimpleKeyword": raise ValueError(f"Keyword {s} not found.") - @classmethod def find_keyword(cls, inp: "SimpleKeyword | str") -> "SimpleKeyword": if isinstance(inp, SimpleKeyword): @@ -38,11 +40,6 @@ def find_keyword(cls, inp: "SimpleKeyword | str") -> "SimpleKeyword": return cls.from_string(inp) - -class Method(SimpleKeywordBox): - pass - - class SimpleKeyword: """ Class to represent simple keywords used in ORCA input files diff --git a/src/opi/input/simple_keywords/dft.py b/src/opi/input/simple_keywords/dft.py index fdbda82f..f37f969d 100644 --- a/src/opi/input/simple_keywords/dft.py +++ b/src/opi/input/simple_keywords/dft.py @@ -5,10 +5,8 @@ __all__ = ("Dft",) -from opi.input.simple_keywords.function import Function - -class Dft(Function): +class Dft(SimpleKeywordBox): """Enum to store all simple keywords of type Dft.""" B3LYP3C = SimpleKeyword("b3lyp3c") diff --git a/src/opi/input/simple_keywords/function.py b/src/opi/input/simple_keywords/function.py deleted file mode 100644 index d51115c0..00000000 --- a/src/opi/input/simple_keywords/function.py +++ /dev/null @@ -1,5 +0,0 @@ -from opi.input.simple_keywords import SimpleKeywordBox - - -class Function(SimpleKeywordBox): - pass \ No newline at end of file diff --git a/src/opi/input/simple_keywords/wft.py b/src/opi/input/simple_keywords/wft.py index 4d946e55..e383e97c 100644 --- a/src/opi/input/simple_keywords/wft.py +++ b/src/opi/input/simple_keywords/wft.py @@ -5,10 +5,8 @@ __all__ = ("Wft",) -from opi.input.simple_keywords.function import Function - -class Wft(Function): +class Wft(SimpleKeywordBox): """Enum to store all simple keywords of type Wft.""" HF = SimpleKeyword("hf") diff --git a/src/opi/tasks/singlepointtask.py b/src/opi/tasks/singlepointtask.py new file mode 100644 index 00000000..fbd9c906 --- /dev/null +++ b/src/opi/tasks/singlepointtask.py @@ -0,0 +1,71 @@ +import typing +from pathlib import Path + +from opi.input.simple_keywords import BasisSet, Method, SimpleKeyword, SolvationModel, Solvent +from opi.input.structures import Structure +from opi.tasks.task_base import Task, TaskParams, TaskResults + + +class SinglePointParams(TaskParams): + method: typing.Annotated[SimpleKeyword, Method] + basis_set: typing.Annotated[SimpleKeyword, BasisSet] + solvation_model: typing.Annotated[SimpleKeyword, SolvationModel] + solvent: typing.Annotated[str, Solvent] + + +class SinglePointTask(Task): + def __init__( + self, + method: str | SimpleKeyword, + basis_set: str | SimpleKeyword, + solvation_model: str | SolvationModel, + solvent: str | Solvent, + ): + self._task_parameters = SinglePointParams( + method=method, basis_set=basis_set, solvation_model=solvation_model, solvent=solvent + ) + self._results_type = SinglePointResults + + @property + def task_parameters(self) -> SinglePointParams: + return self._task_parameters + + def run( + self, + basename: str, + struct: Structure, + working_dir: Path = Path("RUN"), + ncores: int | None = None, + memory: int | None = None, + moinp: Path | None = None, + ) -> "SinglePointResults": + single_point_result = super().run( + basename=basename, + struct=struct, + working_dir=working_dir, + ncores=ncores, + memory=memory, + moinp=moinp, + ) + + return single_point_result + + +class SinglePointResults(TaskResults): + @property + def status(self) -> bool: + return self.output.terminated_normally() and self.output.scf_converged() + + @property + @TaskResults.output_parse + def final_energy(self) -> float: + final_energy = self.output.get_final_energy() + + if final_energy is None: + raise ValueError("Could not get final energy from ORCA Output") + + return final_energy + + @property + def primary_property(self) -> float: + return float(self.final_energy) diff --git a/src/opi/tasks/task_base.py b/src/opi/tasks/task_base.py index 70046f93..9dbe9e46 100644 --- a/src/opi/tasks/task_base.py +++ b/src/opi/tasks/task_base.py @@ -1,21 +1,21 @@ import shutil import typing +from abc import ABC, abstractmethod +from functools import cached_property, wraps from pathlib import Path -from typing import Any -from pydantic import model_validator, BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator from opi.core import Calculator from opi.input import Input -from opi.input.blocks import Block, BlockScf -from opi.input.simple_keywords import SimpleKeyword, BasisSet, SolvationModel, Solvent +from opi.input.blocks import Block +from opi.input.simple_keywords import SolvationModel, Solvent from opi.input.structures import Structure from opi.output.core import Output class TaskParams(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - + model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) def map_to_input(self, input_object: Input) -> Input: hints = typing.get_type_hints(self.__class__, include_extras=True) @@ -26,13 +26,13 @@ def map_to_input(self, input_object: Input) -> Input: args = typing.get_args(field_type) metadata = args[1:] - match metadata: - case (validator, ): + case (validator,): if validator == SolvationModel: - if not self.solvent: + solvent = getattr(self, "solvent", None) + if not solvent: raise ValueError("Solvent not set") - new_keyword = value(self.solvent) + new_keyword = value(solvent) input_object.add_simple_keywords(new_keyword) elif validator == Solvent: continue @@ -52,8 +52,7 @@ def map_to_input(self, input_object: Input) -> Input: return input_object - - @model_validator(mode='before') + @model_validator(mode="before") @classmethod def validate(cls, data: dict) -> dict: hints = typing.get_type_hints(cls, include_extras=True) @@ -76,69 +75,78 @@ def validate(cls, data: dict) -> dict: instance = block_cls.model_validate({key: value}) data[field_name] = getattr(instance, key) - return data +class Task(ABC): + _results_type: type["TaskResults"] -class Task(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - # task_parameters: TaskParams - input_object: Input = Input() - - def __init__(self, /, **data: Any): - super().__init__(**data) - + @property + @abstractmethod + def task_parameters(self) -> TaskParams: + pass - def run(self, basename: str, struct: Structure, working_dir:Path = Path("RUN"), ncores:int | None = None, memory:int | None = None, moinp: Path | None = None) -> "TaskResults": + @property + def input_object(self) -> Input: + inp = Input() + inp = self.task_parameters.map_to_input(input_object=inp) + return inp + + def run( + self, + basename: str, + struct: Structure, + working_dir: Path = Path("RUN"), + ncores: int | None = None, + memory: int | None = None, + moinp: Path | None = None, + ) -> "TaskResults": # > recreate the working dir shutil.rmtree(working_dir, ignore_errors=True) working_dir.mkdir() - if ncores: - self.input_object.ncores = ncores - - if memory: - self.input_object.memory = memory + inp = self.input_object - if moinp: - self.input_object.moinp = moinp + if ncores is not None: + inp.ncores = ncores + if memory is not None: + inp.memory = memory - self.input_object = self.task_parameters.map_to_input(input_object=self.input_object) + if moinp is not None: + inp.moinp = moinp calc = Calculator(basename, working_dir=working_dir) calc.structure = struct - calc.input = self.input_object + calc.input = inp calc.write_and_run() - output = calc.get_output() + return self._results_type(calc_object=calc) - print(output.get_final_energy()) - return TaskResults(calc_object=calc) +class TaskResults(ABC): + def __init__(self, calc_object: Calculator): + self.calc_object = calc_object + self._parsed = False + @staticmethod + def output_parse( + func: typing.Callable[["TaskResults"], typing.Any], + ) -> typing.Callable[["TaskResults"], typing.Any]: + @wraps(func) + def wrapper(self: "TaskResults"): + if not self._parsed: + self.output.parse() + self._parsed = True + return func(self) -class TaskResults(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - calc_object: Calculator | None = None - # - # def __init__(self, calc: Calculator, /, **data: Any): - # super().__init__(**data) - # self.calc_object = calc + return wrapper - - @property + @cached_property def output(self) -> Output: if not self.calc_object: raise ValueError("calc_object not set") - return self.calc_object.get_output() - - - - - - + return self.calc_object.get_output() From dac75a28dbe7377d7b91226d90c39a00b85286e0 Mon Sep 17 00:00:00 2001 From: Nakul Santhosh Date: Fri, 27 Mar 2026 16:27:36 +0100 Subject: [PATCH 06/10] added field validator for assignment and overloaded getters and setters for easier access to attributes --- src/opi/tasks/singlepointtask.py | 22 +++++++++++++++++++++ src/opi/tasks/task_base.py | 33 ++++++++++++++++++++++++++++++-- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/opi/tasks/singlepointtask.py b/src/opi/tasks/singlepointtask.py index fbd9c906..a5944e11 100644 --- a/src/opi/tasks/singlepointtask.py +++ b/src/opi/tasks/singlepointtask.py @@ -30,6 +30,28 @@ def __init__( def task_parameters(self) -> SinglePointParams: return self._task_parameters + def __getattr__(self, name): + """Delegate attribute access to _task_parameters.""" + if name.startswith('_'): + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + try: + return getattr(self._task_parameters, name) + except AttributeError: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __setattr__(self, name, value): + """Delegate attribute setting to _task_parameters.""" + # Allow setting private attributes and special attributes normally + if name.startswith('_') or name in ('task_parameters', 'run'): + super().__setattr__(name, value) + else: + # Check if _task_parameters exists and has this attribute + if hasattr(self, '_task_parameters') and hasattr(self._task_parameters, name): + setattr(self._task_parameters, name, value) + else: + super().__setattr__(name, value) + def run( self, basename: str, diff --git a/src/opi/tasks/task_base.py b/src/opi/tasks/task_base.py index 9dbe9e46..f086c118 100644 --- a/src/opi/tasks/task_base.py +++ b/src/opi/tasks/task_base.py @@ -4,7 +4,7 @@ from functools import cached_property, wraps from pathlib import Path -from pydantic import BaseModel, ConfigDict, model_validator +from pydantic import BaseModel, ConfigDict, model_validator, field_validator from opi.core import Calculator from opi.input import Input @@ -52,6 +52,31 @@ def map_to_input(self, input_object: Input) -> Input: return input_object + + @field_validator('*', mode='before') + @classmethod + def validate_each_field(cls, value, info): + field_name = info.field_name + hints = typing.get_type_hints(cls, include_extras=True) + + if field_name not in hints: + return value + + hint = hints[field_name] + args = typing.get_args(hint) + metadata = args[1:] if len(args) > 1 else () + + match metadata: + case (validator, ): + return validator.find_keyword(value) + + case (validator, key): + block_cls = Block.get_subclass_by_name(validator) + instance = block_cls.model_validate({key: value}) + return getattr(instance, key) + + return value + @model_validator(mode="before") @classmethod def validate(cls, data: dict) -> dict: @@ -83,7 +108,7 @@ class Task(ABC): @property @abstractmethod - def task_parameters(self) -> TaskParams: + def task_parameters(self) -> TaskParams : pass @property @@ -126,6 +151,10 @@ def run( return self._results_type(calc_object=calc) + def change_parameter(self, param: str, value: typing.Any) -> None: + setattr(self.task_parameters, param, value) + + class TaskResults(ABC): def __init__(self, calc_object: Calculator): self.calc_object = calc_object From 038e98ced47d229d409de3a2d36dbc7fcc0f7ec3 Mon Sep 17 00:00:00 2001 From: Nakul Santhosh Date: Mon, 30 Mar 2026 18:02:04 +0200 Subject: [PATCH 07/10] added restart function, optimized duplicate code --- src/opi/tasks/singlepointtask.py | 31 +-------- src/opi/tasks/task_base.py | 111 +++++++++++++++++++++---------- 2 files changed, 79 insertions(+), 63 deletions(-) diff --git a/src/opi/tasks/singlepointtask.py b/src/opi/tasks/singlepointtask.py index a5944e11..8f6a19fa 100644 --- a/src/opi/tasks/singlepointtask.py +++ b/src/opi/tasks/singlepointtask.py @@ -2,7 +2,7 @@ from pathlib import Path from opi.input.simple_keywords import BasisSet, Method, SimpleKeyword, SolvationModel, Solvent -from opi.input.structures import Structure +from opi.input.structures import Structure, BaseStructureFile from opi.tasks.task_base import Task, TaskParams, TaskResults @@ -26,36 +26,11 @@ def __init__( ) self._results_type = SinglePointResults - @property - def task_parameters(self) -> SinglePointParams: - return self._task_parameters - - def __getattr__(self, name): - """Delegate attribute access to _task_parameters.""" - if name.startswith('_'): - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") - - try: - return getattr(self._task_parameters, name) - except AttributeError: - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") - - def __setattr__(self, name, value): - """Delegate attribute setting to _task_parameters.""" - # Allow setting private attributes and special attributes normally - if name.startswith('_') or name in ('task_parameters', 'run'): - super().__setattr__(name, value) - else: - # Check if _task_parameters exists and has this attribute - if hasattr(self, '_task_parameters') and hasattr(self._task_parameters, name): - setattr(self._task_parameters, name, value) - else: - super().__setattr__(name, value) def run( self, basename: str, - struct: Structure, + struct: Structure | BaseStructureFile, working_dir: Path = Path("RUN"), ncores: int | None = None, memory: int | None = None, @@ -70,7 +45,7 @@ def run( moinp=moinp, ) - return single_point_result + return typing.cast(SinglePointResults ,single_point_result) class SinglePointResults(TaskResults): diff --git a/src/opi/tasks/task_base.py b/src/opi/tasks/task_base.py index f086c118..f6e08199 100644 --- a/src/opi/tasks/task_base.py +++ b/src/opi/tasks/task_base.py @@ -10,13 +10,39 @@ from opi.input import Input from opi.input.blocks import Block from opi.input.simple_keywords import SolvationModel, Solvent -from opi.input.structures import Structure +from opi.input.structures import Structure, BaseStructureFile from opi.output.core import Output class TaskParams(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) + def __str__(self) -> str: + lines = ["Task Parameters:"] + for field_name, value in self.model_dump().items(): + lines.append(f" {field_name}: {value}") + return "\n".join(lines) + + @staticmethod + def _get_field_metadata(hint: typing.Any) -> tuple: + args = typing.get_args(hint) + return args[1:] if len(args) > 1 else () + + + @staticmethod + def _resolve_field_value(value: typing.Any, metadata: tuple) -> typing.Any: + match metadata: + case (validator,): + return validator.find_keyword(value) + + case (validator, key): + block_cls = Block.get_subclass_by_name(validator) + instance = block_cls.model_validate({key: value}) + return getattr(instance, key) + + case _: + return value + def map_to_input(self, input_object: Input) -> Input: hints = typing.get_type_hints(self.__class__, include_extras=True) @@ -55,27 +81,15 @@ def map_to_input(self, input_object: Input) -> Input: @field_validator('*', mode='before') @classmethod - def validate_each_field(cls, value, info): - field_name = info.field_name + def validate_field(cls, value, info): hints = typing.get_type_hints(cls, include_extras=True) - if field_name not in hints: + if info.field_name not in hints: return value - hint = hints[field_name] - args = typing.get_args(hint) - metadata = args[1:] if len(args) > 1 else () - - match metadata: - case (validator, ): - return validator.find_keyword(value) - - case (validator, key): - block_cls = Block.get_subclass_by_name(validator) - instance = block_cls.model_validate({key: value}) - return getattr(instance, key) - - return value + hint = hints[info.field_name] + metadata = cls._get_field_metadata(hint) + return cls._resolve_field_value(value, metadata) @model_validator(mode="before") @classmethod @@ -83,33 +97,21 @@ def validate(cls, data: dict) -> dict: hints = typing.get_type_hints(cls, include_extras=True) for field_name, hint in hints.items(): - value = data.get(field_name) if field_name not in data: continue - args = typing.get_args(hint) - metadata = args[1:] - - match metadata: - case (validator,): - keyword = validator.find_keyword(data[field_name]) - data[field_name] = keyword - - case (validator, key): - block_cls = Block.get_subclass_by_name(validator) - instance = block_cls.model_validate({key: value}) - data[field_name] = getattr(instance, key) + metadata = cls._get_field_metadata(hint) + data[field_name] = cls._resolve_field_value(data[field_name], metadata) return data - class Task(ABC): _results_type: type["TaskResults"] + _task_parameters: TaskParams @property - @abstractmethod def task_parameters(self) -> TaskParams : - pass + return self._task_parameters @property def input_object(self) -> Input: @@ -117,10 +119,30 @@ def input_object(self) -> Input: inp = self.task_parameters.map_to_input(input_object=inp) return inp + def __getattr__(self, name): + if name.startswith('_'): + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + try: + return getattr(self._task_parameters, name) + except AttributeError: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __setattr__(self, name, value): + # Allow setting private attributes and special attributes normally + if name.startswith('_') or name in ('task_parameters', 'run'): + super().__setattr__(name, value) + else: + # Check if _task_parameters exists and has this attribute + if hasattr(self, '_task_parameters') and hasattr(self._task_parameters, name): + setattr(self._task_parameters, name, value) + else: + super().__setattr__(name, value) + def run( self, basename: str, - struct: Structure, + struct: Structure | BaseStructureFile, working_dir: Path = Path("RUN"), ncores: int | None = None, memory: int | None = None, @@ -155,6 +177,25 @@ def change_parameter(self, param: str, value: typing.Any) -> None: setattr(self.task_parameters, param, value) + def restart(self, previous_results: "TaskResults", + basename: str|None = None, + struct: Structure| BaseStructureFile| None = None, + working_dir: Path | None = None, + ncores: int | None = None, + memory: int | None = None, + moinp: Path | None = None, + use_previous_orbitals: bool = False ) -> "TaskResults": + + basename = basename if basename else previous_results.calc_object.basename + struct = struct if struct else previous_results.calc_object.structure + working_dir = working_dir if working_dir else previous_results.calc_object.working_dir + ncores = ncores if ncores else previous_results.calc_object.input.ncores + memory = memory if memory else previous_results.calc_object.input.memory + moinp = moinp if moinp else previous_results.calc_object.input.moinp + + return self.run(basename, struct, working_dir, ncores, memory, moinp) + + class TaskResults(ABC): def __init__(self, calc_object: Calculator): self.calc_object = calc_object From e1cb6abb9a659ebd4f398f5c90add45a3dbf3ce4 Mon Sep 17 00:00:00 2001 From: Nakul Santhosh Date: Wed, 1 Apr 2026 16:26:20 +0200 Subject: [PATCH 08/10] added docstrings , renamed TaskParams to Settings, added TaskSettings and MethodSettings --- src/opi/tasks/method_settings.py | 9 + src/opi/tasks/singlepointtask.py | 30 ++- src/opi/tasks/task_base.py | 379 ++++++++++++++++++++++++++----- 3 files changed, 347 insertions(+), 71 deletions(-) create mode 100644 src/opi/tasks/method_settings.py diff --git a/src/opi/tasks/method_settings.py b/src/opi/tasks/method_settings.py new file mode 100644 index 00000000..58a2f3ec --- /dev/null +++ b/src/opi/tasks/method_settings.py @@ -0,0 +1,9 @@ +import typing + +from opi.input.simple_keywords import Dft, SimpleKeyword +from opi.tasks.task_base import MethodSettings + + +class DFTSettings(MethodSettings): + method: typing.Annotated[SimpleKeyword, Dft] + _name: str = "dft" diff --git a/src/opi/tasks/singlepointtask.py b/src/opi/tasks/singlepointtask.py index 8f6a19fa..84501b5d 100644 --- a/src/opi/tasks/singlepointtask.py +++ b/src/opi/tasks/singlepointtask.py @@ -1,32 +1,34 @@ import typing from pathlib import Path -from opi.input.simple_keywords import BasisSet, Method, SimpleKeyword, SolvationModel, Solvent -from opi.input.structures import Structure, BaseStructureFile -from opi.tasks.task_base import Task, TaskParams, TaskResults +from opi.input.simple_keywords import SimpleKeyword, SolvationModel, Solvent, Task +from opi.input.structures import BaseStructureFile, Structure +from opi.tasks.method_settings import DFTSettings +from opi.tasks.task_base import SimpleTask, TaskResults, TaskSettings -class SinglePointParams(TaskParams): - method: typing.Annotated[SimpleKeyword, Method] - basis_set: typing.Annotated[SimpleKeyword, BasisSet] - solvation_model: typing.Annotated[SimpleKeyword, SolvationModel] - solvent: typing.Annotated[str, Solvent] +class SinglePointSettings(TaskSettings): + _name: str = "singlepoint" + task_keyword: typing.Annotated[SimpleKeyword, Task] = Task.SP -class SinglePointTask(Task): +class SinglePointTask(SimpleTask): def __init__( self, method: str | SimpleKeyword, basis_set: str | SimpleKeyword, solvation_model: str | SolvationModel, solvent: str | Solvent, + task: str | SimpleKeyword | None = None, ): - self._task_parameters = SinglePointParams( + self._method_settings = DFTSettings( method=method, basis_set=basis_set, solvation_model=solvation_model, solvent=solvent ) + self._task_settings = ( + SinglePointSettings(task_keyword=task) if task else SinglePointSettings() + ) self._results_type = SinglePointResults - def run( self, basename: str, @@ -45,14 +47,10 @@ def run( moinp=moinp, ) - return typing.cast(SinglePointResults ,single_point_result) + return typing.cast(SinglePointResults, single_point_result) class SinglePointResults(TaskResults): - @property - def status(self) -> bool: - return self.output.terminated_normally() and self.output.scf_converged() - @property @TaskResults.output_parse def final_energy(self) -> float: diff --git a/src/opi/tasks/task_base.py b/src/opi/tasks/task_base.py index f6e08199..88f7cadd 100644 --- a/src/opi/tasks/task_base.py +++ b/src/opi/tasks/task_base.py @@ -4,33 +4,96 @@ from functools import cached_property, wraps from pathlib import Path -from pydantic import BaseModel, ConfigDict, model_validator, field_validator +from pydantic import BaseModel, ConfigDict, field_validator, model_validator from opi.core import Calculator from opi.input import Input from opi.input.blocks import Block -from opi.input.simple_keywords import SolvationModel, Solvent -from opi.input.structures import Structure, BaseStructureFile +from opi.input.simple_keywords import ( + BasisSet, + Method, + SimpleKeyword, + SimpleKeywordBox, + SolvationModel, + Solvent, +) +from opi.input.structures import BaseStructureFile, Structure from opi.output.core import Output -class TaskParams(BaseModel): +class Settings(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) + _name: str def __str__(self) -> str: - lines = ["Task Parameters:"] + """ + String representation of `Settings`. Mostly for debugging purposes. + + Returns + ------- + str + String representation of `Settings`. + + """ + lines = [f"{self._name.title()} Settings:"] for field_name, value in self.model_dump().items(): lines.append(f" {field_name}: {value}") return "\n".join(lines) @staticmethod def _get_field_metadata(hint: typing.Any) -> tuple: + """ + Function to return the metadata of the type annotation of a field. + The type hints are first retrieved after which the metadata is extracted and then returned. + + Parameters + ---------- + hint: Type hint / annotation of field + + Returns + ------- + tuple + Tuple of metadata about the field. + + """ args = typing.get_args(hint) return args[1:] if len(args) > 1 else () - @staticmethod - def _resolve_field_value(value: typing.Any, metadata: tuple) -> typing.Any: + def _resolve_field_value( + value: typing.Any, metadata: tuple[type["SimpleKeywordBox"]] | tuple[str, str] + ) -> typing.Any: + """ + Function to translate user input into OPI compatible types. This is done using the metadata of the type + annotation associated with the field. There are two cases: + + Case 1 + ------ + Metadata has only one value. In this case the field is a simple keyword. Here the class associated with the input + option (eg: `Dft`) , will be checked to see if the user-given input option exists , and if it does , returns + the enum member associated with the input option. + + Case 2 + ------ + Metadata has two values, validator class and key. In this case the field is a block option. The associated block + class is first fetched, and then the attribute of the block is set to the user-given input option. This is then + validated by the block itself using Pydantic features. The validation function of the block translates user input + to OPI compatible types, which is then returned. + + + Parameters + ---------- + value: typing.Any + User input value. + metadata: tuple + Tuple of metadata about the field. + + Returns + ------- + typing.Any + User input value translated to OPI compatible types. + + """ match metadata: case (validator,): return validator.find_keyword(value) @@ -43,7 +106,35 @@ def _resolve_field_value(value: typing.Any, metadata: tuple) -> typing.Any: case _: return value + def _get_simple_keyword(self, validator: type[SimpleKeywordBox], value) -> SimpleKeyword: + if validator == SolvationModel: + solvent = getattr(self, "solvent", None) + new_keyword = value(solvent) + else: + new_keyword = value + + return new_keyword + def map_to_input(self, input_object: Input) -> Input: + """ + Function to map all information held in `Settings` class to an `Input` class object. The function receives an + `Input` object , which may or may not be already populated, after which the function uses the type hints of + every field defined in the class to either fetch a `SimpleKeyword` from the appropriate Enum, and adds it to + `Input.simple_keywords`, or initializes the appropriate block with the attribute set to user defined value, and + adds it to `Input.blocks`. + + The modified `Input` object is then returned. + + Parameters + ---------- + input_object: Input + `Input` object to be modified + + Returns + ------- + Input + Modified `Input` object. + """ hints = typing.get_type_hints(self.__class__, include_extras=True) for field_name, field_type in hints.items(): @@ -54,16 +145,12 @@ def map_to_input(self, input_object: Input) -> Input: match metadata: case (validator,): - if validator == SolvationModel: - solvent = getattr(self, "solvent", None) - if not solvent: - raise ValueError("Solvent not set") - new_keyword = value(solvent) - input_object.add_simple_keywords(new_keyword) - elif validator == Solvent: + if validator == Solvent: continue - else: - input_object.add_simple_keywords(value) + + new_keyword = self._get_simple_keyword(validator, value) + input_object.add_simple_keywords(new_keyword) + case (validator, key): block_type = Block.get_subclass_by_name(validator) block_class = block_type(**{key: value}) @@ -78,10 +165,32 @@ def map_to_input(self, input_object: Input) -> Input: return input_object - - @field_validator('*', mode='before') + @field_validator("*", mode="before") @classmethod - def validate_field(cls, value, info): + def validate_fields(cls, value: typing.Any, info): + """ + This field validator handles validation upon reassignment of values of class attributes. + + This validator is applied to all fields and is executed prior to Pydantics internal validation, which is useful + for handling reassignment of values or custom preprocessing logic. + + The method does the following: + 1. Retrieves type hints, including metadata and checks whether current field has corresponding type hint. + 2. Extracts field specific metadata from type hint. + 3. Resolves incoming value using metadata. + + Parameters + ---------- + value: Any + User input value. + info + Object containing contextual information about field being validated. + + Returns + ------- + Any + User input value processed and converted to OPI compatible types. + """ hints = typing.get_type_hints(cls, include_extras=True) if info.field_name not in hints: @@ -93,49 +202,125 @@ def validate_field(cls, value, info): @model_validator(mode="before") @classmethod - def validate(cls, data: dict) -> dict: - hints = typing.get_type_hints(cls, include_extras=True) + def cross_validate(cls, data: dict[str, typing.Any]) -> dict[str, typing.Any]: + """ + Function to process and validate user input, this validator handles validation upon model initialization. Since + `self.validate_field()` already exists, this function will be reserved only for cross validation. + + Parameters + ---------- + data: dict + User input data. - for field_name, hint in hints.items(): - if field_name not in data: - continue + Returns + ------- + dict + Cross-validated user input data - metadata = cls._get_field_metadata(hint) - data[field_name] = cls._resolve_field_value(data[field_name], metadata) + """ + if not isinstance(data, dict): + return data return data -class Task(ABC): + +class TaskSettings(Settings): + pass + + +class MethodSettings(Settings): + method: typing.Annotated[SimpleKeyword, Method] + basis_set: typing.Annotated[SimpleKeyword, BasisSet] + solvation_model: typing.Annotated[SimpleKeyword, SolvationModel] + solvent: typing.Annotated[str, Solvent] + + +class SimpleTask(ABC): _results_type: type["TaskResults"] - _task_parameters: TaskParams + _task_settings: TaskSettings + _method_settings: MethodSettings @property - def task_parameters(self) -> TaskParams : - return self._task_parameters + def task_settings(self) -> TaskSettings: + return self._task_settings + + @property + def method_settings(self) -> MethodSettings: + return self._method_settings @property def input_object(self) -> Input: + """ + Creates configured `Input` object. First it initializes an empty instance of `Input` , and then passes it as + to corresponding `TaskSettings` and `MethodSettings` objects to be configured by user-defined data stored in those + objects. + + Returns + ------- + `Input` + `Input` object configured by user-defined data. + + """ inp = Input() - inp = self.task_parameters.map_to_input(input_object=inp) + inp = self._task_settings.map_to_input(input_object=inp) + inp = self.method_settings.map_to_input(input_object=inp) return inp - def __getattr__(self, name): - if name.startswith('_'): + def __getattr__(self, name: str) -> typing.Any: + """ + Dynamically resolve attribute access by delegating to internal settings objects. + + This method is called when an attribute is not found on the instance through + the normal lookup process. It attempts to retrieve the attribute from the + internal `_task_settings` and `_method_settings` objects, in that order. + + Parameters + ---------- + name: str + Attribute name. + + Returns + -------- + Any + Attribute value. + """ + if name.startswith("_"): raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") try: - return getattr(self._task_parameters, name) + return getattr(self._task_settings, name) + except AttributeError: + pass + + try: + return getattr(self._method_settings, name) except AttributeError: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: typing.Any) -> None: + """ + Dynamically assign attributes, delegating to internal settings objects when appropriate. + + This method overrides the default attribute assignment behavior to route + assignments to `_task_settings` or `_method_settings` if the attribute + exists. Otherwise, the attribute is set directly on the instance. + + Parameters + ---------- + name : str + The name of the attribute to assign. + value : Any + The value to assign to the attribute. + """ # Allow setting private attributes and special attributes normally - if name.startswith('_') or name in ('task_parameters', 'run'): + if name.startswith("_") or name in ("task_settings", "run"): super().__setattr__(name, value) else: - # Check if _task_parameters exists and has this attribute - if hasattr(self, '_task_parameters') and hasattr(self._task_parameters, name): - setattr(self._task_parameters, name, value) + # Check if _task_settings exists and has this attribute + if hasattr(self, "_task_settings") and hasattr(self._task_settings, name): + setattr(self._task_settings, name, value) + elif hasattr(self, "_method_settings") and hasattr(self._method_settings, name): + setattr(self._method_settings, name, value) else: super().__setattr__(name, value) @@ -148,6 +333,37 @@ def run( memory: int | None = None, moinp: Path | None = None, ) -> "TaskResults": + """ + Execute the computational task with the given structure and settings. + + This method prepares the working directory, configures the calculation + input parameters, and runs the calculation using an external calculator. + The results are returned as an instance of the configured results type. + + Parameters + ---------- + basename : str + Base name for the calculation. + struct : Structure or BaseStructureFile + The input structure for the calculation. + working_dir : pathlib.Path, optional + Directory in which the calculation will be executed. + If it exists, it will be removed and recreated. Defaults to "RUN". + ncores : int, optional + Number of CPU cores to use for the calculation. Overrides the default + value in the input object if provided. + memory : int, optional + Amount of memory to allocate for the calculation. Overrides the default + value in the input object if provided. + moinp : pathlib.Path, optional + Path to a molecular orbital input file. Overrides the default if provided. + + Returns + ------- + TaskResults + An instance of the configured results type containing the results + of the calculation. + """ # > recreate the working dir shutil.rmtree(working_dir, ignore_errors=True) @@ -172,26 +388,50 @@ def run( return self._results_type(calc_object=calc) - - def change_parameter(self, param: str, value: typing.Any) -> None: - setattr(self.task_parameters, param, value) - - - def restart(self, previous_results: "TaskResults", - basename: str|None = None, - struct: Structure| BaseStructureFile| None = None, + def _restart( + self, + previous_results: "TaskResults", + basename: str | None = None, + struct: Structure | BaseStructureFile | None = None, working_dir: Path | None = None, ncores: int | None = None, memory: int | None = None, moinp: Path | None = None, - use_previous_orbitals: bool = False ) -> "TaskResults": - - basename = basename if basename else previous_results.calc_object.basename - struct = struct if struct else previous_results.calc_object.structure - working_dir = working_dir if working_dir else previous_results.calc_object.working_dir - ncores = ncores if ncores else previous_results.calc_object.input.ncores - memory = memory if memory else previous_results.calc_object.input.memory - moinp = moinp if moinp else previous_results.calc_object.input.moinp + use_previous_orbitals: bool = False, + ) -> "TaskResults": + """ + TODO: + - finish restart implementation (low on priority list) + Parameters + ---------- + previous_results + basename + struct + working_dir + ncores + memory + moinp + use_previous_orbitals + + Returns + ------- + + """ + prev_calc = previous_results.calc_object + + basename = basename if basename else prev_calc.basename + struct = struct if struct else prev_calc.structure + working_dir = working_dir if working_dir else prev_calc.working_dir + ncores = ncores if ncores else prev_calc.input.ncores + memory = memory if memory else prev_calc.input.memory + + if use_previous_orbitals: + prev_gbw = prev_calc.working_dir / f"{prev_calc.basename}.gbw" + if not prev_gbw.exists(): + raise FileNotFoundError(f"GBW file not found: {prev_gbw}") + moinp = prev_gbw + else: + moinp = moinp if moinp else prev_calc.input.moinp return self.run(basename, struct, working_dir, ncores, memory, moinp) @@ -205,6 +445,26 @@ def __init__(self, calc_object: Calculator): def output_parse( func: typing.Callable[["TaskResults"], typing.Any], ) -> typing.Callable[["TaskResults"], typing.Any]: + """ + Decorator to ensure output parsing is performed before accessing results. + + This decorator wraps methods of a `TaskResults` instance and guarantees + that the associated output has been parsed before the method is executed. + Parsing is performed lazily and only once per instance. + + Parameters + ---------- + func : Callable[[TaskResults], Any] + The method to wrap. It must be a method of `TaskResults` that relies + on parsed output data. + + Returns + ------- + Callable[[TaskResults], Any] + A wrapped method that ensures `self.output.parse()` has been called + before delegating to the original function. + """ + @wraps(func) def wrapper(self: "TaskResults"): if not self._parsed: @@ -220,3 +480,12 @@ def output(self) -> Output: raise ValueError("calc_object not set") return self.calc_object.get_output() + + @cached_property + def status(self) -> bool: + return self.output.terminated_normally() and self.output.scf_converged() + + @cached_property + @abstractmethod + def primary_property(self) -> typing.Any: + pass From 0adf343881a44e494b3399a8bb3bbcad520fce77 Mon Sep 17 00:00:00 2001 From: Nakul Santhosh Date: Mon, 13 Apr 2026 16:39:31 +0200 Subject: [PATCH 09/10] changes from todays meeting --- src/opi/input/simple_keywords/base.py | 13 ++- src/opi/input/simple_keywords/grid.py | 6 +- src/opi/input/simple_keywords/scf.py | 114 +++++++++++++++----------- src/opi/tasks/method_settings.py | 101 ++++++++++++++++++++++- src/opi/tasks/singlepointtask.py | 15 ++-- src/opi/tasks/task_base.py | 57 +++++++++++-- 6 files changed, 234 insertions(+), 72 deletions(-) diff --git a/src/opi/input/simple_keywords/base.py b/src/opi/input/simple_keywords/base.py index 7366859b..16a897e4 100644 --- a/src/opi/input/simple_keywords/base.py +++ b/src/opi/input/simple_keywords/base.py @@ -24,13 +24,18 @@ def registry(cls) -> list[type["SimpleKeywordBox"]]: def from_string(cls, s: str) -> "SimpleKeyword": norm = s.lower() for c in cls._registry: - for attr, value in vars(c).items(): + for attr in dir(c): + if attr.startswith('_'): # Skip private/magic attributes + continue + value = getattr(c, attr) if isinstance(value, SimpleKeyword) and value.keyword.lower() == norm: return value elif isinstance(value, SimpleKeyword) and attr.lower() == norm: return value + elif isinstance(value, SimpleKeyword) and value.alias and value.alias.lower() == norm: + return value - raise ValueError(f"Keyword {s} not found.") + raise ValueError(f"Keyword {s} not found in class {cls.__name__}") @classmethod def find_keyword(cls, inp: "SimpleKeyword | str") -> "SimpleKeyword": @@ -52,12 +57,14 @@ class SimpleKeyword: keyword: str simple keyword as it will appear in the ORCA .inp file """ + alias: str | None = None - def __init__(self, keyword: str) -> None: + def __init__(self, keyword: str, alias:str|None=None) -> None: self._keyword: str = "" self.keyword = keyword self._name: str = "" # self.name = name + self.alias = alias @property def keyword(self) -> str: diff --git a/src/opi/input/simple_keywords/grid.py b/src/opi/input/simple_keywords/grid.py index bcb10704..7383c969 100644 --- a/src/opi/input/simple_keywords/grid.py +++ b/src/opi/input/simple_keywords/grid.py @@ -9,11 +9,11 @@ class Grid(SimpleKeywordBox): """Enum to store all simple keywords of type Grid.""" - DEFGRID1 = SimpleKeyword("defgrid1") + DEFGRID1 = SimpleKeyword("defgrid1", alias="1") """SimpleKeyword: small grid.""" - DEFGRID2 = SimpleKeyword("defgrid2") + DEFGRID2 = SimpleKeyword("defgrid2", alias="2") """SimpleKeyword: medium grid.""" - DEFGRID3 = SimpleKeyword("defgrid3") + DEFGRID3 = SimpleKeyword("defgrid3", alias="3") """SimpleKeyword: large grid.""" REFGRID = SimpleKeyword("refgrid") """SimpleKeyword: reference grid.""" diff --git a/src/opi/input/simple_keywords/scf.py b/src/opi/input/simple_keywords/scf.py index 0bd72599..f4d53929 100644 --- a/src/opi/input/simple_keywords/scf.py +++ b/src/opi/input/simple_keywords/scf.py @@ -5,8 +5,67 @@ __all__ = ("Scf",) +class ScfThreshold(SimpleKeywordBox): + SLOPPYSCF = SimpleKeyword("sloppyscf", alias="sloppy") + """SimpleKeyword: SCF convergence threshold settings.""" + LOOSESCF = SimpleKeyword("loosescf", alias="loose") + """SimpleKeyword: SCF convergence threshold settings.""" + NORMALSCF = SimpleKeyword("normalscf", alias="normal") + """SimpleKeyword: SCF convergence threshold settings.""" + STRONGSCF = SimpleKeyword("strongscf", alias="strong") + """SimpleKeyword: SCF convergence threshold settings.""" + TIGHTSCF = SimpleKeyword("tightscf", alias="tight") + """SimpleKeyword: SCF convergence threshold settings.""" + VERYTIGHTSCF = SimpleKeyword("verytightscf", alias="verytight") + """SimpleKeyword: SCF convergence threshold settings.""" + EXTREMESCF = SimpleKeyword("extremescf", alias="extreme") + """SimpleKeyword: SCF convergence threshold settings.""" + + +class ScfGuess(SimpleKeywordBox): + MOREAD = SimpleKeyword("moread") + """SimpleKeyword: SCF initial guess read orbitals from gbw file.""" + EHTANO = SimpleKeyword("ehtano") + """SimpleKeyword: SCF initial guess.""" + HCORE = SimpleKeyword("hcore") + """SimpleKeyword: SCF initial guess.""" + PATOM = SimpleKeyword("patom") + """SimpleKeyword: SCF initial guess.""" + PMODEL = SimpleKeyword("pmodel") + """SimpleKeyword: SCF initial guess.""" + PMODELX = SimpleKeyword("pmodelx") + """SimpleKeyword: SCF initial guess.""" + PMODELXAV = SimpleKeyword("pmodelxav") + """SimpleKeyword: SCF initial guess.""" + PMODELXPM = SimpleKeyword("pmodelxpm") + """SimpleKeyword: SCF initial guess.""" + SYMBREAKGUESS = SimpleKeyword("symbreakguess") + """SimpleKeyword: SCF initial guess.""" + + +class ScfConvergence(SimpleKeywordBox): + EASYCONV = SimpleKeyword("easyconv", alias="easy") + """SimpleKeyword: SCF convergence strategy.""" + NORMALCONV = SimpleKeyword("normalconv", alias="normal") + """SimpleKeyword: SCF convergence strategy.""" + SLOWCONV = SimpleKeyword("slowconv", alias="slow") + """SimpleKeyword: SCF convergence strategy.""" + VERYSLOWCONV = SimpleKeyword("veryslowconv", alias="veryslow") + """SimpleKeyword: SCF convergence strategy.""" + + +class ScfSolver(SimpleKeywordBox): + DIIS = SimpleKeyword("diis") + """SimpleKeyword: SCF solver.""" + KDIIS = SimpleKeyword("kdiis") + """SimpleKeyword: SCF solver.""" + SOSCF = SimpleKeyword("soscf") + """SimpleKeyword: SCF solver.""" + TRAH = SimpleKeyword("trah") + """SimpleKeyword: SCF solver.""" -class Scf(SimpleKeywordBox): + +class Scf(ScfThreshold, ScfConvergence, ScfGuess, ScfSolver): """Enum to store all simple keywords of type Scf.""" G3CONV = SimpleKeyword("3conv") @@ -17,50 +76,29 @@ class Scf(SimpleKeywordBox): """SimpleKeyword: SCF solver combination.""" KDIISTRAH = SimpleKeyword("kdiistrah") """SimpleKeyword: SCF solver combination.""" - DIIS = SimpleKeyword("diis") - """SimpleKeyword: SCF solver.""" + NODIIS = SimpleKeyword("nodiis") """SimpleKeyword: SCF solver.""" AODIIS = SimpleKeyword("aodiis") """SimpleKeyword: SCF solver.""" NOAODIIS = SimpleKeyword("noaodiis") """SimpleKeyword: SCF solver.""" - KDIIS = SimpleKeyword("kdiis") - """SimpleKeyword: SCF solver.""" + NOKDIIS = SimpleKeyword("nokdiis") """SimpleKeyword: SCF solver.""" - SOSCF = SimpleKeyword("soscf") - """SimpleKeyword: SCF solver.""" + NOSOSCF = SimpleKeyword("nososcf") """SimpleKeyword: SCF solver.""" - TRAH = SimpleKeyword("trah") - """SimpleKeyword: SCF solver.""" + NOTRAH = SimpleKeyword("notrah") """SimpleKeyword: SCF solver.""" AUTOSTART = SimpleKeyword("autostart") """SimpleKeyword: SCF initial guess start SCF from a gbw file with the same basename (default).""" NOAUTOSTART = SimpleKeyword("noautostart") """SimpleKeyword: SCF initial guess do not start SCF from a gbw file with the same basename.""" - MOREAD = SimpleKeyword("moread") - """SimpleKeyword: SCF initial guess read orbitals from gbw file.""" - EHTANO = SimpleKeyword("ehtano") - """SimpleKeyword: SCF initial guess.""" - HCORE = SimpleKeyword("hcore") - """SimpleKeyword: SCF initial guess.""" HUECKEL = SimpleKeyword("hueckel") """SimpleKeyword: SCF initial guess.""" - PATOM = SimpleKeyword("patom") - """SimpleKeyword: SCF initial guess.""" - PMODEL = SimpleKeyword("pmodel") - """SimpleKeyword: SCF initial guess.""" - PMODELX = SimpleKeyword("pmodelx") - """SimpleKeyword: SCF initial guess.""" - PMODELXAV = SimpleKeyword("pmodelxav") - """SimpleKeyword: SCF initial guess.""" - PMODELXPM = SimpleKeyword("pmodelxpm") - """SimpleKeyword: SCF initial guess.""" - SYMBREAKGUESS = SimpleKeyword("symbreakguess") - """SimpleKeyword: SCF initial guess.""" + UNITMATRIXGUESS = SimpleKeyword("unitmatrixguess") """SimpleKeyword: SCF initial guess.""" USEGRAMSCHMIDT = SimpleKeyword("usegramschmidt") @@ -91,20 +129,6 @@ class Scf(SimpleKeywordBox): """SimpleKeyword: enable fractional occupations.""" SCFCONVFORCED = SimpleKeyword("scfconvforced") """SimpleKeyword: Force SCF convergence for subsequent operations.""" - SLOPPYSCF = SimpleKeyword("sloppyscf") - """SimpleKeyword: SCF convergence threshold settings.""" - LOOSESCF = SimpleKeyword("loosescf") - """SimpleKeyword: SCF convergence threshold settings.""" - NORMALSCF = SimpleKeyword("normalscf") - """SimpleKeyword: SCF convergence threshold settings.""" - STRONGSCF = SimpleKeyword("strongscf") - """SimpleKeyword: SCF convergence threshold settings.""" - TIGHTSCF = SimpleKeyword("tightscf") - """SimpleKeyword: SCF convergence threshold settings.""" - VERYTIGHTSCF = SimpleKeyword("verytightscf") - """SimpleKeyword: SCF convergence threshold settings.""" - EXTREMESCF = SimpleKeyword("extremescf") - """SimpleKeyword: SCF convergence threshold settings.""" SLOPPYSCFCHECK = SimpleKeyword("sloppyscfcheck") """SimpleKeyword: SCF convergence threshold settings.""" NOSLOPPYSCFCHECK = SimpleKeyword("nosloppyscfcheck") @@ -125,14 +149,6 @@ class Scf(SimpleKeywordBox): """SimpleKeyword: SCF convergence threshold settings.""" SCFCONV12 = SimpleKeyword("scfconv12") """SimpleKeyword: SCF convergence threshold settings.""" - EASYCONV = SimpleKeyword("easyconv") - """SimpleKeyword: SCF convergence strategy.""" - NORMALCONV = SimpleKeyword("normalconv") - """SimpleKeyword: SCF convergence strategy.""" - SLOWCONV = SimpleKeyword("slowconv") - """SimpleKeyword: SCF convergence strategy.""" - VERYSLOWCONV = SimpleKeyword("veryslowconv") - """SimpleKeyword: SCF convergence strategy.""" DAMP = SimpleKeyword("damp") """SimpleKeyword: SCF settings.""" NODAMP = SimpleKeyword("nodamp") diff --git a/src/opi/tasks/method_settings.py b/src/opi/tasks/method_settings.py index 58a2f3ec..2e0dc603 100644 --- a/src/opi/tasks/method_settings.py +++ b/src/opi/tasks/method_settings.py @@ -1,9 +1,106 @@ import typing +import warnings -from opi.input.simple_keywords import Dft, SimpleKeyword +from pydantic import field_validator, model_validator + +from opi.input.simple_keywords import Dft, SimpleKeyword, Grid, DispersionCorrection +from opi.input.simple_keywords.scf import ScfThreshold, ScfSolver, Scf, ScfConvergence from opi.tasks.task_base import MethodSettings class DFTSettings(MethodSettings): - method: typing.Annotated[SimpleKeyword, Dft] _name: str = "dft" + method: typing.Annotated[SimpleKeyword, Dft] + grid: typing.Annotated[SimpleKeyword, Grid] | None = None + scf_maxiter: typing.Annotated[int, "BlockScf", "maxiter"] | None = None + scf_threshold: typing.Annotated[SimpleKeyword, ScfThreshold] | None = None + scf_solver: typing.Annotated[SimpleKeyword, ScfSolver] | None = None + scf_stab: typing.Annotated[SimpleKeyword, Scf] | None = None + scf_conv: typing.Annotated[SimpleKeyword, ScfConvergence] | None = None + + + @field_validator("*", mode="before") + @classmethod + def validate_fields(cls, value: typing.Any, info): + if info.field_name == "scf_stab": + if value: + return Scf.SCFSTAB + else: + return None + if info.field_name == "method": + try: + new_keyword = Dft.find_keyword(value) + except ValueError: + new_keyword = cls._find_dft_disp_keyword(value) + return new_keyword + else: + return super().validate_fields(value, info) + + + + @model_validator(mode="after") + @classmethod + def cross_validate(cls, data: "DFTSettings") -> "DFTSettings": + """ + Cross-validation for `DftSettings`. + If the method keyword contains '3c', the `basis_set` attribute will be set to `None`. + + The `DftSettings` object is then returned. + Parameters + ---------- + data: DFTSettings + `DFTSettings` object given as input. + + Returns + ------- + DFTSettings + Cross-validated `DFTSettings` object. + + """ + if "3c" in data.method.keyword and data.basis_set: + warnings.warn("Basis Set will be ignored due to selection of Method", UserWarning) + data.basis_set = None + + return data + + @classmethod + def _find_dft_disp_keyword(cls, value: str | SimpleKeyword) -> SimpleKeyword: + """ + Function to search for a `Dft` keyword along with `DispersionCorrection`. + In the case that an - is present in keyword, the keyword is split along the - and it is verified whether the + given keyword is a valid combination of `Dft` and `DispersionCorrection` keyword. + + If it is , a `SimpleKeyword` object is created and returned. If not , a `ValueError` is raised. + Parameters + ---------- + value: str | SimpleKeyword + The value to search for. + + Returns + ------- + SimpleKeyword + The created `SimpleKeyword` object. + + Raises + ------ + ValueError + If given value is invalid + + """ + if isinstance(value, SimpleKeyword): + value = value.keyword + + if '-' in value: + try: + keywords = value.split('-') + Dft.find_keyword(keywords[0]) + DispersionCorrection.find_keyword(keywords[1]) + + return SimpleKeyword(value) + except ValueError: + raise ValueError(f"Invalid Dft keyword or dispersion correction given: {value}") + else: + raise ValueError(f"Invalid Dft keyword '{value}'") + + + diff --git a/src/opi/tasks/singlepointtask.py b/src/opi/tasks/singlepointtask.py index 84501b5d..9d98de5d 100644 --- a/src/opi/tasks/singlepointtask.py +++ b/src/opi/tasks/singlepointtask.py @@ -1,7 +1,7 @@ import typing from pathlib import Path -from opi.input.simple_keywords import SimpleKeyword, SolvationModel, Solvent, Task +from opi.input.simple_keywords import SimpleKeyword, Solvent, Task from opi.input.structures import BaseStructureFile, Structure from opi.tasks.method_settings import DFTSettings from opi.tasks.task_base import SimpleTask, TaskResults, TaskSettings @@ -16,17 +16,18 @@ class SinglePointTask(SimpleTask): def __init__( self, method: str | SimpleKeyword, - basis_set: str | SimpleKeyword, - solvation_model: str | SolvationModel, - solvent: str | Solvent, + basis_set: str | SimpleKeyword | None = None, + solvation_model: str | SimpleKeyword | None = None, + solvent: str | Solvent | None = None, task: str | SimpleKeyword | None = None, ): self._method_settings = DFTSettings( method=method, basis_set=basis_set, solvation_model=solvation_model, solvent=solvent ) self._task_settings = ( - SinglePointSettings(task_keyword=task) if task else SinglePointSettings() - ) + SinglePointSettings(task_keyword=task) + ) if task else SinglePointSettings() + self._results_type = SinglePointResults def run( @@ -37,6 +38,7 @@ def run( ncores: int | None = None, memory: int | None = None, moinp: Path | None = None, + strict: bool = False ) -> "SinglePointResults": single_point_result = super().run( basename=basename, @@ -45,6 +47,7 @@ def run( ncores=ncores, memory=memory, moinp=moinp, + strict=strict ) return typing.cast(SinglePointResults, single_point_result) diff --git a/src/opi/tasks/task_base.py b/src/opi/tasks/task_base.py index 88f7cadd..a3ce0add 100644 --- a/src/opi/tasks/task_base.py +++ b/src/opi/tasks/task_base.py @@ -56,7 +56,12 @@ def _get_field_metadata(hint: typing.Any) -> tuple: Tuple of metadata about the field. """ + origin = typing.get_origin(hint) args = typing.get_args(hint) + if origin is typing.Union: + non_none_args = [arg for arg in args if arg is not type(None)] + if non_none_args: + return Settings._get_field_metadata(non_none_args[0]) return args[1:] if len(args) > 1 else () @staticmethod @@ -139,9 +144,10 @@ def map_to_input(self, input_object: Input) -> Input: for field_name, field_type in hints.items(): value = getattr(self, field_name) + if value is None: + continue - args = typing.get_args(field_type) - metadata = args[1:] + metadata = self._get_field_metadata(field_type) match metadata: case (validator,): @@ -149,7 +155,8 @@ def map_to_input(self, input_object: Input) -> Input: continue new_keyword = self._get_simple_keyword(validator, value) - input_object.add_simple_keywords(new_keyword) + if new_keyword: + input_object.add_simple_keywords(new_keyword) case (validator, key): block_type = Block.get_subclass_by_name(validator) @@ -191,6 +198,9 @@ def validate_fields(cls, value: typing.Any, info): Any User input value processed and converted to OPI compatible types. """ + if value is None: + return value + hints = typing.get_type_hints(cls, include_extras=True) if info.field_name not in hints: @@ -230,9 +240,9 @@ class TaskSettings(Settings): class MethodSettings(Settings): method: typing.Annotated[SimpleKeyword, Method] - basis_set: typing.Annotated[SimpleKeyword, BasisSet] - solvation_model: typing.Annotated[SimpleKeyword, SolvationModel] - solvent: typing.Annotated[str, Solvent] + basis_set: typing.Annotated[SimpleKeyword, BasisSet] | None = None + solvation_model: typing.Annotated[SimpleKeyword, SolvationModel] | None = None + solvent: typing.Annotated[str, Solvent] | None = None class SimpleTask(ABC): @@ -332,6 +342,7 @@ def run( ncores: int | None = None, memory: int | None = None, moinp: Path | None = None, + strict: bool = False ) -> "TaskResults": """ Execute the computational task with the given structure and settings. @@ -357,6 +368,8 @@ def run( value in the input object if provided. moinp : pathlib.Path, optional Path to a molecular orbital input file. Overrides the default if provided. + strict : bool, optional + Controls whether working directory will be created/overwritten. Defaults to False. Returns ------- @@ -364,10 +377,20 @@ def run( An instance of the configured results type containing the results of the calculation. """ + if strict: + # Must already exist + if not working_dir.exists(): + raise ValueError(f"Working directory {working_dir.resolve()} does not exist (strict mode)") + + # Must be empty + if any(working_dir.iterdir()): + raise ValueError(f"Working directory {working_dir.resolve()} is not empty (strict mode)") - # > recreate the working dir - shutil.rmtree(working_dir, ignore_errors=True) - working_dir.mkdir() + else: + # Non-strict: recreate directory + if working_dir.exists(): + shutil.rmtree(working_dir) + working_dir.mkdir() inp = self.input_object @@ -489,3 +512,19 @@ def status(self) -> bool: @abstractmethod def primary_property(self) -> typing.Any: pass + + def __getattr__(self, name): + """ + First tries to get attribute from the object itself. + If not found, tries to get it from self.output. + """ + # Check if 'output' exists to avoid infinite recursion + if name == 'output': + raise AttributeError(f"'{type(self).__name__}' object has no attribute 'output'") + + # Try to get the attribute from self.output + try: + return getattr(self.output, name) + except AttributeError: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + From c009bfcffc4be223f1efdcfd49ce8c65c102a2d3 Mon Sep 17 00:00:00 2001 From: Nakul Santhosh Date: Wed, 15 Apr 2026 10:21:41 +0200 Subject: [PATCH 10/10] flesh out the different tasks --- src/opi/input/simple_keywords/base.py | 4 ++ src/opi/input/simple_keywords/opt.py | 27 +++++---- src/opi/tasks/engrad_task.py | 53 +++++++++++++++++ src/opi/tasks/freq_task.py | 49 ++++++++++++++++ src/opi/tasks/method_settings.py | 16 +++-- src/opi/tasks/opt_task.py | 84 +++++++++++++++++++++++++++ src/opi/tasks/task_base.py | 6 +- 7 files changed, 220 insertions(+), 19 deletions(-) create mode 100644 src/opi/tasks/engrad_task.py create mode 100644 src/opi/tasks/freq_task.py create mode 100644 src/opi/tasks/opt_task.py diff --git a/src/opi/input/simple_keywords/base.py b/src/opi/input/simple_keywords/base.py index 16a897e4..f71e03a7 100644 --- a/src/opi/input/simple_keywords/base.py +++ b/src/opi/input/simple_keywords/base.py @@ -4,6 +4,10 @@ class SimpleKeywordBox: + """ + TODO: + - rework registry to account for latest changes. + """ _registry: list[type["SimpleKeywordBox"]] = [] def __init_subclass__(cls, **kwargs: Any) -> None: diff --git a/src/opi/input/simple_keywords/opt.py b/src/opi/input/simple_keywords/opt.py index f9f20161..90278607 100644 --- a/src/opi/input/simple_keywords/opt.py +++ b/src/opi/input/simple_keywords/opt.py @@ -6,24 +6,27 @@ __all__ = ("Opt",) -class Opt(SimpleKeywordBox): - """Enum to store all simple keywords of type Opt.""" - - OPT = SimpleKeyword("opt") - """SimpleKeyword: Perform a geometry optimization.""" - CRUDEOPT = SimpleKeyword("crudeopt") +class OptThreshold(SimpleKeywordBox): + CRUDEOPT = SimpleKeyword("crudeopt", alias="crude") """SimpleKeyword: Geometry optimization with thresholds.""" - INTERPOPT = SimpleKeyword("interpopt") + LOOSEOPT = SimpleKeyword("looseopt", alias="loose") """SimpleKeyword: Geometry optimization with thresholds.""" - LOOSEOPT = SimpleKeyword("looseopt") + NORMALOPT = SimpleKeyword("normalopt", alias="normal") """SimpleKeyword: Geometry optimization with thresholds.""" - NORMALOPT = SimpleKeyword("normalopt") + SLOPPYOPT = SimpleKeyword("sloppyopt", alias="sloppy") """SimpleKeyword: Geometry optimization with thresholds.""" - SLOPPYOPT = SimpleKeyword("sloppyopt") + TIGHTOPT = SimpleKeyword("tightopt", alias="tight") """SimpleKeyword: Geometry optimization with thresholds.""" - TIGHTOPT = SimpleKeyword("tightopt") + VERYTIGHTOPT = SimpleKeyword("verytightopt", alias="verytight") """SimpleKeyword: Geometry optimization with thresholds.""" - VERYTIGHTOPT = SimpleKeyword("verytightopt") + + +class Opt(SimpleKeywordBox): + """Enum to store all simple keywords of type Opt.""" + + OPT = SimpleKeyword("opt") + """SimpleKeyword: Perform a geometry optimization.""" + INTERPOPT = SimpleKeyword("interpopt") """SimpleKeyword: Geometry optimization with thresholds.""" OPTH = SimpleKeyword("opth") """SimpleKeyword: Optimize only hydrogen atoms.""" diff --git a/src/opi/tasks/engrad_task.py b/src/opi/tasks/engrad_task.py new file mode 100644 index 00000000..921c4cfc --- /dev/null +++ b/src/opi/tasks/engrad_task.py @@ -0,0 +1,53 @@ +import typing + +from opi.input.simple_keywords import Task, SimpleKeyword, Solvent +from opi.tasks.method_settings import DFTSettings +from opi.tasks.task_base import SimpleTask, TaskSettings, TaskResults + + +class EngradSettings(TaskSettings): + _name: str = "engrad" + task_keyword: typing.Annotated[SimpleKeyword, Task] = Task.ENGRAD + + +class EngradTask(SimpleTask): + def __init__(self, + method: str | SimpleKeyword, + basis_set: str | SimpleKeyword | None = None, + solvation_model: str | SimpleKeyword | None = None, + solvent: str | Solvent | None = None, + task: str | SimpleKeyword | None = None): + self._method_settings = DFTSettings( + method=method, basis_set=basis_set, solvation_model=solvation_model, solvent=solvent + ) + self._task_settings = ( + EngradSettings(task_keyword=task) + ) if task else EngradSettings() + + self._results_type = EngradResults + + +class EngradResults(TaskResults): + @property + @TaskResults.output_parse + def final_energy(self) -> float: + final_energy = self.output.get_final_energy() + + if final_energy is None: + raise ValueError("Could not get final energy from ORCA Output") + + return final_energy + + @property + @TaskResults.output_parse + def gradient(self) -> list[float]: + gradient = self.output.get_gradient() + + if gradient is None: + raise ValueError("Could not get gradient from ORCA Output") + + return gradient + + @property + def primary_property(self) -> tuple[float, list[float]]: + return self.final_energy, self.gradient \ No newline at end of file diff --git a/src/opi/tasks/freq_task.py b/src/opi/tasks/freq_task.py new file mode 100644 index 00000000..6a453941 --- /dev/null +++ b/src/opi/tasks/freq_task.py @@ -0,0 +1,49 @@ +import typing +from functools import cached_property + +from opi.input.simple_keywords import Task, SimpleKeyword, Solvent +from opi.tasks.method_settings import DFTSettings +from opi.tasks.task_base import TaskSettings, TaskResults, SimpleTask + + +class FreqSettings(TaskSettings): + _name: str = "freq" + task_keyword: typing.Annotated[SimpleKeyword, Task] = Task.FREQ + + +class FreqTask(SimpleTask): + def __init__(self, + method: str | SimpleKeyword, + basis_set: str | SimpleKeyword | None = None, + solvation_model: str | SimpleKeyword | None = None, + solvent: str | Solvent | None = None, + task: str | SimpleKeyword | None = None): + self._method_settings = DFTSettings( + method=method, basis_set=basis_set, solvation_model=solvation_model, solvent=solvent + ) + self._task_settings = ( + FreqSettings(task_keyword=task) + ) if task else FreqSettings() + + self._results_type = FreqResults + + +class FreqResults(TaskResults): + @cached_property + def status(self) -> bool: + return self.output.terminated_normally() + + @cached_property + @TaskResults.output_parse + def free_energy_delta(self) -> float: + free_energy_delta = self.output.get_free_energy_delta() + + if not free_energy_delta: + raise ValueError("Could not get free energy delta from ORCA output") + + return free_energy_delta + + @property + def primary_property(self) -> float: + return self.free_energy_delta + diff --git a/src/opi/tasks/method_settings.py b/src/opi/tasks/method_settings.py index 2e0dc603..273a4cec 100644 --- a/src/opi/tasks/method_settings.py +++ b/src/opi/tasks/method_settings.py @@ -3,6 +3,7 @@ from pydantic import field_validator, model_validator +from opi.input import Input from opi.input.simple_keywords import Dft, SimpleKeyword, Grid, DispersionCorrection from opi.input.simple_keywords.scf import ScfThreshold, ScfSolver, Scf, ScfConvergence from opi.tasks.task_base import MethodSettings @@ -15,18 +16,13 @@ class DFTSettings(MethodSettings): scf_maxiter: typing.Annotated[int, "BlockScf", "maxiter"] | None = None scf_threshold: typing.Annotated[SimpleKeyword, ScfThreshold] | None = None scf_solver: typing.Annotated[SimpleKeyword, ScfSolver] | None = None - scf_stab: typing.Annotated[SimpleKeyword, Scf] | None = None + scf_stab: bool = False scf_conv: typing.Annotated[SimpleKeyword, ScfConvergence] | None = None @field_validator("*", mode="before") @classmethod def validate_fields(cls, value: typing.Any, info): - if info.field_name == "scf_stab": - if value: - return Scf.SCFSTAB - else: - return None if info.field_name == "method": try: new_keyword = Dft.find_keyword(value) @@ -63,6 +59,14 @@ def cross_validate(cls, data: "DFTSettings") -> "DFTSettings": return data + def map_to_input(self, input_object: Input) -> Input: + input_object = super().map_to_input(input_object) + + if self.scf_stab: + input_object.add_simple_keywords(Scf.SCFSTAB) + + return input_object + @classmethod def _find_dft_disp_keyword(cls, value: str | SimpleKeyword) -> SimpleKeyword: """ diff --git a/src/opi/tasks/opt_task.py b/src/opi/tasks/opt_task.py new file mode 100644 index 00000000..08d3a6d3 --- /dev/null +++ b/src/opi/tasks/opt_task.py @@ -0,0 +1,84 @@ +import typing + +from opi.input import Input +from opi.input.simple_keywords import Task, SimpleKeyword, Solvent +from opi.input.simple_keywords.opt import OptThreshold, Opt +from opi.input.structures import Structure +from opi.tasks.method_settings import DFTSettings +from opi.tasks.task_base import TaskSettings, TaskResults, SimpleTask + + +class OptSettings(TaskSettings): + _name:str = "opt" + task_keyword: typing.Annotated[SimpleKeyword, Task] = Task.OPT + opt_threshold: typing.Annotated[SimpleKeyword, OptThreshold] | None = None + optrigid: bool = False + opt_h: bool = False + lopt:bool = False + opt_maxiter: typing.Annotated[int, "BlockGeom", "maxiter"] | None = None + + def map_to_input(self, input_object: Input) -> Input: + input_object = super().map_to_input(input_object) + + if self.optrigid: + input_object.add_simple_keywords(Opt.RIGIDBODYOPT) + + opt_map = { + (True, True): Opt.L_OPTH, + (True, False): Opt.OPTH, + (False, True): Opt.L_OPT, + } + + keyword = opt_map.get((self.opt_h, self.lopt)) + if keyword: + input_object.add_simple_keywords(keyword) + + return input_object + + +class OptTask(SimpleTask): + def __init__(self, + method: str | SimpleKeyword, + basis_set: str | SimpleKeyword | None = None, + solvation_model: str | SimpleKeyword | None = None, + solvent: str | Solvent | None = None, + task: str | SimpleKeyword | None = None): + self._method_settings = DFTSettings( + method=method, basis_set=basis_set, solvation_model=solvation_model, solvent=solvent + ) + self._task_settings = ( + OptSettings(task_keyword=task) + ) if task else OptSettings() + + self._results_type = OptResults + + +class OptResults(TaskResults): + @property + @TaskResults.output_parse + def final_energy(self) -> float: + final_energy = self.output.get_final_energy() + + if final_energy is None: + raise ValueError("Could not get final energy from ORCA Output") + + return final_energy + + + @property + @TaskResults.output_parse + def structure(self) -> Structure: + structure = self.output.get_structure() + if structure is None: + raise ValueError("Could not get structure from ORCA Output") + + return structure + + + @property + def primary_property(self) -> tuple[float, Structure]: + return self.final_energy, self.structure + + + + diff --git a/src/opi/tasks/task_base.py b/src/opi/tasks/task_base.py index a3ce0add..c9af038f 100644 --- a/src/opi/tasks/task_base.py +++ b/src/opi/tasks/task_base.py @@ -22,6 +22,10 @@ class Settings(BaseModel): + """ + TODO: + - add checking for Solvent and SolvationModel now that they are optional. + """ model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) _name: str @@ -235,7 +239,7 @@ def cross_validate(cls, data: dict[str, typing.Any]) -> dict[str, typing.Any]: class TaskSettings(Settings): - pass + task_keyword: typing.Annotated[SimpleKeyword, SimpleKeywordBox] class MethodSettings(Settings):