From ea4d8d82a253ba5fbc7e107f7b0e76a83f701f38 Mon Sep 17 00:00:00 2001 From: crjacinto Date: Fri, 17 Apr 2026 13:59:30 +0200 Subject: [PATCH 1/2] RMSD and Rotationl Constants Modules --- src/opi/utils/rmsd.py | 132 +++++++++++++++ src/opi/utils/rotconst.py | 345 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 477 insertions(+) create mode 100644 src/opi/utils/rmsd.py create mode 100644 src/opi/utils/rotconst.py diff --git a/src/opi/utils/rmsd.py b/src/opi/utils/rmsd.py new file mode 100644 index 00000000..221546d1 --- /dev/null +++ b/src/opi/utils/rmsd.py @@ -0,0 +1,132 @@ +import numpy as np +import os +from typing import Tuple, List, Union + + +def read_xyz( + data: Union[str, Tuple[List[str], np.ndarray]] +) -> Tuple[List[str], np.ndarray]: + """ + Read geometry from: + - XYZ file path + - XYZ block string + - (symbols, coords) tuple + + Returns + ------- + symbols : list[str] + coords : (N, 3) ndarray + """ + + # ------------------------- + # Case 1: string input + # ------------------------- + if isinstance(data, str): + + # Case 1a: a file path + if os.path.isfile(data): + with open(data, "r") as f: + lines = f.readlines() + + # Case 1b: an XYZ block string + else: + lines = data.strip().splitlines() + + n_atoms = int(lines[0].strip()) + symbols = [] + coords = [] + + for line in lines[2:2 + n_atoms]: + parts = line.split() + sym = parts[0] + xyz = [float(x) for x in parts[1:4]] + + symbols.append(sym) + coords.append(xyz) + + return symbols, np.array(coords, dtype=np.float64) + + # ------------------------- + # Case 2: in-memory tuple + # ------------------------- + elif isinstance(data, tuple) and len(data) == 2: + symbols, coords = data + + coords = np.asarray(coords, dtype=np.float64) + + if coords.ndim != 2 or coords.shape[1] != 3: + raise ValueError("Coordinates must be an (N,3) array") + + if len(symbols) != len(coords): + raise ValueError("Symbols and coordinates must have same length") + + return list(symbols), coords + + else: + raise TypeError( + "Input must be: file path, XYZ string, or (symbols, coords)" + ) + + +def _validate_geometries(symA: List[str], symB: List[str]) -> None: + if len(symA) != len(symB): + raise ValueError("Geometries have different number of atoms") + + for i, (a, b) in enumerate(zip(symA, symB)): + if a != b: + raise ValueError(f"Atom mismatch at index {i}: {a} != {b}") + + +def kabsch_rmsd( + ref_xyz: str, + target_xyz: str, + *, + align: bool = True, +) -> float: + """Compute RMSD between two XYZ geometries. + + Parameters + ---------- + ref_xyz : str + Path to reference geometry. + target_xyz : str + Path to target geometry. + align : bool, default True + Whether to perform optimal alignment (Kabsch). + + Returns + ------- + float + RMSD (Å) + """ + + symA, A = read_xyz(ref_xyz) + symB, B = read_xyz(target_xyz) + + _validate_geometries(symA, symB) + + # Center using centroid (simple average) + A_cent = A - A.mean(axis=0) + B_cent = B - B.mean(axis=0) + + if not align: + diff = A_cent - B_cent + return float(np.sqrt(np.sum(diff**2) / len(A_cent))) + + # Standard Kabsch covariance matrix + H = B_cent.T @ A_cent + + U, _, Vt = np.linalg.svd(H) + + d = np.linalg.det(Vt.T @ U.T) + D = np.diag([1.0, 1.0, d]) + + R = Vt.T @ D @ U.T + + B_rot = B_cent @ R + + diff = A_cent - B_rot + + rmsd = np.sqrt(np.sum(diff**2) / len(A_cent)) + + return float(rmsd) diff --git a/src/opi/utils/rotconst.py b/src/opi/utils/rotconst.py new file mode 100644 index 00000000..d43ad295 --- /dev/null +++ b/src/opi/utils/rotconst.py @@ -0,0 +1,345 @@ +import numpy as np +import os +from dataclasses import dataclass +import warnings +from typing import Iterable, Union, Tuple, List + + +# ============================================================ +# Physical constants +# ============================================================ +_AMU_TO_KG = 1.66053906660e-27 +_ANGSTROM_TO_M = 1.0e-10 +_H_PLANCK = 6.62607015e-34 +_C_CM = 2.99792458e10 + + +# ============================================================ +# Atomic masses +# ============================================================ +ATOMIC_MASSES = { + "X": 0.0, + "PointCharge": 0.0, + + "H": 1.008, "He": 4.003, + "Li": 6.941, "Be": 9.012, "B": 10.810, "C": 12.011, "N": 14.007, + "O": 15.999, "F": 18.998, "Ne": 20.179, + + "Na": 22.990, "Mg": 24.305, "Al": 26.982, "Si": 28.086, + "P": 30.974, "S": 32.060, "Cl": 35.453, "Ar": 39.948, + + "K": 39.100, "Ca": 40.080, + + "Sc": 44.960, "Ti": 47.900, "V": 50.940, "Cr": 52.000, + "Mn": 54.940, "Fe": 55.850, "Co": 58.930, "Ni": 58.700, + "Cu": 63.550, "Zn": 65.380, + + "Ga": 69.720, "Ge": 72.590, "As": 74.920, "Se": 78.960, + "Br": 79.900, "Kr": 83.800, + + "Rb": 85.479, "Sr": 87.620, + + "Y": 88.910, "Zr": 91.220, "Nb": 92.910, "Mo": 95.940, + "Tc": 97.000, "Ru": 101.070, "Rh": 102.910, "Pd": 106.400, + "Ag": 107.870, "Cd": 112.410, + + "In": 114.820, "Sn": 118.690, "Sb": 121.750, "Te": 127.600, + "I": 126.900, "Xe": 131.300, + + "Cs": 132.9054, "Ba": 137.3300, + + "La": 138.9055, "Ce": 140.1200, "Pr": 140.9077, "Nd": 144.2400, + "Pm": 145.0000, "Sm": 150.4000, "Eu": 151.9600, "Gd": 157.2500, + "Tb": 158.9254, "Dy": 162.5000, "Ho": 164.9304, "Er": 167.2600, + "Tm": 168.9342, "Yb": 173.0400, "Lu": 174.9670, + + "Hf": 178.4900, "Ta": 180.9479, "W": 183.8500, "Re": 186.2070, + "Os": 190.2000, "Ir": 192.2200, "Pt": 195.0900, "Au": 196.9665, + "Hg": 200.5900, + + "Tl": 204.3700, "Pb": 207.2000, "Bi": 208.9804, "Po": 209.0000, + "At": 210.0000, "Rn": 222.0000, + + "Fr": 223.0000, "Ra": 226.0254, + + "Ac": 227.0278, "Th": 232.0381, "Pa": 231.0359, "U": 238.0290, + "Np": 237.0482, "Pu": 244.0000, "Am": 243.0000, "Cm": 247.0000, + "Bk": 247.0000, "Cf": 251.0000, "Es": 252.0000, "Fm": 257.0000, + "Md": 258.0000, "No": 259.0000, "Lr": 262.0000, + + "Rf": 267.0000, "Db": 268.0000, "Sg": 269.0000, "Bh": 270.0000, + "Hs": 269.0000, "Mt": 278.0000, "Ds": 281.0000, "Rg": 281.0000, + "Cn": 285.0000, + + "Nh": 283.0000, "Fl": 289.0000, "Mc": 288.0000, + "Lv": 293.0000, "Ts": 294.0000, "Og": 294.0000, +} + +# ============================================================ +# Dataclass +# ============================================================ +@dataclass +class RotationalConstants: + A: float | None + B: float | None + C: float | None + A_cm: float | None + B_cm: float | None + C_cm: float | None + moments: tuple[float, float, float] + rotor_type: str + + def __str__(self) -> str: + def fmt(x, unit=""): + return f"{x:.6f} {unit}" if x is not None else "None" + + return ( + # "Rotational Spectrum\n" + # "--------------------\n" + f"Rotor type : {self.rotor_type}\n\n" + "Moments of inertia (amu·Å²):\n" + f" Ia = {self.moments[0]:.6f}\n" + f" Ib = {self.moments[1]:.6f}\n" + f" Ic = {self.moments[2]:.6f}\n\n" + "Rotational constants:\n" + f" A = {fmt(self.A, 'MHz')} ({fmt(self.A_cm, 'cm⁻¹')})\n" + f" B = {fmt(self.B, 'MHz')} ({fmt(self.B_cm, 'cm⁻¹')})\n" + f" C = {fmt(self.C, 'MHz')} ({fmt(self.C_cm, 'cm⁻¹')})" + ) + + +# ============================================================ +# Utilities +# ============================================================ +def _normalize_symbol(s: str) -> str: + if s in ("X", "PointCharge"): + return s + return s.capitalize() + + +def _read_xyz( + data: Union[str, Tuple[List[str], np.ndarray]] +) -> Tuple[List[str], np.ndarray]: + """ + Read geometry from: + - XYZ file path + - XYZ block string + - (symbols, coords) tuple + + Returns + ------- + symbols : list[str] + coords : (N, 3) ndarray + """ + + # ------------------------- + # Case 1: string input + # ------------------------- + if isinstance(data, str): + + # Case 1a: file path + if os.path.isfile(data): + with open(data, "r") as f: + lines = f.readlines() + else: + # Case 1b: XYZ block string + lines = data.strip().splitlines() + + if len(lines) < 3: + raise ValueError("Invalid XYZ format: too few lines") + + try: + n_atoms = int(lines[0].strip()) + except ValueError: + raise ValueError("First line must contain number of atoms") + + symbols = [] + coords = [] + + for i, line in enumerate(lines[2:2 + n_atoms], start=3): + parts = line.split() + + if len(parts) < 4: + raise ValueError(f"Line {i} malformed: '{line}'") + + # --- clean symbol --- + raw_sym = parts[0].strip() + + # Remove numeric labels (e.g. C1 → C) + sym = ''.join(filter(str.isalpha, raw_sym)) + + # Capitalization (cap insensitive) + sym = sym.capitalize() + + try: + xyz = [float(x) for x in parts[1:4]] + except ValueError: + raise ValueError(f"Invalid coordinates at line {i}: '{line}'") + + symbols.append(sym) + coords.append(xyz) + + return symbols, np.array(coords, dtype=np.float64) + + # ------------------------- + # Case 2: in-memory tuple + # ------------------------- + elif isinstance(data, tuple) and len(data) == 2: + symbols, coords = data + + coords = np.asarray(coords, dtype=np.float64) + + if coords.ndim != 2 or coords.shape[1] != 3: + raise ValueError("Coordinates must be an (N,3) array") + + if len(symbols) != len(coords): + raise ValueError("Symbols and coordinates must have same length") + + # Clean symbols here as well (consistency!) + clean_symbols = [] + for s in symbols: + s = ''.join(filter(str.isalpha, str(s))) + clean_symbols.append(s.capitalize()) + + return clean_symbols, coords + + else: + raise TypeError( + "Input must be: file path, XYZ string, or (symbols, coords)" + ) + + +# ============================================================ +# Main function +# ============================================================ +def rotational_constants( + symbols: list[str] | None = None, + coords: np.ndarray | None = None, + xyz: str | None = None, + masses: np.ndarray | None = None, + weights: dict[str, float] | None = None, + atom_weights: dict[int, float] | None = None, +) -> RotationalConstants | None: + """ + Flexible rotational constant calculator. + + Input options + ------------- + - symbols + coords + - xyz (file path, string, or lines) + + Mass priority + ------------- + masses > atom_weights > weights > default + + Unknown atoms + ------------- + Assigned mass = 0 with warning (unless overridden). + """ + + # --- Input parsing --- + if xyz is not None: + symbols, coords = _read_xyz(xyz) + + if symbols is None or coords is None: + raise ValueError("Provide either (symbols, coords) or xyz input.") + + coords = np.asarray(coords, dtype=np.float64) + + # --- Normalize symbols --- + symbols = [_normalize_symbol(s) for s in symbols] + + # --- Prepare weights --- + weights = { _normalize_symbol(k): v for k, v in (weights or {}).items() } + atom_weights = atom_weights or {} + + # --- Assign masses --- + if masses is not None: + masses = np.asarray(masses, dtype=np.float64) + + else: + masses_list = [] + for i, s in enumerate(symbols): + + if i in atom_weights: + m = atom_weights[i] + + elif s in weights: + m = weights[s] + + elif s in ATOMIC_MASSES: + m = ATOMIC_MASSES[s] + + else: + warnings.warn(f"Unknown element '{s}' → mass set to 0.0") + m = 0.0 + + masses_list.append(m) + + masses = np.array(masses_list, dtype=np.float64) + + # --- Filter zero-mass atoms --- + mask = masses > 0.0 + if not np.any(mask): + return None + + masses = masses[mask] + coords = coords[mask] + + total_mass = masses.sum() + + # --- Center of mass --- + com = (masses[:, None] * coords).sum(axis=0) / total_mass + coords -= com + + # --- Inertia tensor --- + inertia = np.zeros((3, 3), dtype=np.float64) + for m, r in zip(masses, coords): + inertia += m * (np.dot(r, r) * np.eye(3) - np.outer(r, r)) + + # --- Diagonalize --- + moments_raw, _ = np.linalg.eigh(inertia) + moments_raw = np.maximum(moments_raw, 0.0) + Ia, Ib, Ic = moments_raw + + # --- Convert to rotational constants --- + def _moment_to_mhz(I): + if I < 1e-6: + return None + I_si = I * _AMU_TO_KG * (_ANGSTROM_TO_M ** 2) + return _H_PLANCK / (8.0 * np.pi**2 * I_si) / 1e6 + + def _mhz_to_cm(mhz): + return None if mhz is None else mhz * 1e6 / _C_CM + + A = _moment_to_mhz(Ia) + B = _moment_to_mhz(Ib) + C = _moment_to_mhz(Ic) + + # --- Rotor classification --- + tol = 1e-3 + n_zero = sum(m < 1e-6 for m in (Ia, Ib, Ic)) + + if n_zero == 3: + rotor = "monoatomic" + elif n_zero == 2: + rotor = "linear" + elif abs(Ia - Ib) < tol and abs(Ib - Ic) < tol: + rotor = "spherical top" + elif abs(Ia - Ib) < tol: + rotor = "symmetric top (oblate)" + elif abs(Ib - Ic) < tol: + rotor = "symmetric top (prolate)" + else: + rotor = "asymmetric top" + + return RotationalConstants( + A=A, + B=B, + C=C, + A_cm=_mhz_to_cm(A), + B_cm=_mhz_to_cm(B), + C_cm=_mhz_to_cm(C), + moments=(Ia, Ib, Ic), + rotor_type=rotor, + ) \ No newline at end of file From 6f4bb2c073a1cba4b3c6b44371f289b39eb2ac17 Mon Sep 17 00:00:00 2001 From: crjacinto Date: Fri, 17 Apr 2026 15:58:58 +0200 Subject: [PATCH 2/2] Fixed missing type annotations --- src/opi/utils/rmsd.py | 35 +----- src/opi/utils/rotconst.py | 235 ++++++++++++++++++++++---------------- 2 files changed, 143 insertions(+), 127 deletions(-) diff --git a/src/opi/utils/rmsd.py b/src/opi/utils/rmsd.py index 221546d1..2c75ff68 100644 --- a/src/opi/utils/rmsd.py +++ b/src/opi/utils/rmsd.py @@ -1,11 +1,10 @@ -import numpy as np import os -from typing import Tuple, List, Union +from typing import List, Tuple, Union + +import numpy as np -def read_xyz( - data: Union[str, Tuple[List[str], np.ndarray]] -) -> Tuple[List[str], np.ndarray]: +def read_xyz(data: Union[str, Tuple[List[str], np.ndarray]]) -> Tuple[List[str], np.ndarray]: """ Read geometry from: - XYZ file path @@ -18,11 +17,7 @@ def read_xyz( coords : (N, 3) ndarray """ - # ------------------------- - # Case 1: string input - # ------------------------- if isinstance(data, str): - # Case 1a: a file path if os.path.isfile(data): with open(data, "r") as f: @@ -36,7 +31,7 @@ def read_xyz( symbols = [] coords = [] - for line in lines[2:2 + n_atoms]: + for line in lines[2 : 2 + n_atoms]: parts = line.split() sym = parts[0] xyz = [float(x) for x in parts[1:4]] @@ -46,26 +41,8 @@ def read_xyz( return symbols, np.array(coords, dtype=np.float64) - # ------------------------- - # Case 2: in-memory tuple - # ------------------------- - elif isinstance(data, tuple) and len(data) == 2: - symbols, coords = data - - coords = np.asarray(coords, dtype=np.float64) - - if coords.ndim != 2 or coords.shape[1] != 3: - raise ValueError("Coordinates must be an (N,3) array") - - if len(symbols) != len(coords): - raise ValueError("Symbols and coordinates must have same length") - - return list(symbols), coords - else: - raise TypeError( - "Input must be: file path, XYZ string, or (symbols, coords)" - ) + raise TypeError("Input must be: file path or XYZ string") def _validate_geometries(symA: List[str], symB: List[str]) -> None: diff --git a/src/opi/utils/rotconst.py b/src/opi/utils/rotconst.py index d43ad295..e82e9e78 100644 --- a/src/opi/utils/rotconst.py +++ b/src/opi/utils/rotconst.py @@ -1,17 +1,17 @@ -import numpy as np import os -from dataclasses import dataclass import warnings -from typing import Iterable, Union, Tuple, List +from dataclasses import dataclass +from typing import List, Tuple, Union +import numpy as np # ============================================================ # Physical constants # ============================================================ -_AMU_TO_KG = 1.66053906660e-27 -_ANGSTROM_TO_M = 1.0e-10 -_H_PLANCK = 6.62607015e-34 -_C_CM = 2.99792458e10 +_AMU_TO_KG = 1.66053906660e-27 +_ANGSTROM_TO_M = 1.0e-10 +_H_PLANCK = 6.62607015e-34 +_C_CM = 2.99792458e10 # ============================================================ @@ -20,61 +20,127 @@ ATOMIC_MASSES = { "X": 0.0, "PointCharge": 0.0, - - "H": 1.008, "He": 4.003, - "Li": 6.941, "Be": 9.012, "B": 10.810, "C": 12.011, "N": 14.007, - "O": 15.999, "F": 18.998, "Ne": 20.179, - - "Na": 22.990, "Mg": 24.305, "Al": 26.982, "Si": 28.086, - "P": 30.974, "S": 32.060, "Cl": 35.453, "Ar": 39.948, - - "K": 39.100, "Ca": 40.080, - - "Sc": 44.960, "Ti": 47.900, "V": 50.940, "Cr": 52.000, - "Mn": 54.940, "Fe": 55.850, "Co": 58.930, "Ni": 58.700, - "Cu": 63.550, "Zn": 65.380, - - "Ga": 69.720, "Ge": 72.590, "As": 74.920, "Se": 78.960, - "Br": 79.900, "Kr": 83.800, - - "Rb": 85.479, "Sr": 87.620, - - "Y": 88.910, "Zr": 91.220, "Nb": 92.910, "Mo": 95.940, - "Tc": 97.000, "Ru": 101.070, "Rh": 102.910, "Pd": 106.400, - "Ag": 107.870, "Cd": 112.410, - - "In": 114.820, "Sn": 118.690, "Sb": 121.750, "Te": 127.600, - "I": 126.900, "Xe": 131.300, - - "Cs": 132.9054, "Ba": 137.3300, - - "La": 138.9055, "Ce": 140.1200, "Pr": 140.9077, "Nd": 144.2400, - "Pm": 145.0000, "Sm": 150.4000, "Eu": 151.9600, "Gd": 157.2500, - "Tb": 158.9254, "Dy": 162.5000, "Ho": 164.9304, "Er": 167.2600, - "Tm": 168.9342, "Yb": 173.0400, "Lu": 174.9670, - - "Hf": 178.4900, "Ta": 180.9479, "W": 183.8500, "Re": 186.2070, - "Os": 190.2000, "Ir": 192.2200, "Pt": 195.0900, "Au": 196.9665, + "H": 1.008, + "He": 4.003, + "Li": 6.941, + "Be": 9.012, + "B": 10.810, + "C": 12.011, + "N": 14.007, + "O": 15.999, + "F": 18.998, + "Ne": 20.179, + "Na": 22.990, + "Mg": 24.305, + "Al": 26.982, + "Si": 28.086, + "P": 30.974, + "S": 32.060, + "Cl": 35.453, + "Ar": 39.948, + "K": 39.100, + "Ca": 40.080, + "Sc": 44.960, + "Ti": 47.900, + "V": 50.940, + "Cr": 52.000, + "Mn": 54.940, + "Fe": 55.850, + "Co": 58.930, + "Ni": 58.700, + "Cu": 63.550, + "Zn": 65.380, + "Ga": 69.720, + "Ge": 72.590, + "As": 74.920, + "Se": 78.960, + "Br": 79.900, + "Kr": 83.800, + "Rb": 85.479, + "Sr": 87.620, + "Y": 88.910, + "Zr": 91.220, + "Nb": 92.910, + "Mo": 95.940, + "Tc": 97.000, + "Ru": 101.070, + "Rh": 102.910, + "Pd": 106.400, + "Ag": 107.870, + "Cd": 112.410, + "In": 114.820, + "Sn": 118.690, + "Sb": 121.750, + "Te": 127.600, + "I": 126.900, + "Xe": 131.300, + "Cs": 132.9054, + "Ba": 137.3300, + "La": 138.9055, + "Ce": 140.1200, + "Pr": 140.9077, + "Nd": 144.2400, + "Pm": 145.0000, + "Sm": 150.4000, + "Eu": 151.9600, + "Gd": 157.2500, + "Tb": 158.9254, + "Dy": 162.5000, + "Ho": 164.9304, + "Er": 167.2600, + "Tm": 168.9342, + "Yb": 173.0400, + "Lu": 174.9670, + "Hf": 178.4900, + "Ta": 180.9479, + "W": 183.8500, + "Re": 186.2070, + "Os": 190.2000, + "Ir": 192.2200, + "Pt": 195.0900, + "Au": 196.9665, "Hg": 200.5900, - - "Tl": 204.3700, "Pb": 207.2000, "Bi": 208.9804, "Po": 209.0000, - "At": 210.0000, "Rn": 222.0000, - - "Fr": 223.0000, "Ra": 226.0254, - - "Ac": 227.0278, "Th": 232.0381, "Pa": 231.0359, "U": 238.0290, - "Np": 237.0482, "Pu": 244.0000, "Am": 243.0000, "Cm": 247.0000, - "Bk": 247.0000, "Cf": 251.0000, "Es": 252.0000, "Fm": 257.0000, - "Md": 258.0000, "No": 259.0000, "Lr": 262.0000, - - "Rf": 267.0000, "Db": 268.0000, "Sg": 269.0000, "Bh": 270.0000, - "Hs": 269.0000, "Mt": 278.0000, "Ds": 281.0000, "Rg": 281.0000, + "Tl": 204.3700, + "Pb": 207.2000, + "Bi": 208.9804, + "Po": 209.0000, + "At": 210.0000, + "Rn": 222.0000, + "Fr": 223.0000, + "Ra": 226.0254, + "Ac": 227.0278, + "Th": 232.0381, + "Pa": 231.0359, + "U": 238.0290, + "Np": 237.0482, + "Pu": 244.0000, + "Am": 243.0000, + "Cm": 247.0000, + "Bk": 247.0000, + "Cf": 251.0000, + "Es": 252.0000, + "Fm": 257.0000, + "Md": 258.0000, + "No": 259.0000, + "Lr": 262.0000, + "Rf": 267.0000, + "Db": 268.0000, + "Sg": 269.0000, + "Bh": 270.0000, + "Hs": 269.0000, + "Mt": 278.0000, + "Ds": 281.0000, + "Rg": 281.0000, "Cn": 285.0000, - - "Nh": 283.0000, "Fl": 289.0000, "Mc": 288.0000, - "Lv": 293.0000, "Ts": 294.0000, "Og": 294.0000, + "Nh": 283.0000, + "Fl": 289.0000, + "Mc": 288.0000, + "Lv": 293.0000, + "Ts": 294.0000, + "Og": 294.0000, } + # ============================================================ # Dataclass # ============================================================ @@ -90,7 +156,7 @@ class RotationalConstants: rotor_type: str def __str__(self) -> str: - def fmt(x, unit=""): + def fmt(x: float | None, unit: str = "") -> str: return f"{x:.6f} {unit}" if x is not None else "None" return ( @@ -117,9 +183,7 @@ def _normalize_symbol(s: str) -> str: return s.capitalize() -def _read_xyz( - data: Union[str, Tuple[List[str], np.ndarray]] -) -> Tuple[List[str], np.ndarray]: +def _read_xyz(data: Union[str, Tuple[List[str], np.ndarray]]) -> Tuple[List[str], np.ndarray]: """ Read geometry from: - XYZ file path @@ -136,7 +200,6 @@ def _read_xyz( # Case 1: string input # ------------------------- if isinstance(data, str): - # Case 1a: file path if os.path.isfile(data): with open(data, "r") as f: @@ -156,7 +219,7 @@ def _read_xyz( symbols = [] coords = [] - for i, line in enumerate(lines[2:2 + n_atoms], start=3): + for i, line in enumerate(lines[2 : 2 + n_atoms], start=3): parts = line.split() if len(parts) < 4: @@ -166,7 +229,7 @@ def _read_xyz( raw_sym = parts[0].strip() # Remove numeric labels (e.g. C1 → C) - sym = ''.join(filter(str.isalpha, raw_sym)) + sym = "".join(filter(str.isalpha, raw_sym)) # Capitalization (cap insensitive) sym = sym.capitalize() @@ -181,32 +244,8 @@ def _read_xyz( return symbols, np.array(coords, dtype=np.float64) - # ------------------------- - # Case 2: in-memory tuple - # ------------------------- - elif isinstance(data, tuple) and len(data) == 2: - symbols, coords = data - - coords = np.asarray(coords, dtype=np.float64) - - if coords.ndim != 2 or coords.shape[1] != 3: - raise ValueError("Coordinates must be an (N,3) array") - - if len(symbols) != len(coords): - raise ValueError("Symbols and coordinates must have same length") - - # Clean symbols here as well (consistency!) - clean_symbols = [] - for s in symbols: - s = ''.join(filter(str.isalpha, str(s))) - clean_symbols.append(s.capitalize()) - - return clean_symbols, coords - else: - raise TypeError( - "Input must be: file path, XYZ string, or (symbols, coords)" - ) + raise TypeError("Input must be: file path or XYZ string") # ============================================================ @@ -215,7 +254,7 @@ def _read_xyz( def rotational_constants( symbols: list[str] | None = None, coords: np.ndarray | None = None, - xyz: str | None = None, + xyz: str | None = None, masses: np.ndarray | None = None, weights: dict[str, float] | None = None, atom_weights: dict[int, float] | None = None, @@ -250,7 +289,7 @@ def rotational_constants( symbols = [_normalize_symbol(s) for s in symbols] # --- Prepare weights --- - weights = { _normalize_symbol(k): v for k, v in (weights or {}).items() } + weights = {_normalize_symbol(k): v for k, v in (weights or {}).items()} atom_weights = atom_weights or {} # --- Assign masses --- @@ -260,7 +299,6 @@ def rotational_constants( else: masses_list = [] for i, s in enumerate(symbols): - if i in atom_weights: m = atom_weights[i] @@ -294,6 +332,7 @@ def rotational_constants( # --- Inertia tensor --- inertia = np.zeros((3, 3), dtype=np.float64) + assert coords is not None for m, r in zip(masses, coords): inertia += m * (np.dot(r, r) * np.eye(3) - np.outer(r, r)) @@ -303,13 +342,13 @@ def rotational_constants( Ia, Ib, Ic = moments_raw # --- Convert to rotational constants --- - def _moment_to_mhz(I): - if I < 1e-6: + def _moment_to_mhz(inertia: float) -> float | None: + if inertia < 1e-6: return None - I_si = I * _AMU_TO_KG * (_ANGSTROM_TO_M ** 2) + I_si = inertia * _AMU_TO_KG * (_ANGSTROM_TO_M**2) return _H_PLANCK / (8.0 * np.pi**2 * I_si) / 1e6 - def _mhz_to_cm(mhz): + def _mhz_to_cm(mhz: float | None) -> float | None: return None if mhz is None else mhz * 1e6 / _C_CM A = _moment_to_mhz(Ia) @@ -342,4 +381,4 @@ def _mhz_to_cm(mhz): C_cm=_mhz_to_cm(C), moments=(Ia, Ib, Ic), rotor_type=rotor, - ) \ No newline at end of file + )