diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2c3eef7..629a9bd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,7 @@ repos: - "--enable=C0415" - "--score=n" name: pylint (no lazy imports) - exclude: "^(src/graphrelax/__init__.py|src/graphrelax/cli.py|src/graphrelax/designer.py|src/graphrelax/relaxer.py|src/graphrelax/utils.py|src/graphrelax/LigandMPNN/|tests/)" + exclude: "^(src/graphrelax/__init__.py|src/graphrelax/cli.py|src/graphrelax/designer.py|src/graphrelax/relaxer.py|src/graphrelax/utils.py|src/graphrelax/ligand_utils.py|src/graphrelax/pipeline.py|src/graphrelax/LigandMPNN/|tests/)" - repo: https://github.com/pre-commit/mirrors-prettier rev: v4.0.0-alpha.8 diff --git a/README.md b/README.md index 4bacf51..ddfe753 100644 --- a/README.md +++ b/README.md @@ -6,49 +6,66 @@ GraphRelax combines **LigandMPNN** (for sequence design and side-chain packing) ## Installation -GraphRelax requires pdbfixer, which is only available via conda-forge. We recommend using conda/mamba for installation. +GraphRelax requires several packages that are only available via conda-forge. We recommend using mamba for faster installation. ### From PyPI (Latest Release) ```bash -# First, install pdbfixer via conda (required) -conda install -c conda-forge pdbfixer +# First, install conda-forge dependencies (not available on PyPI) +mamba install -c conda-forge pdbfixer openmmforcefields openff-toolkit rdkit # Then install graphrelax from PyPI pip install graphrelax ``` -This installs the latest stable release. - ### From Source (Latest Development Version) ```bash -# First, install pdbfixer via conda (required) -conda install -c conda-forge pdbfixer - # Clone the repository git clone https://github.com/delalamo/GraphRelax.git cd GraphRelax -# Install in editable mode +# Option A: Use environment.yml with mamba (recommended) +mamba env create -f environment.yml +conda activate graphrelax pip install -e . -``` -This installs the latest development version with all recent changes. +# Option B: Manual installation +mamba install -c conda-forge pdbfixer openmmforcefields openff-toolkit rdkit +pip install -e . +``` LigandMPNN model weights (~40MB) are downloaded automatically on first run. -### Platform-specific Installation +> **Note:** `mamba` is a drop-in replacement for `conda` that's 10-100x faster. Install it with `conda install -n base -c conda-forge mamba`, or use `micromamba` as a standalone tool. -```bash -# CPU-only (smaller install, no GPU dependencies) -pip install "graphrelax[cpu]" +### Dependencies -# With CUDA 11 GPU support -pip install "graphrelax[cuda11]" +Core dependencies (installed automatically via pip): -# With CUDA 12 GPU support -pip install "graphrelax[cuda12]" +- Python >= 3.9 +- PyTorch >= 2.0 +- NumPy < 2.0 (PyTorch <2.5 is incompatible with NumPy 2.x) +- OpenMM +- BioPython +- ProDy +- dm-tree +- absl-py +- ml-collections + +Conda-forge only (must be installed via mamba/conda, not pip): + +| Package | Purpose | +| ----------------- | -------------------------------- | +| pdbfixer | Structure preparation | +| openmmforcefields | Small molecule force fields | +| openff-toolkit | OpenFF Molecule parameterization | +| rdkit | Bond perception from PDB | + +Install all conda-forge dependencies with: + +```bash +mamba install -c conda-forge pdbfixer openmmforcefields openff-toolkit rdkit ``` ### Docker @@ -75,23 +92,13 @@ docker build -t graphrelax . docker run --rm graphrelax --help ``` -### Dependencies - -Core dependencies (installed automatically via pip): - -- Python >= 3.9 -- PyTorch >= 2.0 -- NumPy < 2.0 (PyTorch <2.5 is incompatible with NumPy 2.x) -- OpenMM -- BioPython -- ProDy -- dm-tree -- absl-py -- ml-collections - -Required (must be installed separately via conda): +## Features -- pdbfixer (conda-forge only, not on PyPI) +- **FastRelax-like protocol**: Alternate between side-chain repacking and energy minimization +- **Sequence design**: Full redesign or residue-specific control via Rosetta-style resfiles +- **Multiple output modes**: Relax-only, repack-only, design-only, or combinations +- **GPU acceleration**: Automatic GPU detection for both LigandMPNN and OpenMM +- **Scorefile output**: Rosetta-compatible scorefiles with energy terms and sequence metrics ## Usage @@ -162,49 +169,111 @@ graphrelax -i input.pdb -o relaxed.pdb --constrained-minimization --stiffness 5. | Default (unconstrained) | No | No | Fast | No | | `--constrained-minimization` | Yes (harmonic) | Yes | Slower | Yes | -**Important:** When your input PDB contains ligands or other non-standard residues (HETATM records other than water), you **must** use `--constrained-minimization`. The unconstrained mode uses AMBER force field parameters which don't include templates for non-standard molecules. Constrained mode uses OpenFold's AmberRelaxation which can handle ligands properly. +### Crystallography Artifact Removal + +By default, GraphRelax **automatically removes** common crystallography and cryo-EM artifacts from input structures: + +- **Buffers**: SO4, PO4, CIT, ACT, MES, HEPES, Tris, etc. +- **Cryoprotectants**: GOL (glycerol), EDO (ethylene glycol), MPD, PEG variants, DMSO +- **Detergents**: DDM, OG, LDAO, CHAPS, SDS, etc. +- **Lipids/Fatty acids**: PLM (palmitic), MYR (myristic), OLA (oleic), etc. +- **Reducing agents**: BME, DTT +- **Halide ions**: CL, BR, IOD + +**Biologically relevant metal ions** (ZN, MG, CA, FE, MN, CU, etc.) are **preserved**. + +```bash +# Artifacts are removed by default +graphrelax -i crystal_structure.pdb -o relaxed.pdb + +# Keep all HETATM records including artifacts +graphrelax -i crystal_structure.pdb -o relaxed.pdb --keep-all-ligands + +# Keep specific artifacts that you need (comma-separated) +graphrelax -i crystal_structure.pdb -o relaxed.pdb --keep-ligand GOL,SO4 +``` ### Working with Ligands -When designing proteins with bound ligands (e.g., heme, cofactors, small molecules), use `ligand_mpnn` model type with constrained minimization: +Biologically relevant ligands (cofactors, substrates, inhibitors) are **auto-detected** and parameterized using openmmforcefields. No special flags are needed. ```bash -# Design around a ligand +# Ligands are automatically included and parameterized +graphrelax -i protein_with_ligand.pdb -o designed.pdb --design + +# Choose a specific force field for ligands +graphrelax -i protein_with_ligand.pdb -o designed.pdb \ + --ligand-forcefield gaff-2.11 + +# Provide SMILES for unknown ligands graphrelax -i protein_with_ligand.pdb -o designed.pdb \ - --design --model-type ligand_mpnn --constrained-minimization + --ligand-smiles "LIG:c1ccccc1" -# Repack side chains around a ligand -graphrelax -i protein_with_ligand.pdb -o repacked.pdb \ - --relax --model-type ligand_mpnn --constrained-minimization +# Strip all ligands (protein-only processing) +graphrelax -i protein_with_ligand.pdb -o designed.pdb --ignore-ligands ``` -**Note:** If you attempt to use unconstrained minimization with a PDB containing ligands, GraphRelax will exit with an error message directing you to use `--constrained-minimization`. +For design around ligands, use `ligand_mpnn` model type: -### Pre-Idealization +```bash +graphrelax -i protein_with_ligand.pdb -o designed.pdb \ + --design --model-type ligand_mpnn +``` -GraphRelax can optionally idealize backbone geometry before processing. This is useful for structures with distorted bond lengths or angles (e.g., from homology modeling or low-resolution experimental data). The idealization step: +#### Ligand Force Field Options + +| Force Field | Flag | Description | +| --------------- | ------------------------------------ | -------------------------- | +| OpenFF Sage 2.0 | `--ligand-forcefield openff-2.0.0` | Modern, accurate (default) | +| GAFF 2.11 | `--ligand-forcefield gaff-2.11` | Well-tested, robust | +| Espaloma 0.3 | `--ligand-forcefield espaloma-0.3.0` | ML-based, fast | + +#### Alternative: Constrained Minimization + +Use `--constrained-minimization` for position-restrained minimization (does not require openmmforcefields): + +```bash +graphrelax -i protein_with_ligand.pdb -o designed.pdb \ + --design --constrained-minimization +``` + +**Requires:** `mamba install -c conda-forge pdbfixer` + +### Backbone Idealization + +By default, GraphRelax idealizes backbone geometry before processing. This corrects distorted bond lengths and angles commonly found in experimental structures or homology models. The idealization step: 1. Corrects backbone bond lengths and angles to ideal values 2. Preserves phi/psi/omega dihedral angles -3. Adds missing atoms and optionally missing residues from SEQRES +3. Adds missing atoms 4. Runs constrained minimization to relieve local strain 5. By default, closes chain breaks (gaps) in the structure +**Important:** By default, GraphRelax adds missing residues from SEQRES records during idealization and renumbers all residues sequentially starting from 1 for each chain. This ensures consistent numbering regardless of the original PDB numbering scheme. + +If you're using a resfile, the residue numbers must match the **idealized** structure numbering (sequential from 1), not the original PDB numbering. To preserve original numbering for resfile compatibility, use one of these options: + +- `--ignore-missing-residues`: Keep original numbering, don't add missing residues +- `--no-idealize`: Skip idealization entirely (preserves original geometry and numbering) + ```bash -# Idealize before relaxation -graphrelax -i input.pdb -o relaxed.pdb --pre-idealize +# Default: idealization is enabled +graphrelax -i input.pdb -o relaxed.pdb -# Idealize but don't add missing residues from SEQRES -graphrelax -i input.pdb -o relaxed.pdb --pre-idealize --ignore-missing-residues +# Skip idealization (use input geometry as-is) +graphrelax -i input.pdb -o relaxed.pdb --no-idealize -# Idealize but keep chain breaks as separate chains (don't close gaps) -graphrelax -i input.pdb -o relaxed.pdb --pre-idealize --retain-chainbreaks +# Don't add missing residues (preserve original numbering for resfiles) +graphrelax -i input.pdb -o relaxed.pdb --ignore-missing-residues + +# Keep chain breaks as separate chains (don't close gaps) +graphrelax -i input.pdb -o relaxed.pdb --retain-chainbreaks # Combine with design -graphrelax -i input.pdb -o designed.pdb --pre-idealize --design +graphrelax -i input.pdb -o designed.pdb --design ``` -**Note:** Pre-idealization requires pdbfixer (`conda install -c conda-forge pdbfixer`). +**Note:** Idealization requires pdbfixer (`mamba install -c conda-forge pdbfixer`). ### Resfile Format @@ -270,22 +339,33 @@ Relaxation options: --constrained-minimization Use constrained minimization with position restraints (AlphaFold-style). Default is unconstrained. Requires pdbfixer. - **Required when input PDB contains ligands.** --stiffness K Restraint stiffness in kcal/mol/A^2 (default: 10.0) Only applies to constrained minimization. --max-iterations N Max L-BFGS iterations, 0=unlimited (default: 0) + --ignore-ligands Strip all ligands before processing. By default, + ligands are auto-detected and parameterized. + --ligand-forcefield FF Force field for ligands: openff-2.0.0 (default), + gaff-2.11, or espaloma-0.3.0 + --ligand-smiles RES:SMILES Provide SMILES for a ligand residue. Can be + used multiple times for multiple ligands. Input preprocessing: --keep-waters Keep water molecules in input (default: removed) - --pre-idealize Idealize backbone geometry before processing. - Corrects bond lengths/angles while preserving - dihedral angles. By default, chain breaks are closed. - Requires pdbfixer. + --keep-all-ligands Keep all HETATM records including crystallography + artifacts. By default, common artifacts are removed. + --keep-ligand RES1,RES2,... + Keep specific ligand residues (comma-separated). + Example: --keep-ligand GOL,SO4 + --no-idealize Skip backbone idealization. By default, backbone + geometry is idealized (bond lengths/angles corrected + while preserving dihedrals). Requires pdbfixer. --ignore-missing-residues Do not add missing residues from SEQRES during - pre-idealization. By default, missing terminal and - internal loop residues are added. - --retain-chainbreaks Do not close chain breaks during pre-idealization. + processing. By default, missing terminal and + internal loop residues are added during relaxation + and idealization. Use this flag to preserve + original PDB residue numbering for resfile compatibility. + --retain-chainbreaks Do not close chain breaks during idealization. By default, chain breaks are closed by treating all segments as a single chain. @@ -295,6 +375,7 @@ Scoring: General: -v, --verbose Verbose output --seed N Random seed for reproducibility + --overwrite Overwrite output files if they exist (default: error) ``` ### Scorefile Output @@ -328,7 +409,7 @@ config = PipelineConfig( constrained=False, # Default: unconstrained minimization ), idealize=IdealizeConfig( - enabled=True, # Enable pre-idealization + enabled=True, # Idealization enabled by default add_missing_residues=True, # Add missing residues from SEQRES close_chainbreaks=True, # Close chain breaks (default) ), diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..1398256 --- /dev/null +++ b/environment.yml @@ -0,0 +1,48 @@ +# GraphRelax Conda Environment +# ============================ +# +# Creates a complete environment with all dependencies for GraphRelax, +# including ligand support. +# +# Usage (mamba is 10-100x faster than conda): +# mamba env create -f environment.yml +# conda activate graphrelax +# pip install -e . +# +# Or with micromamba (standalone, no conda required): +# micromamba create -f environment.yml +# micromamba activate graphrelax +# pip install -e . +# +# For GPU support, also install CUDA-enabled PyTorch: +# pip install torch --index-url https://download.pytorch.org/whl/cu118 + +name: graphrelax +channels: + - conda-forge + - defaults +dependencies: + # Python version + - python>=3.9 + + # Conda-forge only packages (not on PyPI) + - pdbfixer + - openmmforcefields>=0.13.0 + - openff-toolkit>=0.14.0 + - rdkit>=2023.09.1 + + # Core dependencies (also available via pip) + - openmm + - numpy<2 + - biopython + - prody + + # pip dependencies + - pip + - pip: + - torch>=2.0 + - absl-py + - ml-collections + - dm-tree + # Install graphrelax itself + - -e . diff --git a/pyproject.toml b/pyproject.toml index 82b2ac4..2ee2d87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,8 +34,8 @@ dependencies = [ "openmm", ] -# Note: pdbfixer is not on PyPI, install via conda: -# conda install -c conda-forge pdbfixer +# Note: Some dependencies are only available via conda-forge: +# conda install -c conda-forge pdbfixer openmmforcefields openff-toolkit rdkit [project.urls] Homepage = "https://github.com/delalamo/GraphRelax" diff --git a/src/graphrelax/LigandMPNN/openfold/np/relax/cleanup.py b/src/graphrelax/LigandMPNN/openfold/np/relax/cleanup.py index 6f062a0..72b69f0 100644 --- a/src/graphrelax/LigandMPNN/openfold/np/relax/cleanup.py +++ b/src/graphrelax/LigandMPNN/openfold/np/relax/cleanup.py @@ -57,6 +57,9 @@ def fix_pdb(pdbfile, alterations_info): _remove_heterogens(fixer, alterations_info, keep_water=False) fixer.findMissingResidues() alterations_info["missing_residues"] = fixer.missingResidues + # Clear missing residues to preserve original residue numbering + # Missing residues can be added via --pre-idealize if desired + fixer.missingResidues = {} fixer.findMissingAtoms() alterations_info["missing_heavy_atoms"] = fixer.missingAtoms alterations_info["missing_terminals"] = fixer.missingTerminals diff --git a/src/graphrelax/LigandMPNN/openfold/utils/tensor_utils.py b/src/graphrelax/LigandMPNN/openfold/utils/tensor_utils.py index ccc7e7a..bb6fc97 100644 --- a/src/graphrelax/LigandMPNN/openfold/utils/tensor_utils.py +++ b/src/graphrelax/LigandMPNN/openfold/utils/tensor_utils.py @@ -91,7 +91,7 @@ def batched_gather(data, inds, dim=0, no_batch_dims=0): ] remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds ranges.extend(remaining_dims) - return data[ranges] + return data[tuple(ranges)] # With tree_map, a poor man's JAX tree_map diff --git a/src/graphrelax/artifacts.py b/src/graphrelax/artifacts.py new file mode 100644 index 0000000..f97e48c --- /dev/null +++ b/src/graphrelax/artifacts.py @@ -0,0 +1,336 @@ +"""Constants and utilities for crystallography artifact detection and removal. + +Common crystallography and cryo-EM artifacts (buffers, cryoprotectants, +detergents, lipids) are stripped by default during preprocessing. +Use --keep-all-ligands to preserve them. +""" + +import logging +from collections import defaultdict +from typing import Dict, Optional, Set, Tuple + +logger = logging.getLogger(__name__) + +# ============================================================================= +# Artifact Categories +# ============================================================================= + +# Buffer components commonly found in crystallization conditions +BUFFER_ARTIFACTS = frozenset( + { + # Sulfate/Phosphate + "SO4", + "PO4", + "PO3", + "PI", + "2HP", + "2PO", + "SPO", + "SPH", + # Organic acids + "CIT", + "FLC", + "ACT", + "ACE", + "FMT", + "FOR", + "NO3", + "AZI", + # HEPES, MES, Tris, etc. + "MES", + "EPE", + "TRS", + "CAC", + "BIS", + "HEZ", + "MOH", + "BTB", + "TAM", + "MRD", + "PGO", + "144", + "IPS", + # Malonate, Tartrate, etc. + "MLI", + "TAR", + "MAL", + "SUC", + "FUM", + # Borate + "BO3", + "BO4", + # Imidazole (from His-tag purification) + "IMD", + "1MZ", + } +) + +# Cryoprotectants used for flash-cooling crystals +CRYOPROTECTANT_ARTIFACTS = frozenset( + { + # Glycerol and glycols + "GOL", + "EDO", + "EGL", + "PGR", + "PGQ", + "GLC", + # MPD (2-Methyl-2,4-pentanediol) + "MPD", + "1PG", + "PDO", + # PEG variants (polyethylene glycol fragments) + "PEG", + "PGE", + "PG4", + "1PE", + "P6G", + "P33", + "PE4", + "PG0", + "2PE", + "PEU", + "PE8", + "PE5", + "XPE", + "12P", + "15P", + "PG5", + "PG6", + "PEE", + "PE3", + "P4G", + "P2E", + # DMSO + "DMS", + # Isopropanol + "IPA", + } +) + +# Detergents used for membrane protein crystallization +DETERGENT_ARTIFACTS = frozenset( + { + # Maltosides (DDM, DM, UDM, etc.) + "LMT", + "MLA", + "BMA", + "TRE", + "DDM", + "DXM", + # Glucosides (OG, NG, etc.) + "BOG", + "BGC", + "NDG", + "HTG", + "OGA", + "NGS", + # LDAO (Lauryldimethylamine-N-oxide) + "LDA", + "DAO", + # CHAPS/CHAPSO + "CPS", + "CHT", + "SDS", + "CHP", + # Triton, digitonin + "TRT", + "T3A", + "D10", + "D12", + "DGT", + # LMNG, cymals + "MNG", + "CYC", + "LMG", + # C12E8/9 polyoxyethylene + "CE9", + "C8E", + "C10", + "C12", + # Octyl glucoside + "OLC", + } +) + +# Lipids and fatty acids (from LCP crystallization or membrane proteins) +LIPID_ARTIFACTS = frozenset( + { + # Common fatty acids + "PLM", + "MYR", + "OLA", + "STE", + "PAL", + "LNL", + "ARA", + "DCA", + "UND", + "MYS", + "MYA", + "LNO", + "EOA", + "PEF", + # Monoolein (lipidic cubic phase crystallization) + "OLB", + "9OL", + "MPG", + "OLI", + # Phospholipids and fragments + "PC", + "PE", + "PG", + "PS", + "PLC", + "EPH", + "CDL", + # Cholesterol + "CLR", + "CHO", + } +) + +# Reducing agents from protein preparation +REDUCING_AGENT_ARTIFACTS = frozenset( + { + "BME", + "DTT", + "DTU", + "TCE", + "TRO", + "GSH", + } +) + +# Halide ions (often crystallization additives, not biologically relevant) +HALIDE_ARTIFACTS = frozenset( + { + "CL", + "BR", + "IOD", + "F", + } +) + +# Unknown/placeholder atoms +UNKNOWN_ARTIFACTS = frozenset( + { + "UNX", + "UNL", + "UNK", + "DUM", + } +) + +# ============================================================================= +# Combined Sets +# ============================================================================= + +# Master set of all artifacts to remove by default +CRYSTALLOGRAPHY_ARTIFACTS = ( + BUFFER_ARTIFACTS + | CRYOPROTECTANT_ARTIFACTS + | DETERGENT_ARTIFACTS + | LIPID_ARTIFACTS + | REDUCING_AGENT_ARTIFACTS + | HALIDE_ARTIFACTS + | UNKNOWN_ARTIFACTS +) + +# Biologically relevant ions - NOT stripped by default +# These are often structurally/functionally important +BIOLOGICALLY_RELEVANT_IONS = frozenset( + { + "ZN", + "MG", + "CA", + "FE", + "FE2", + "MN", + "CU", + "CO", + "NI", + "MO", + "NA", + "K", # Often biologically relevant in channels/pumps + } +) + +# Water residues (handled separately by remove_waters) +WATER_RESIDUES = frozenset({"HOH", "WAT", "SOL", "TIP3", "TIP4", "SPC"}) + + +# ============================================================================= +# Removal Functions +# ============================================================================= + + +def remove_artifacts( + pdb_string: str, + keep_residues: Optional[Set[str]] = None, +) -> Tuple[str, Dict[str, int]]: + """ + Remove crystallography artifacts from PDB string. + + Artifacts are identified by their residue name (columns 17-20 in PDB + format). HETATM records with residue names in CRYSTALLOGRAPHY_ARTIFACTS + are removed unless they appear in keep_residues. + + Args: + pdb_string: PDB file contents as string + keep_residues: Set of residue names to preserve (whitelist) + + Returns: + Tuple of: + - filtered_pdb: PDB string with artifacts removed + - removed_counts: Dict mapping residue name to atom count + """ + if keep_residues is None: + keep_residues = set() + + # Normalize whitelist to uppercase + keep_residues = {r.upper() for r in keep_residues} + + # Residues to remove = artifacts minus whitelist + to_remove = CRYSTALLOGRAPHY_ARTIFACTS - keep_residues + + kept_lines = [] + removed_counts = defaultdict(int) + + for line in pdb_string.split("\n"): + if line.startswith("HETATM"): + resname = line[17:20].strip().upper() + if resname in to_remove: + removed_counts[resname] += 1 + continue + kept_lines.append(line) + + filtered_pdb = "\n".join(kept_lines) + + return filtered_pdb, dict(removed_counts) + + +def is_artifact(resname: str) -> bool: + """ + Check if a residue name is a known crystallography artifact. + + Args: + resname: Three-letter residue code + + Returns: + True if the residue is in CRYSTALLOGRAPHY_ARTIFACTS + """ + return resname.upper() in CRYSTALLOGRAPHY_ARTIFACTS + + +def is_biologically_relevant_ion(resname: str) -> bool: + """ + Check if a residue name is a biologically relevant ion. + + Args: + resname: Three-letter residue code + + Returns: + True if the residue is in BIOLOGICALLY_RELEVANT_IONS + """ + return resname.upper() in BIOLOGICALLY_RELEVANT_IONS diff --git a/src/graphrelax/chain_gaps.py b/src/graphrelax/chain_gaps.py index e7404b7..a4987ee 100644 --- a/src/graphrelax/chain_gaps.py +++ b/src/graphrelax/chain_gaps.py @@ -278,63 +278,6 @@ def restore_chain_ids( return "\n".join(output_lines) -def add_ter_records_at_gaps(pdb_string: str, gaps: List[ChainGap]) -> str: - """ - Add TER records at gap locations to signal chain breaks. - - This is an alternative to chain splitting that works with force fields - that recognize TER as chain termination. - - Args: - pdb_string: PDB file contents as string - gaps: List of detected gaps - - Returns: - PDB string with TER records inserted at gap locations - """ - if not gaps: - return pdb_string - - # Build set of (chain_id, resnum, icode) where TER should be inserted - # TER goes after residue_before - ter_locations = { - (g.chain_id, g.residue_before, g.icode_before) for g in gaps - } - - output_lines = [] - prev_chain = None - prev_resnum = None - prev_icode = None - - for line in pdb_string.split("\n"): - if line.startswith(("ATOM", "HETATM")) and len(line) > 26: - chain_id = line[21] - resnum = int(line[22:26].strip()) - icode = line[26].strip() if len(line) > 26 else "" - - # Check if we need to insert TER before this line - # (i.e., previous residue was at a gap location) - if ( - prev_chain is not None - and (prev_chain, prev_resnum, prev_icode) in ter_locations - and (chain_id != prev_chain or resnum != prev_resnum) - ): - # Insert TER record - ter_line = "TER " - output_lines.append(ter_line) - logger.debug( - f"Inserted TER after {prev_chain}:{prev_resnum}{prev_icode}" - ) - - prev_chain = chain_id - prev_resnum = resnum - prev_icode = icode - - output_lines.append(line) - - return "\n".join(output_lines) - - def get_gap_summary(gaps: List[ChainGap]) -> str: """ Generate a human-readable summary of detected gaps. diff --git a/src/graphrelax/cli.py b/src/graphrelax/cli.py index 8c75e60..010b7c2 100644 --- a/src/graphrelax/cli.py +++ b/src/graphrelax/cli.py @@ -19,6 +19,43 @@ def setup_logging(verbose: bool): ) +def _get_output_paths(output_path: Path, n_outputs: int) -> list: + """ + Get list of all output file paths that will be written. + + Args: + output_path: Base output path + n_outputs: Number of outputs to generate + + Returns: + List of Path objects for all output files + """ + if n_outputs == 1: + return [output_path] + + paths = [] + stem = output_path.stem + suffix = output_path.suffix + for i in range(1, n_outputs + 1): + paths.append(output_path.parent / f"{stem}_{i}{suffix}") + return paths + + +def _check_output_exists(output_path: Path, n_outputs: int) -> list: + """ + Check if any output files already exist. + + Args: + output_path: Base output path + n_outputs: Number of outputs to generate + + Returns: + List of existing file paths + """ + paths = _get_output_paths(output_path, n_outputs) + return [p for p in paths if p.exists()] + + def _check_for_ligands(input_path: Path, fmt) -> bool: """ Check if input structure has ligands (non-water HETATM records). @@ -216,6 +253,44 @@ def create_parser() -> argparse.ArgumentParser: "to prevent artificial gap closure during minimization." ), ) + relax_group.add_argument( + "--ignore-ligands", + action="store_true", + help=( + "Strip all ligands (HETATM records) before processing. " + "By default, ligands are auto-detected and parameterized " + "using openmmforcefields." + ), + ) + relax_group.add_argument( + "--ligand-forcefield", + choices=["openff-2.0.0", "gaff-2.11", "espaloma-0.3.0"], + default="openff-2.0.0", + metavar="FF", + help=( + "Force field for ligand parameterization (default: openff-2.0.0)." + " Options: openff-2.0.0 (Sage), gaff-2.11 (GAFF2), espaloma-0.3.0" + ), + ) + relax_group.add_argument( + "--ligand-smiles", + type=str, + action="append", + metavar="RESNAME:SMILES", + help=( + "Provide SMILES for a ligand residue. Format: RESNAME:SMILES. " + "Can be used multiple times. Example: --ligand-smiles LIG:SMILES" + ), + ) + relax_group.add_argument( + "--no-fetch-ligand-smiles", + action="store_true", + help=( + "Disable automatic SMILES lookup from PDBe Chemical Component " + "Dictionary. By default, if RDKit cannot infer bond topology from " + "coordinates, SMILES are fetched from PDBe." + ), + ) # Scoring options score_group = parser.add_argument_group("Scoring options") @@ -234,28 +309,47 @@ def create_parser() -> argparse.ArgumentParser: help="Keep water molecules in input (default: waters are removed)", ) preprocess_group.add_argument( - "--pre-idealize", + "--keep-all-ligands", + action="store_true", + help=( + "Keep all HETATM records including crystallography artifacts " + "(buffers, cryoprotectants, detergents, lipids). " + "By default, common artifacts are removed." + ), + ) + preprocess_group.add_argument( + "--keep-ligand", + type=str, + metavar="RESNAMES", + help=( + "Keep specific ligand residue(s) that would otherwise be stripped " + "as artifacts. Comma-separated. Example: --keep-ligand GOL,SO4" + ), + ) + preprocess_group.add_argument( + "--no-idealize", action="store_true", help=( - "Idealize backbone geometry before processing. " - "Runs constrained minimization to fix local geometry while " - "preserving dihedral angles. By default, chain breaks are closed." + "Skip backbone idealization. By default, backbone geometry is " + "idealized before processing (corrects bond lengths/angles while " + "preserving dihedral angles). Use this flag to skip idealization." ), ) preprocess_group.add_argument( "--ignore-missing-residues", action="store_true", help=( - "Do not add missing residues from SEQRES during pre-idealization. " - "By default, missing N/C-terminal residues and internal loops are " - "added based on SEQRES records." + "Do not add missing residues from SEQRES. By default, missing " + "N/C-terminal residues and internal loops are added. Use this " + "to preserve original PDB residue numbering for resfile " + "compatibility." ), ) preprocess_group.add_argument( "--retain-chainbreaks", action="store_true", help=( - "Do not close chain breaks during pre-idealization. " + "Do not close chain breaks during idealization. " "By default, chain breaks are closed by treating all segments " "as a single chain. Use this to preserve gaps." ), @@ -275,6 +369,11 @@ def create_parser() -> argparse.ArgumentParser: metavar="N", help="Random seed for reproducibility", ) + general_group.add_argument( + "--overwrite", + action="store_true", + help="Overwrite output files if they exist (default: error if exists)", + ) return parser @@ -296,6 +395,23 @@ def main(args=None) -> int: logger.error(f"Resfile not found: {opts.resfile}") return 1 + # Check if output files already exist (unless --overwrite is set) + if not opts.overwrite: + existing = _check_output_exists(opts.output, opts.n_outputs) + if existing: + if len(existing) == 1: + logger.error( + f"Output file already exists: {existing[0]}. " + "Use --overwrite to replace." + ) + else: + files = ", ".join(str(p) for p in existing) + logger.error( + f"Output files already exist: {files}. " + "Use --overwrite to replace." + ) + return 1 + # Validate input format input_suffix = opts.input.suffix.lower() if input_suffix not in (".pdb", ".cif", ".mmcif"): @@ -335,19 +451,28 @@ def main(args=None) -> int: input_format = detect_format(opts.input) has_ligands = _check_for_ligands(opts.input, input_format) - # Validate: ligand_mpnn with ligands requires constrained minimization + # Log ligand handling info uses_relaxation = mode in ( PipelineMode.RELAX, PipelineMode.NO_REPACK, PipelineMode.DESIGN, ) - if has_ligands and uses_relaxation and not opts.constrained_minimization: - logger.error( - "Input PDB contains ligands (HETATM records). " - "Unconstrained minimization cannot handle non-standard residues. " - "Please use --constrained-minimization flag." + if has_ligands and uses_relaxation: + if opts.ignore_ligands: + logger.info("Ligands will be stripped (--ignore-ligands)") + elif opts.constrained_minimization: + logger.info("Ligands handled via constrained minimization") + else: + logger.info("Ligands auto-detected, using openmmforcefields") + + # Warn about resfile + idealization interaction + if opts.resfile and not opts.no_idealize: + logger.warning( + "Using resfile with idealization enabled. Residue numbers in " + "the resfile should match the idealized structure (sequential " + "numbering starting from 1). Use --no-idealize or " + "--ignore-missing-residues to preserve original numbering." ) - return 1 logger.info(f"Running GraphRelax in {mode.value} mode") logger.info(f"Input: {opts.input}") @@ -361,15 +486,41 @@ def main(args=None) -> int: seed=opts.seed, ) + # Parse ligand SMILES if provided + ligand_smiles = {} + if opts.ligand_smiles: + for entry in opts.ligand_smiles: + if ":" not in entry: + logger.error( + f"Invalid --ligand-smiles format: '{entry}'. " + "Expected format: RESNAME:SMILES" + ) + return 1 + resname, smiles = entry.split(":", 1) + ligand_smiles[resname.strip().upper()] = smiles.strip() + relax_config = RelaxConfig( stiffness=opts.stiffness, max_iterations=opts.max_iterations, constrained=opts.constrained_minimization, split_chains_at_gaps=not opts.no_split_gaps, + add_missing_residues=not opts.ignore_missing_residues, + ignore_ligands=opts.ignore_ligands, + ligand_forcefield=opts.ligand_forcefield, + ligand_smiles=ligand_smiles, + fetch_pdbe_smiles=not opts.no_fetch_ligand_smiles, ) + # Build keep_residues set from --keep-ligand flag (comma-separated) + keep_residues = set() + if opts.keep_ligand: + for resname in opts.keep_ligand.split(","): + resname = resname.strip().upper() + if resname: + keep_residues.add(resname) + idealize_config = IdealizeConfig( - enabled=opts.pre_idealize, + enabled=not opts.no_idealize, add_missing_residues=not opts.ignore_missing_residues, close_chainbreaks=not opts.retain_chainbreaks, ) @@ -381,6 +532,8 @@ def main(args=None) -> int: scorefile=opts.scorefile, verbose=opts.verbose, remove_waters=not opts.keep_waters, + remove_artifacts=not opts.keep_all_ligands, + keep_residues=keep_residues, design=design_config, relax=relax_config, idealize=idealize_config, diff --git a/src/graphrelax/config.py b/src/graphrelax/config.py index 1a13895..2bb3ff3 100644 --- a/src/graphrelax/config.py +++ b/src/graphrelax/config.py @@ -3,7 +3,9 @@ from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Literal, Optional +from typing import Dict, Literal, Optional, Set + +LigandForceField = Literal["openff-2.0.0", "gaff-2.11", "espaloma-0.3.0"] class PipelineMode(Enum): @@ -42,14 +44,25 @@ class RelaxConfig: max_outer_iterations: int = 3 # Violation-fixing iterations constrained: bool = False # Use constrained (AmberRelaxation) minimization split_chains_at_gaps: bool = True # Split chains at gaps to prevent closure + add_missing_residues: bool = True # Add missing residues from SEQRES # GPU is auto-detected and used when available + # Ligand support options (ligands are auto-detected) + ignore_ligands: bool = False # If True, strip all ligands before processing + ligand_forcefield: LigandForceField = ( + "openff-2.0.0" # Force field for ligands + ) + ligand_smiles: Dict[str, str] = field( + default_factory=dict + ) # {resname: SMILES} + fetch_pdbe_smiles: bool = True # Auto-fetch SMILES from PDBe CCD + @dataclass class IdealizeConfig: """Configuration for structure idealization preprocessing.""" - enabled: bool = False # Idealization disabled by default + enabled: bool = True # Idealization enabled by default fix_cis_omega: bool = True # Correct non-trans peptide bonds (except Pro) post_idealize_stiffness: float = 10.0 # kcal/mol/A^2 for restraint add_missing_residues: bool = True # Add missing residues from SEQRES @@ -66,6 +79,8 @@ class PipelineConfig: scorefile: Optional[Path] = None # If set, write scores to this file verbose: bool = False remove_waters: bool = True # Remove water molecules from input + remove_artifacts: bool = True # Remove crystallography artifacts by default + keep_residues: Set[str] = field(default_factory=set) # Whitelist residues design: DesignConfig = field(default_factory=DesignConfig) relax: RelaxConfig = field(default_factory=RelaxConfig) idealize: IdealizeConfig = field(default_factory=IdealizeConfig) diff --git a/src/graphrelax/idealize.py b/src/graphrelax/idealize.py index 49e1236..af7b8c0 100644 --- a/src/graphrelax/idealize.py +++ b/src/graphrelax/idealize.py @@ -20,6 +20,7 @@ from openmm import openmm, unit from pdbfixer import PDBFixer +from graphrelax.artifacts import WATER_RESIDUES from graphrelax.chain_gaps import ( ChainGap, detect_chain_gaps, @@ -27,6 +28,7 @@ split_chains_at_gaps, ) from graphrelax.config import IdealizeConfig +from graphrelax.utils import check_gpu_available # Add vendored LigandMPNN to path for OpenFold imports LIGANDMPNN_PATH = Path(__file__).parent / "LigandMPNN" @@ -37,9 +39,6 @@ logger = logging.getLogger(__name__) -# Water residue names to preserve with protein -WATER_RESIDUES = {"HOH", "WAT", "SOL", "TIP3", "TIP4", "SPC"} - @dataclass class DihedralAngles: @@ -482,11 +481,7 @@ def minimize_with_constraints( ) # Check for GPU - use_gpu = False - for i in range(Platform.getNumPlatforms()): - if Platform.getPlatform(i).getName() == "CUDA": - use_gpu = True - break + use_gpu = check_gpu_available() # Create simulation integrator = openmm.LangevinIntegrator(0, 0.01, 0.0) @@ -503,7 +498,7 @@ def minimize_with_constraints( state = simulation.context.getState(getPositions=True) output = io.StringIO() openmm_app.PDBFile.writeFile( - simulation.topology, state.getPositions(), output + simulation.topology, state.getPositions(), output, keepIds=True ) logger.debug("Post-idealization minimization complete") @@ -564,9 +559,15 @@ def idealize_structure( 2. Detect chain gaps 3. Split chains at gaps 4. Extract dihedrals and optionally correct cis-omega - 5. Run constrained minimization to relieve local strain - 6. Restore original chain IDs - 7. Restore ligands + 5. Add missing residues and atoms (protein only) + 6. Run constrained minimization on protein only + 7. Restore original chain IDs + 8. Reintroduce ligands (ligands are kept at original positions) + 9. Renumber residues sequentially + + Note: Ligands are NOT minimized during idealization. They are extracted + before protein minimization and restored afterward at their original + positions. This avoids the need for ligand force field parameterization. Args: pdb_string: Input PDB file contents @@ -579,7 +580,8 @@ def idealize_structure( # Step 1: Extract ligands protein_pdb, ligand_lines = extract_ligands(pdb_string) - if ligand_lines.strip(): + has_ligands = bool(ligand_lines.strip()) + if has_ligands: logger.info("Extracted ligands for separate handling") # Step 2: Detect chain gaps (only if we want to retain them) @@ -604,21 +606,89 @@ def idealize_structure( if residues: _idealize_chain_segment(residues, config) - # Step 5: Run constrained minimization - # This is the key step - it fixes local geometry issues while - # keeping the overall structure close to the original + # Step 5-6: Add missing residues/atoms and run constrained minimization + # (protein only - ligands are not included yet) minimized_pdb = minimize_with_constraints( protein_pdb, stiffness=config.post_idealize_stiffness, add_missing_residues=config.add_missing_residues, ) - # Step 6: Restore original chain IDs + # Step 7: Restore original chain IDs if chain_mapping: minimized_pdb = restore_chain_ids(minimized_pdb, chain_mapping) - # Step 7: Restore ligands - final_pdb = restore_ligands(minimized_pdb, ligand_lines) + # Step 8: Reintroduce ligands + if has_ligands: + # Restore ligands to the minimized protein + # Ligands are kept at their original positions (not minimized) + # This avoids the need for ligand force field parameterization + logger.info( + "Restoring ligands to minimized protein (ligands held fixed)" + ) + final_pdb = restore_ligands(minimized_pdb, ligand_lines) + else: + final_pdb = minimized_pdb + + # Step 9: Renumber residues sequentially per chain + # This fixes issues where pdbfixer assigns non-sequential numbers + final_pdb = renumber_residues_sequential(final_pdb) logger.info("Structure idealization complete") return final_pdb, gaps + + +def renumber_residues_sequential(pdb_string: str) -> str: + """ + Renumber residues sequentially starting from 1 for each chain. + + This fixes issues where pdbfixer assigns non-sequential residue numbers + when adding missing residues. Sequential numbering is required for + proper chain gap detection and LigandMPNN processing. + + HETATM records are renumbered to continue after the last ATOM residue + in each chain. + + Args: + pdb_string: PDB file contents + + Returns: + PDB string with sequential residue numbering + """ + lines = pdb_string.split("\n") + output_lines = [] + + # Track residue numbering per chain + chain_residue_count = {} # chain_id -> next_resnum + residue_map = {} # (chain_id, old_resnum, icode) -> new_resnum + + # First pass: assign new numbers to unique residues + for line in lines: + if line.startswith("ATOM") or line.startswith("HETATM"): + chain_id = line[21] + old_resnum = line[22:26].strip() + icode = line[26] + key = (chain_id, old_resnum, icode) + + if key not in residue_map: + if chain_id not in chain_residue_count: + chain_residue_count[chain_id] = 1 + residue_map[key] = chain_residue_count[chain_id] + chain_residue_count[chain_id] += 1 + + # Second pass: apply new numbers + for line in lines: + if line.startswith("ATOM") or line.startswith("HETATM"): + chain_id = line[21] + old_resnum = line[22:26].strip() + icode = line[26] + key = (chain_id, old_resnum, icode) + + new_resnum = residue_map[key] + # Format residue number right-justified in columns 23-26 + new_line = line[:22] + f"{new_resnum:>4}" + " " + line[27:] + output_lines.append(new_line) + else: + output_lines.append(line) + + return "\n".join(output_lines) diff --git a/src/graphrelax/ligand_utils.py b/src/graphrelax/ligand_utils.py new file mode 100644 index 0000000..a369fdf --- /dev/null +++ b/src/graphrelax/ligand_utils.py @@ -0,0 +1,427 @@ +"""Utilities for ligand parameterization and handling.""" + +import io +import json +import logging +import urllib.error +import urllib.request +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +from openmm import app as openmm_app + +# Ligand parameterization dependencies (conda-forge only) +# conda install -c conda-forge openff-toolkit=0.14.0 rdkit=2023.09.1 +try: + from openff.toolkit import Molecule + from rdkit import Chem + from rdkit.Chem import AllChem + + LIGAND_DEPS_AVAILABLE = True +except ImportError: + LIGAND_DEPS_AVAILABLE = False + Molecule = None + Chem = None + AllChem = None + +logger = logging.getLogger(__name__) + +# Import WATER_RESIDUES from artifacts to avoid duplication +from graphrelax.artifacts import WATER_RESIDUES # noqa: E402 + + +def _check_ligand_deps(): + """Raise ImportError if ligand dependencies are not available.""" + if not LIGAND_DEPS_AVAILABLE: + raise ImportError( + "Ligand parameterization requires openff-toolkit and rdkit.\n" + "These packages are only available via conda-forge:\n\n" + " conda install -c conda-forge " + "openff-toolkit=0.14.0 rdkit=2023.09.1\n\n" + "See the README for full installation instructions." + ) + + +@dataclass +class LigandInfo: + """Information about a ligand extracted from PDB.""" + + resname: str + chain_id: str + resnum: int + pdb_lines: List[str] + smiles: Optional[str] = None + + +def extract_ligands_from_pdb( + pdb_string: str, + exclude_artifacts: bool = True, +) -> Tuple[str, List[LigandInfo]]: + """ + Separate protein ATOM records from ligand HETATM records. + + By default, crystallography artifacts (buffers, cryoprotectants, detergents, + lipids) are excluded from the ligand list since they cannot be meaningfully + parameterized for minimization. + + Args: + pdb_string: Full PDB string with protein and ligands + exclude_artifacts: If True, skip known crystallography artifacts + + Returns: + Tuple of (protein_only_pdb, list_of_ligand_info) + """ + # Import here to avoid circular imports + if exclude_artifacts: + from graphrelax.artifacts import CRYSTALLOGRAPHY_ARTIFACTS + else: + CRYSTALLOGRAPHY_ARTIFACTS = set() + + protein_lines = [] + ligand_lines_by_residue = {} + + for line in pdb_string.split("\n"): + if line.startswith("HETATM"): + resname = line[17:20].strip().upper() + chain_id = line[21] if len(line) > 21 else " " + try: + resnum = int(line[22:26].strip()) + except ValueError: + resnum = 0 + + # Skip water + if resname in WATER_RESIDUES: + continue + + # Skip known artifacts (they won't be parameterized) + if resname in CRYSTALLOGRAPHY_ARTIFACTS: + continue + + key = (chain_id, resnum, resname) + if key not in ligand_lines_by_residue: + ligand_lines_by_residue[key] = [] + ligand_lines_by_residue[key].append(line) + elif line.startswith("END"): + pass # Skip, will add back + else: + protein_lines.append(line) + + protein_pdb = "\n".join(protein_lines) + "\nEND\n" + + ligands = [] + for (chain_id, resnum, resname), lines in ligand_lines_by_residue.items(): + ligands.append( + LigandInfo( + resname=resname, + chain_id=chain_id, + resnum=resnum, + pdb_lines=lines, + ) + ) + + return protein_pdb, ligands + + +def get_ion_smiles() -> Dict[str, str]: + """ + Return SMILES for common ions. + + Ions are single atoms that cannot be parsed from PDB coordinates, + so we need explicit SMILES for them. + """ + return { + "ZN": "[Zn+2]", + "MG": "[Mg+2]", + "CA": "[Ca+2]", + "FE": "[Fe+2]", + "FE2": "[Fe+2]", + "MN": "[Mn+2]", + "CU": "[Cu+2]", + "CO": "[Co+2]", + "NA": "[Na+]", + "K": "[K+]", + "CL": "[Cl-]", + } + + +# Cofactors that contain metals or unusual chemistry that cannot be +# parameterized by standard force fields (GAFF, OpenFF, etc.) +# These ligands will be excluded from minimization and restored unchanged. +UNPARAMETERIZABLE_COFACTORS = { + # Heme and porphyrins (contain Fe) + "HEM", + "HEC", # Heme C + "HEA", # Heme A + "HEB", # Heme B + "1HE", # Heme variant + "2HE", + "DHE", # Deuteroheme + "HAS", # Heme-AS + "HDD", # Hydroxyheme + "HEO", # Heme O + "HNI", # Heme N + "SRM", # Siroheme + # Iron-sulfur clusters + "SF4", # 4Fe-4S cluster + "FES", # 2Fe-2S cluster + "F3S", # 3Fe-4S cluster + # Other metallocofactors + "CLA", # Chlorophyll A + "CLB", # Chlorophyll B + "BCL", # Bacteriochlorophyll + "BPH", # Bacteriopheophytin + "PHO", # Pheophytin + "CHL", # Chlorophyll + "B12", # Vitamin B12 / Cobalamin + "COB", # Cobalamin + "PQQ", # Pyrroloquinoline quinone + "MTE", # Methanopterin + "F43", # Coenzyme F430 (Ni) + "MO7", # Molybdopterin + "MGD", # Molybdopterin guanine dinucleotide + # Copper centers + "CU1", # Copper site + "CUA", # CuA center + "CUB", # CuB center +} + + +def is_unparameterizable_cofactor(resname: str) -> bool: + """Check if a residue is a known unparameterizable cofactor.""" + return resname.upper() in UNPARAMETERIZABLE_COFACTORS + + +# In-memory cache for PDBe SMILES lookups +_PDBE_SMILES_CACHE = {} + + +def fetch_pdbe_smiles(resname: str) -> Optional[str]: + """ + Fetch SMILES from PDBe Chemical Component Dictionary. + + Args: + resname: Three-letter ligand code (e.g., "ATP", "3JD") + + Returns: + SMILES string if found, None otherwise + """ + resname_upper = resname.upper() + + # Check cache first + if resname_upper in _PDBE_SMILES_CACHE: + cached = _PDBE_SMILES_CACHE[resname_upper] + if cached is not None: + logger.debug(f"Using cached SMILES for {resname_upper}") + return cached + + url = f"https://www.ebi.ac.uk/pdbe/api/pdb/compound/summary/{resname_upper}" + + try: + with urllib.request.urlopen(url, timeout=5) as response: + data = json.loads(response.read().decode("utf-8")) + + if resname_upper in data and data[resname_upper]: + compound_data = data[resname_upper][0] + if "smiles" in compound_data and compound_data["smiles"]: + smiles = compound_data["smiles"][0]["name"] + _PDBE_SMILES_CACHE[resname_upper] = smiles + logger.info(f"Fetched SMILES for {resname_upper} from PDBe") + return smiles + + except urllib.error.HTTPError as e: + if e.code == 404: + logger.debug(f"Ligand {resname_upper} not found in PDBe CCD") + else: + logger.warning( + f"HTTP error fetching SMILES for {resname_upper}: {e}" + ) + except urllib.error.URLError as e: + logger.warning( + f"Network error fetching SMILES for {resname_upper}: {e}" + ) + except json.JSONDecodeError as e: + logger.warning(f"JSON parse error for {resname_upper}: {e}") + except Exception as e: + logger.warning( + f"Unexpected error fetching SMILES for {resname_upper}: {e}" + ) + + # Cache the failure too + _PDBE_SMILES_CACHE[resname_upper] = None + return None + + +def is_single_atom_ligand(ligand: LigandInfo) -> bool: + """Check if ligand is a single atom (ion).""" + return len(ligand.pdb_lines) == 1 + + +def create_openff_molecule( + ligand: LigandInfo, + smiles: Optional[str] = None, + fetch_pdbe: bool = True, +): + """ + Create an OpenFF Toolkit Molecule from ligand info. + + Attempts to create the molecule in this order: + 1. From user-provided SMILES (if given) + 2. From ion lookup table (for single-atom ligands) + 3. From PDBe Chemical Component Dictionary (if fetch_pdbe=True) + 4. From PDB coordinates via RDKit bond perception (fallback) + + Note: PDBe lookup is preferred over RDKit because RDKit's bond perception + from 3D coordinates often fails or produces incorrect molecules for complex + organic ligands. PDBe has correct SMILES for all standard PDB ligands. + + Requires openff-toolkit and rdkit (install via conda-forge). + + Args: + ligand: LigandInfo with PDB coordinates + smiles: Optional SMILES string (overrides automatic detection) + fetch_pdbe: If True, try PDBe CCD before RDKit (recommended) + + Returns: + openff.toolkit.Molecule + + Raises: + ImportError: If openff-toolkit or rdkit are not installed. + ValueError: If molecule cannot be created by any method. + """ + _check_ligand_deps() + + # 1. User-provided SMILES takes precedence + if smiles: + try: + mol = Molecule.from_smiles(smiles, allow_undefined_stereo=True) + mol.name = ligand.resname # Required for openmmforcefields matching + logger.debug(f"Created molecule for {ligand.resname} from SMILES") + return mol + except Exception as e: + logger.warning(f"Failed to create from provided SMILES: {e}") + + # 2. Handle ions (single atoms) - need explicit SMILES + if is_single_atom_ligand(ligand): + ion_smiles = get_ion_smiles() + if ligand.resname in ion_smiles: + mol = Molecule.from_smiles( + ion_smiles[ligand.resname], allow_undefined_stereo=True + ) + mol.name = ligand.resname # Required for openmmforcefields matching + logger.debug(f"Created ion molecule for {ligand.resname}") + return mol + else: + raise ValueError( + f"Unknown ion '{ligand.resname}'. Provide SMILES via " + f"ligand_smiles={{'{ligand.resname}': '[Element+charge]'}}" + ) + + # 3. Try PDBe Chemical Component Dictionary first (most reliable) + if fetch_pdbe: + pdbe_smiles = fetch_pdbe_smiles(ligand.resname) + if pdbe_smiles: + try: + mol = Molecule.from_smiles( + pdbe_smiles, allow_undefined_stereo=True + ) + mol.name = ( + ligand.resname + ) # Required for openmmforcefields matching + logger.info( + f"Created molecule for {ligand.resname} using PDBe SMILES" + ) + return mol + except Exception as e: + logger.warning(f"Failed to create from PDBe SMILES: {e}") + + # 4. Fallback: Try RDKit bond perception from PDB coordinates + # Note: This often produces incorrect molecules for complex organic ligands + pdb_block = "\n".join(ligand.pdb_lines) + "\nEND\n" + rdkit_error = None + try: + mol = _create_molecule_via_rdkit(pdb_block) + mol.name = ligand.resname # Required for openmmforcefields matching + logger.debug(f"Created molecule for {ligand.resname} via RDKit") + return mol + except Exception as e: + rdkit_error = e + logger.debug(f"RDKit parsing failed: {e}") + + # 5. Last resort: Try direct OpenFF PDB parsing + openff_error = None + try: + mol = Molecule.from_pdb_file( + io.StringIO(pdb_block), + allow_undefined_stereo=True, + ) + mol.name = ligand.resname # Required for openmmforcefields matching + logger.debug(f"Created molecule for {ligand.resname} from PDB") + return mol + except Exception as e: + openff_error = e + logger.debug(f"OpenFF PDB parsing failed: {e}") + + # All methods failed + raise ValueError( + f"Could not create molecule for ligand {ligand.resname}.\n" + f"RDKit error: {rdkit_error}\n" + f"OpenFF error: {openff_error}\n\n" + f"You can provide SMILES manually via --ligand-smiles " + f"'{ligand.resname}:YOUR_SMILES_HERE'" + ) + + +def _create_molecule_via_rdkit(pdb_block: str): + """Create OpenFF Molecule via RDKit from PDB block.""" + # Parse PDB with RDKit + mol = Chem.MolFromPDBBlock(pdb_block, removeHs=False, sanitize=False) + + if mol is None: + raise ValueError("RDKit could not parse PDB block") + + # Try to sanitize + try: + Chem.SanitizeMol(mol) + except Exception: + # Try without hydrogens and re-add them + mol = Chem.MolFromPDBBlock(pdb_block, removeHs=True, sanitize=True) + if mol is None: + raise ValueError("RDKit sanitization failed") + mol = Chem.AddHs(mol, addCoords=True) + AllChem.EmbedMolecule(mol, randomSeed=42) + + # Convert to OpenFF Molecule + return Molecule.from_rdkit(mol, allow_undefined_stereo=True) + + +def ligand_pdb_to_topology(ligand: LigandInfo, strip_hydrogens: bool = False): + """ + Convert ligand PDB lines to OpenMM topology and positions. + + Args: + ligand: LigandInfo with PDB lines + strip_hydrogens: If True, remove hydrogen atoms from the topology. + Useful when hydrogens will be added later by addHydrogens(). + + Returns: + Tuple of (topology, positions) + """ + if strip_hydrogens: + # Filter out hydrogen atoms (element column is at position 77-78) + heavy_atom_lines = [] + for line in ligand.pdb_lines: + if len(line) >= 78: + element = line[76:78].strip() + if element != "H": + heavy_atom_lines.append(line) + else: + # Try to detect hydrogen from atom name (starts with H) + atom_name = line[12:16].strip() if len(line) >= 16 else "" + if not atom_name.startswith("H"): + heavy_atom_lines.append(line) + pdb_block = "\n".join(heavy_atom_lines) + "\nEND\n" + else: + pdb_block = "\n".join(ligand.pdb_lines) + "\nEND\n" + + pdb_file = openmm_app.PDBFile(io.StringIO(pdb_block)) + return pdb_file.topology, pdb_file.positions diff --git a/src/graphrelax/pipeline.py b/src/graphrelax/pipeline.py index d76dcc4..d7c258d 100644 --- a/src/graphrelax/pipeline.py +++ b/src/graphrelax/pipeline.py @@ -183,6 +183,21 @@ def _run_single_output( if removed > 0: logger.info(f"Removed {removed} water-related lines from input") + # Remove crystallography artifacts if requested + if self.config.remove_artifacts: + from graphrelax.artifacts import remove_artifacts + + current_structure, removed_counts = remove_artifacts( + current_structure, + keep_residues=self.config.keep_residues, + ) + if removed_counts: + total = sum(removed_counts.values()) + artifacts = ", ".join( + f"{k}({v})" for k, v in sorted(removed_counts.items()) + ) + logger.info(f"Removed {total} artifact atoms: {artifacts}") + # Convert to PDB format for internal processing if needed current_pdb = ensure_pdb_format(current_structure, input_pdb) @@ -190,7 +205,8 @@ def _run_single_output( if self.config.idealize.enabled: logger.info("Running pre-idealization...") current_pdb, gaps = idealize_structure( - current_pdb, self.config.idealize + current_pdb, + self.config.idealize, ) if gaps: logger.info( diff --git a/src/graphrelax/relaxer.py b/src/graphrelax/relaxer.py index 3407ef3..9359def 100644 --- a/src/graphrelax/relaxer.py +++ b/src/graphrelax/relaxer.py @@ -65,8 +65,8 @@ def relax(self, pdb_string: str) -> Tuple[str, dict, np.ndarray]: gaps before minimization to prevent artificial gap closure. Ligands (non-water HETATM records) are extracted before relaxation - and restored afterward, since standard AMBER force fields cannot - parameterize arbitrary ligands. + and restored afterward. Protein atoms near ligands are restrained + to prevent clashes when ligands are restored. Args: pdb_string: PDB file contents as string @@ -76,10 +76,13 @@ def relax(self, pdb_string: str) -> Tuple[str, dict, np.ndarray]: """ # Extract ligands before relaxation (AMBER can't parameterize them) protein_pdb, ligand_lines = extract_ligands(pdb_string) + ligand_coords = None if ligand_lines.strip(): logger.debug( "Extracted ligands for separate handling during relaxation" ) + # Parse ligand coordinates to restrain nearby protein atoms + ligand_coords = self._parse_ligand_coords(ligand_lines) # Detect and handle chain gaps if configured chain_mapping = {} @@ -96,7 +99,7 @@ def relax(self, pdb_string: str) -> Tuple[str, dict, np.ndarray]: relaxed_pdb, debug_info, violations = self.relax_protein(prot) else: relaxed_pdb, debug_info, violations = self._relax_unconstrained( - protein_pdb + protein_pdb, ligand_coords=ligand_coords ) # Restore original chain IDs if chains were split @@ -112,6 +115,20 @@ def relax(self, pdb_string: str) -> Tuple[str, dict, np.ndarray]: return relaxed_pdb, debug_info, violations + def _parse_ligand_coords(self, ligand_lines: str) -> np.ndarray: + """Parse ligand atom coordinates from HETATM lines.""" + coords = [] + for line in ligand_lines.split("\n"): + if line.startswith("HETATM"): + try: + x = float(line[30:38]) + y = float(line[38:46]) + z = float(line[46:54]) + coords.append([x, y, z]) + except (ValueError, IndexError): + continue + return np.array(coords) if coords else None + def relax_pdb_file(self, pdb_path: Path) -> Tuple[str, dict, np.ndarray]: """ Relax a PDB file. @@ -163,18 +180,19 @@ def relax_protein(self, prot) -> Tuple[str, dict, np.ndarray]: return relaxed_pdb, debug_data, violations def _relax_unconstrained( - self, pdb_string: str + self, pdb_string: str, ligand_coords: np.ndarray = None ) -> Tuple[str, dict, np.ndarray]: """ Bare-bones unconstrained OpenMM minimization. - No position restraints, no violation checking, uses OpenMM defaults. - This is the default minimization mode. - - Note: Ligands are extracted at the relax() level before calling this. + No position restraints on protein, uses OpenMM defaults. + If ligand_coords is provided, adds fixed "dummy" particles at those + positions with LJ repulsion to prevent protein from clashing with + ligand positions. Args: pdb_string: PDB file contents as string (protein-only) + ligand_coords: Optional array of ligand atom positions (Angstroms) Returns: Tuple of (relaxed_pdb_string, debug_info, violations) @@ -184,9 +202,11 @@ def _relax_unconstrained( use_gpu = self._check_gpu_available() + has_ligand = ligand_coords is not None and len(ligand_coords) > 0 logger.info( f"Running unconstrained OpenMM minimization " - f"(max_iter={self.config.max_iterations}, gpu={use_gpu})" + f"(max_iter={self.config.max_iterations}, gpu={use_gpu}" + f"{', with ligand exclusion zone' if has_ligand else ''})" ) # Use pdbfixer to add missing atoms and terminal groups @@ -211,6 +231,43 @@ def _relax_unconstrained( modeller.topology, constraints=openmm_app.HBonds ) + n_protein_atoms = system.getNumParticles() + + # Add ligand atoms as fixed dummy particles with LJ repulsion + ligand_particle_indices = [] + if has_ligand: + # Add a custom nonbonded force for ligand-protein repulsion + # Using soft-core LJ potential to prevent singularities + ligand_repulsion = openmm.CustomNonbondedForce( + "epsilon * (sigma/r)^12; " + "sigma=0.3; epsilon=4.0" # 3 Angstrom radius, 4 kJ/mol + ) + ligand_repulsion.setNonbondedMethod( + openmm.CustomNonbondedForce.CutoffNonPeriodic + ) + ligand_repulsion.setCutoffDistance(1.2 * unit.nanometers) + + # Add all protein atoms to the force + for _ in range(n_protein_atoms): + ligand_repulsion.addParticle([]) + + # Add ligand dummy particles to the system + for _ in ligand_coords: + # Add massless particle (won't move) + idx = system.addParticle(0.0) + ligand_particle_indices.append(idx) + ligand_repulsion.addParticle([]) + + # Set interaction groups: protein interacts with ligand dummies + protein_set = set(range(n_protein_atoms)) + ligand_set = set(ligand_particle_indices) + ligand_repulsion.addInteractionGroup(protein_set, ligand_set) + + system.addForce(ligand_repulsion) + logger.debug( + f"Added {len(ligand_coords)} ligand exclusion particles" + ) + # Create integrator and simulation integrator = openmm.LangevinIntegrator(0, 0.01, 0.0) platform = openmm.Platform.getPlatformByName( @@ -219,7 +276,18 @@ def _relax_unconstrained( simulation = openmm_app.Simulation( modeller.topology, system, integrator, platform ) - simulation.context.setPositions(modeller.positions) + + # Set positions: protein from modeller, ligand dummies from coords + positions = list(modeller.positions) + if has_ligand: + for coord in ligand_coords: + # Convert Angstroms to nanometers + positions.append( + openmm.Vec3(coord[0], coord[1], coord[2]) + * 0.1 + * unit.nanometers + ) + simulation.context.setPositions(positions) # Get initial energy state = simulation.context.getState(getEnergy=True, getPositions=True) @@ -237,13 +305,18 @@ def _relax_unconstrained( efinal = state.getPotentialEnergy().value_in_unit(ENERGY) pos = state.getPositions(asNumpy=True).value_in_unit(LENGTH) - # Calculate RMSD - rmsd = np.sqrt(np.sum((posinit - pos) ** 2) / len(posinit)) + # Calculate RMSD (protein atoms only) + rmsd = np.sqrt( + np.sum((posinit[:n_protein_atoms] - pos[:n_protein_atoms]) ** 2) + / n_protein_atoms + ) - # Write output PDB + # Write output PDB (protein only - exclude dummy ligand particles) output = io.StringIO() + # Get only protein positions + protein_positions = state.getPositions()[:n_protein_atoms] openmm_app.PDBFile.writeFile( - simulation.topology, state.getPositions(), output + modeller.topology, protein_positions, output ) relaxed_pdb = output.getvalue() @@ -253,6 +326,8 @@ def _relax_unconstrained( "rmsd": rmsd, "attempts": 1, } + if has_ligand: + debug_data["ligand_exclusion_atoms"] = len(ligand_coords) logger.info( f"Minimization complete: E_init={einit:.2f}, " diff --git a/src/graphrelax/utils.py b/src/graphrelax/utils.py index 9e2ae91..0eec674 100644 --- a/src/graphrelax/utils.py +++ b/src/graphrelax/utils.py @@ -5,10 +5,40 @@ from pathlib import Path from typing import Optional +from graphrelax.artifacts import WATER_RESIDUES from graphrelax.structure_io import StructureFormat logger = logging.getLogger(__name__) +# Cache for GPU availability check +_gpu_available = None + + +def check_gpu_available() -> bool: + """ + Check if CUDA is available for OpenMM. + + Results are cached after first check. + + Returns: + True if CUDA platform is available + """ + global _gpu_available + if _gpu_available is not None: + return _gpu_available + + from openmm import Platform + + for i in range(Platform.getNumPlatforms()): + if Platform.getPlatform(i).getName() == "CUDA": + _gpu_available = True + logger.info("OpenMM CUDA platform detected, using GPU") + return True + + _gpu_available = False + logger.info("OpenMM CUDA not available, using CPU") + return False + def remove_waters(structure_string: str, fmt: StructureFormat = None) -> str: """ @@ -39,7 +69,6 @@ def _remove_waters_pdb(pdb_string: str) -> str: Returns: PDB string with water molecules removed """ - water_residues = {"HOH", "WAT", "SOL", "TIP3", "TIP4", "SPC"} filtered_lines = [] for line in pdb_string.splitlines(): @@ -48,13 +77,13 @@ def _remove_waters_pdb(pdb_string: str) -> str: # Residue name is in columns 17-20 (0-indexed: 17:20) if len(line) >= 20: resname = line[17:20].strip() - if resname in water_residues: + if resname in WATER_RESIDUES: continue # Check TER records that might reference water elif line.startswith("TER"): if len(line) >= 20: resname = line[17:20].strip() - if resname in water_residues: + if resname in WATER_RESIDUES: continue filtered_lines.append(line) @@ -77,11 +106,9 @@ def _remove_waters_cif(cif_string: str) -> str: from Bio.PDB import MMCIFIO, MMCIFParser, Select - water_residues = {"HOH", "WAT", "SOL", "TIP3", "TIP4", "SPC"} - class WaterRemover(Select): def accept_residue(self, residue): - return residue.get_resname().strip() not in water_residues + return residue.get_resname().strip() not in WATER_RESIDUES # MMCIFParser requires a file path with tempfile.NamedTemporaryFile( diff --git a/tests/test_artifacts.py b/tests/test_artifacts.py new file mode 100644 index 0000000..3d8184a --- /dev/null +++ b/tests/test_artifacts.py @@ -0,0 +1,404 @@ +"""Tests for crystallography artifact detection and removal.""" + +from graphrelax.artifacts import ( + BIOLOGICALLY_RELEVANT_IONS, + BUFFER_ARTIFACTS, + CRYOPROTECTANT_ARTIFACTS, + CRYSTALLOGRAPHY_ARTIFACTS, + DETERGENT_ARTIFACTS, + HALIDE_ARTIFACTS, + LIPID_ARTIFACTS, + REDUCING_AGENT_ARTIFACTS, + WATER_RESIDUES, + is_artifact, + is_biologically_relevant_ion, + remove_artifacts, +) + + +class TestArtifactConstants: + """Tests for artifact constant definitions.""" + + def test_common_buffers_included(self): + """Test that common buffer artifacts are in the set.""" + buffers = ["SO4", "PO4", "CIT", "ACT", "MES", "EPE", "TRS"] + for buf in buffers: + assert ( + buf in BUFFER_ARTIFACTS + ), f"{buf} should be in BUFFER_ARTIFACTS" + + def test_common_cryoprotectants_included(self): + """Test that common cryoprotectants are in the set.""" + cryo = ["GOL", "EDO", "MPD", "PEG", "DMS"] + for c in cryo: + assert ( + c in CRYOPROTECTANT_ARTIFACTS + ), f"{c} should be in CRYOPROTECTANT_ARTIFACTS" + + def test_common_detergents_included(self): + """Test that common detergents are in the set.""" + detergents = ["BOG", "LDA", "SDS"] + for d in detergents: + assert ( + d in DETERGENT_ARTIFACTS + ), f"{d} should be in DETERGENT_ARTIFACTS" + + def test_common_lipids_included(self): + """Test that common lipids/fatty acids are in the set.""" + lipids = ["PLM", "MYR", "OLA", "STE"] + for lip in lipids: + assert lip in LIPID_ARTIFACTS, f"{lip} should be in LIPID_ARTIFACTS" + + def test_reducing_agents_included(self): + """Test that common reducing agents are in the set.""" + agents = ["BME", "DTT"] + for a in agents: + assert ( + a in REDUCING_AGENT_ARTIFACTS + ), f"{a} should be in REDUCING_AGENT_ARTIFACTS" + + def test_halides_included(self): + """Test that halide ions are in the set.""" + halides = ["CL", "BR", "IOD", "F"] + for h in halides: + assert h in HALIDE_ARTIFACTS, f"{h} should be in HALIDE_ARTIFACTS" + + def test_metal_ions_not_in_artifacts(self): + """Test that biologically relevant metal ions are NOT artifacts.""" + metals = ["ZN", "MG", "CA", "FE", "MN", "CU"] + for m in metals: + assert ( + m not in CRYSTALLOGRAPHY_ARTIFACTS + ), f"{m} should NOT be in CRYSTALLOGRAPHY_ARTIFACTS" + + def test_biologically_relevant_ions_defined(self): + """Test that biologically relevant ions are in their own set.""" + ions = ["ZN", "MG", "CA", "FE", "MN", "CU", "CO", "NI"] + for ion in ions: + assert ( + ion in BIOLOGICALLY_RELEVANT_IONS + ), f"{ion} should be in BIOLOGICALLY_RELEVANT_IONS" + + def test_master_set_contains_all_categories(self): + """Test that CRYSTALLOGRAPHY_ARTIFACTS combines all category sets.""" + all_categories = ( + BUFFER_ARTIFACTS + | CRYOPROTECTANT_ARTIFACTS + | DETERGENT_ARTIFACTS + | LIPID_ARTIFACTS + | REDUCING_AGENT_ARTIFACTS + | HALIDE_ARTIFACTS + ) + for artifact in all_categories: + assert ( + artifact in CRYSTALLOGRAPHY_ARTIFACTS + ), f"{artifact} should be in CRYSTALLOGRAPHY_ARTIFACTS" + + def test_water_residues_separate(self): + """Test that water residues are in their own set.""" + waters = ["HOH", "WAT", "SOL"] + for w in waters: + assert w in WATER_RESIDUES, f"{w} should be in WATER_RESIDUES" + # Waters should NOT be in crystallography artifacts + assert ( + w not in CRYSTALLOGRAPHY_ARTIFACTS + ), f"{w} should NOT be in CRYSTALLOGRAPHY_ARTIFACTS" + + +class TestIsArtifact: + """Tests for is_artifact function.""" + + def test_glycerol_is_artifact(self): + """Test that glycerol is detected as artifact.""" + assert is_artifact("GOL") + assert is_artifact("gol") # case insensitive + + def test_sulfate_is_artifact(self): + """Test that sulfate is detected as artifact.""" + assert is_artifact("SO4") + + def test_zinc_is_not_artifact(self): + """Test that zinc is not detected as artifact.""" + assert not is_artifact("ZN") + + def test_heme_is_not_artifact(self): + """Test that heme is not detected as artifact.""" + assert not is_artifact("HEM") + + +class TestIsBiologicallyRelevantIon: + """Tests for is_biologically_relevant_ion function.""" + + def test_zinc_is_relevant(self): + """Test that zinc is detected as relevant.""" + assert is_biologically_relevant_ion("ZN") + assert is_biologically_relevant_ion("zn") # case insensitive + + def test_magnesium_is_relevant(self): + """Test that magnesium is detected as relevant.""" + assert is_biologically_relevant_ion("MG") + + def test_glycerol_is_not_relevant_ion(self): + """Test that glycerol is not a relevant ion.""" + assert not is_biologically_relevant_ion("GOL") + + +class TestRemoveArtifacts: + """Tests for remove_artifacts function.""" + + def test_removes_glycerol(self): + """Test that glycerol atoms are removed.""" + # fmt: off + pdb = ( + "ATOM 1 N ALA A 1 0.0 0.0 0.0 1.00 0.00\n" + "ATOM 2 CA ALA A 1 1.5 0.0 0.0 1.00 0.00\n" + "HETATM 3 O1 GOL A 100 5.0 5.0 5.0 1.00 0.00\n" + "HETATM 4 O2 GOL A 100 6.0 5.0 5.0 1.00 0.00\n" + "END\n" + ) + # fmt: on + result, removed = remove_artifacts(pdb) + + assert "GOL" not in result + assert "GOL" in removed + assert removed["GOL"] == 2 + assert "ATOM" in result # Protein preserved + + def test_removes_sulfate(self): + """Test that sulfate atoms are removed.""" + # fmt: off + pdb = ( + "ATOM 1 N ALA A 1 0.0 0.0 0.0 1.00 0.00\n" + "HETATM 2 S SO4 A 200 5.0 5.0 5.0 1.00 0.00\n" + "HETATM 3 O1 SO4 A 200 6.0 5.0 5.0 1.00 0.00\n" + "HETATM 4 O2 SO4 A 200 7.0 5.0 5.0 1.00 0.00\n" + "HETATM 5 O3 SO4 A 200 8.0 5.0 5.0 1.00 0.00\n" + "HETATM 6 O4 SO4 A 200 9.0 5.0 5.0 1.00 0.00\n" + "END\n" + ) + # fmt: on + result, removed = remove_artifacts(pdb) + + assert "SO4" not in result + assert removed["SO4"] == 5 + + def test_keeps_zinc(self): + """Test that zinc ions are preserved.""" + # fmt: off + pdb = ( + "ATOM 1 N ALA A 1 0.0 0.0 0.0 1.00 0.00\n" + "HETATM 2 ZN ZN A 200 5.0 5.0 5.0 1.00 0.00\n" + "END\n" + ) + # fmt: on + result, removed = remove_artifacts(pdb) + + assert "ZN" in result + assert "ZN" not in removed + + def test_keeps_heme(self): + """Test that heme is preserved (not an artifact).""" + # fmt: off + pdb = ( + "ATOM 1 N ALA A 1 0.0 0.0 0.0 1.00 0.00\n" + "HETATM 2 FE HEM A 200 5.0 5.0 5.0 1.00 0.00\n" + "HETATM 3 NA HEM A 200 6.0 5.0 5.0 1.00 0.00\n" + "END\n" + ) + # fmt: on + result, removed = remove_artifacts(pdb) + + assert "HEM" in result + assert "HEM" not in removed + + def test_whitelist_preserves_artifact(self): + """Test that whitelisted artifacts are preserved.""" + # fmt: off + pdb = ( + "ATOM 1 N ALA A 1 0.0 0.0 0.0 1.00 0.00\n" + "HETATM 2 O1 GOL A 100 5.0 5.0 5.0 1.00 0.00\n" + "END\n" + ) + # fmt: on + result, removed = remove_artifacts(pdb, keep_residues={"GOL"}) + + assert "GOL" in result + assert "GOL" not in removed + + def test_whitelist_case_insensitive(self): + """Test that whitelist is case insensitive.""" + # fmt: off + pdb = ( + "ATOM 1 N ALA A 1 0.0 0.0 0.0 1.00 0.00\n" + "HETATM 2 O1 GOL A 100 5.0 5.0 5.0 1.00 0.00\n" + "END\n" + ) + # fmt: on + result, removed = remove_artifacts(pdb, keep_residues={"gol"}) + + assert "GOL" in result + + def test_multiple_artifact_types(self): + """Test removal of multiple artifact types.""" + # fmt: off + pdb = ( + "ATOM 1 N ALA A 1 0.0 0.0 0.0 1.00 0.00\n" + "HETATM 2 O1 GOL A 100 5.0 5.0 5.0 1.00 0.00\n" + "HETATM 3 S SO4 A 200 6.0 6.0 6.0 1.00 0.00\n" + "HETATM 4 CL CL A 300 7.0 7.0 7.0 1.00 0.00\n" + "END\n" + ) + # fmt: on + result, removed = remove_artifacts(pdb) + + assert "GOL" not in result + assert "SO4" not in result + assert "CL" not in result + assert removed["GOL"] == 1 + assert removed["SO4"] == 1 + assert removed["CL"] == 1 + + def test_empty_pdb(self): + """Test handling of empty PDB.""" + pdb = "" + result, removed = remove_artifacts(pdb) + assert result == "" + assert len(removed) == 0 + + def test_protein_only_pdb(self): + """Test handling of PDB with no artifacts.""" + # fmt: off + pdb = ( + "ATOM 1 N ALA A 1 0.0 0.0 0.0 1.00 0.00\n" + "ATOM 2 CA ALA A 1 1.5 0.0 0.0 1.00 0.00\n" + "END\n" + ) + # fmt: on + result, removed = remove_artifacts(pdb) + + assert result == pdb + assert len(removed) == 0 + + +class TestRemoveArtifactsDetergentsAndLipids: + """Tests specifically for detergent and lipid removal.""" + + def test_removes_palmitic_acid(self): + """Test that palmitic acid is removed.""" + # fmt: off + pdb = ( + "ATOM 1 N ALA A 1 0.0 0.0 0.0 1.00 0.00\n" + "HETATM 2 C1 PLM A 100 5.0 5.0 5.0 1.00 0.00\n" + "END\n" + ) + # fmt: on + result, removed = remove_artifacts(pdb) + + assert "PLM" not in result + assert removed["PLM"] == 1 + + def test_removes_octyl_glucoside(self): + """Test that octyl glucoside detergent is removed.""" + # fmt: off + pdb = ( + "ATOM 1 N ALA A 1 0.0 0.0 0.0 1.00 0.00\n" + "HETATM 2 C1 BOG A 100 5.0 5.0 5.0 1.00 0.00\n" + "END\n" + ) + # fmt: on + result, removed = remove_artifacts(pdb) + + assert "BOG" not in result + assert removed["BOG"] == 1 + + def test_removes_ldao(self): + """Test that LDAO detergent is removed.""" + # fmt: off + pdb = ( + "ATOM 1 N ALA A 1 0.0 0.0 0.0 1.00 0.00\n" + "HETATM 2 N LDA A 100 5.0 5.0 5.0 1.00 0.00\n" + "END\n" + ) + # fmt: on + result, removed = remove_artifacts(pdb) + + assert "LDA" not in result + assert removed["LDA"] == 1 + + +class TestCLIKeepLigandParsing: + """Tests for CLI --keep-ligand argument parsing.""" + + def test_single_ligand(self): + """Test parsing a single ligand name.""" + from graphrelax.cli import create_parser + + parser = create_parser() + args = parser.parse_args( + ["-i", "in.pdb", "-o", "out.pdb", "--keep-ligand", "GOL"] + ) + assert args.keep_ligand == "GOL" + + def test_comma_separated_ligands(self): + """Test parsing comma-separated ligand names.""" + from graphrelax.cli import create_parser + + parser = create_parser() + args = parser.parse_args( + ["-i", "in.pdb", "-o", "out.pdb", "--keep-ligand", "GOL,SO4,EDO"] + ) + assert args.keep_ligand == "GOL,SO4,EDO" + + # Test the parsing logic that builds the keep_residues set + keep_residues = set() + if args.keep_ligand: + for resname in args.keep_ligand.split(","): + resname = resname.strip().upper() + if resname: + keep_residues.add(resname) + + assert keep_residues == {"GOL", "SO4", "EDO"} + + def test_comma_separated_with_spaces(self): + """Test parsing comma-separated ligand names with spaces.""" + from graphrelax.cli import create_parser + + parser = create_parser() + args = parser.parse_args( + ["-i", "in.pdb", "-o", "out.pdb", "--keep-ligand", "GOL, SO4, EDO"] + ) + + keep_residues = set() + if args.keep_ligand: + for resname in args.keep_ligand.split(","): + resname = resname.strip().upper() + if resname: + keep_residues.add(resname) + + assert keep_residues == {"GOL", "SO4", "EDO"} + + def test_lowercase_normalized_to_uppercase(self): + """Test that lowercase ligand names are normalized to uppercase.""" + from graphrelax.cli import create_parser + + parser = create_parser() + args = parser.parse_args( + ["-i", "in.pdb", "-o", "out.pdb", "--keep-ligand", "gol,so4"] + ) + + keep_residues = set() + if args.keep_ligand: + for resname in args.keep_ligand.split(","): + resname = resname.strip().upper() + if resname: + keep_residues.add(resname) + + assert keep_residues == {"GOL", "SO4"} + + def test_no_keep_ligand(self): + """Test that keep_ligand is None when not provided.""" + from graphrelax.cli import create_parser + + parser = create_parser() + args = parser.parse_args(["-i", "in.pdb", "-o", "out.pdb"]) + assert args.keep_ligand is None diff --git a/tests/test_chain_gaps.py b/tests/test_chain_gaps.py index 53b29d7..c5119a2 100644 --- a/tests/test_chain_gaps.py +++ b/tests/test_chain_gaps.py @@ -4,7 +4,6 @@ from graphrelax.chain_gaps import ( ChainGap, - add_ter_records_at_gaps, detect_chain_gaps, get_gap_summary, restore_chain_ids, @@ -302,27 +301,6 @@ def test_roundtrip_preserves_atoms(self, gapped_peptide_pdb): assert orig_atoms == restored_atoms -class TestAddTerRecordsAtGaps: - """Tests for add_ter_records_at_gaps function.""" - - def test_no_ter_without_gaps(self, continuous_peptide_pdb): - """Test no TER records added to continuous peptide.""" - gaps = detect_chain_gaps(continuous_peptide_pdb, check_distance=False) - result = add_ter_records_at_gaps(continuous_peptide_pdb, gaps) - assert result == continuous_peptide_pdb - - def test_adds_ter_at_gap(self, gapped_peptide_pdb): - """Test TER record is added at gap location.""" - gaps = detect_chain_gaps(gapped_peptide_pdb, check_distance=False) - result = add_ter_records_at_gaps(gapped_peptide_pdb, gaps) - - # Count TER records - ter_count = sum( - 1 for line in result.split("\n") if line.startswith("TER") - ) - assert ter_count >= 1 - - class TestGetGapSummary: """Tests for get_gap_summary function.""" diff --git a/tests/test_ligand_utils.py b/tests/test_ligand_utils.py new file mode 100644 index 0000000..d4c7b9d --- /dev/null +++ b/tests/test_ligand_utils.py @@ -0,0 +1,228 @@ +"""Unit tests for ligand utilities.""" + +import pytest + +from graphrelax.ligand_utils import ( + _PDBE_SMILES_CACHE, + WATER_RESIDUES, + LigandInfo, + extract_ligands_from_pdb, + fetch_pdbe_smiles, + get_ion_smiles, + is_single_atom_ligand, +) + + +class TestExtractLigands: + """Tests for ligand extraction from PDB.""" + + def test_extract_no_ligands(self, small_peptide_pdb_string): + """Test extraction from protein-only PDB.""" + protein_pdb, ligands = extract_ligands_from_pdb( + small_peptide_pdb_string + ) + + assert len(ligands) == 0 + assert "ATOM" in protein_pdb + assert "END" in protein_pdb + + def test_extract_with_ligand(self): + """Test extraction of a ligand from PDB.""" + # fmt: off + pdb_with_ligand = ( + "ATOM 1 N ALA A 1 0.0 0.0 0.0 1.00 0.00\n" + "ATOM 2 CA ALA A 1 1.5 0.0 0.0 1.00 0.00\n" + "ATOM 3 C ALA A 1 2.0 1.4 0.0 1.00 0.00\n" + "ATOM 4 O ALA A 1 1.2 2.4 0.0 1.00 0.00\n" + "HETATM 5 FE HEM A 200 5.0 5.0 5.0 1.00 0.00\n" + "HETATM 6 NA HEM A 200 3.5 5.0 5.0 1.00 0.00\n" + "HETATM 7 NB HEM A 200 5.0 3.5 5.0 1.00 0.00\n" + "END\n" + ) + # fmt: on + protein_pdb, ligands = extract_ligands_from_pdb(pdb_with_ligand) + + assert len(ligands) == 1 + assert ligands[0].resname == "HEM" + assert ligands[0].chain_id == "A" + assert ligands[0].resnum == 200 + assert len(ligands[0].pdb_lines) == 3 # FE, NA, NB + + # Check protein PDB doesn't have HETATM + assert "HETATM" not in protein_pdb + assert "ATOM" in protein_pdb + + def test_water_excluded(self): + """Test that water is not extracted as ligand.""" + # fmt: off + pdb_with_water = ( + "ATOM 1 N ALA A 1 0.0 0.0 0.0 1.00 0.00\n" + "ATOM 2 CA ALA A 1 1.5 0.0 0.0 1.00 0.00\n" + "HETATM 5 O HOH A 301 10.0 10.0 10.0 1.00 0.00\n" + "HETATM 6 O WAT A 302 11.0 10.0 10.0 1.00 0.00\n" + "HETATM 7 O SOL A 303 12.0 10.0 10.0 1.00 0.00\n" + "END\n" + ) + # fmt: on + protein_pdb, ligands = extract_ligands_from_pdb(pdb_with_water) + + # Waters should not be in ligand list + assert len(ligands) == 0 + assert not any(lig.resname in WATER_RESIDUES for lig in ligands) + + def test_multiple_ligands(self): + """Test extraction of multiple different ligands.""" + # fmt: off + pdb_with_multiple = ( + "ATOM 1 N ALA A 1 0.0 0.0 0.0 1.00 0.00\n" + "HETATM 5 FE HEM A 200 5.0 5.0 5.0 1.00 0.00\n" + "HETATM 6 N1 NAD B 301 8.0 8.0 8.0 1.00 0.00\n" + "END\n" + ) + # fmt: on + protein_pdb, ligands = extract_ligands_from_pdb(pdb_with_multiple) + + assert len(ligands) == 2 + resnames = {lig.resname for lig in ligands} + assert resnames == {"HEM", "NAD"} + + +class TestIonSmiles: + """Tests for ion SMILES lookup.""" + + def test_common_ions(self): + """Test that common metal ions are defined.""" + smiles = get_ion_smiles() + ions = ["ZN", "MG", "CA", "FE", "MN", "CU"] + for ion in ions: + assert ion in smiles, f"{ion} should be in ion SMILES" + + def test_ion_smiles_format(self): + """Test that ion SMILES have proper format with charge.""" + smiles = get_ion_smiles() + # All ions should have brackets indicating charged species + for ion, smi in smiles.items(): + assert smi.startswith("["), f"{ion} SMILES should start with [" + assert smi.endswith("]"), f"{ion} SMILES should end with ]" + + def test_halide_ions(self): + """Test that halide ions are defined.""" + smiles = get_ion_smiles() + assert "CL" in smiles + assert smiles["CL"] == "[Cl-]" + + +class TestSingleAtomLigand: + """Tests for single atom ligand detection.""" + + def test_ion_is_single_atom(self): + """Test that ions are detected as single atoms.""" + ligand = LigandInfo( + resname="ZN", + chain_id="A", + resnum=1, + pdb_lines=["HETATM 1 ZN ZN A 1 0.0 0.0 0.0"], + ) + assert is_single_atom_ligand(ligand) + + def test_heme_is_not_single_atom(self): + """Test that multi-atom ligands are not single atoms.""" + ligand = LigandInfo( + resname="HEM", + chain_id="A", + resnum=1, + pdb_lines=[ + "HETATM 1 FE HEM A 1 0.0 0.0 0.0", + "HETATM 2 NA HEM A 1 1.0 0.0 0.0", + ], + ) + assert not is_single_atom_ligand(ligand) + + +class TestLigandInfo: + """Tests for LigandInfo dataclass.""" + + def test_ligand_info_creation(self): + """Test creating a LigandInfo object.""" + ligand = LigandInfo( + resname="HEM", + chain_id="A", + resnum=200, + pdb_lines=[ + "HETATM 5 FE HEM A 200 5.000 5.000 5.000" + ], + ) + assert ligand.resname == "HEM" + assert ligand.chain_id == "A" + assert ligand.resnum == 200 + assert ligand.smiles is None + + def test_ligand_info_with_smiles(self): + """Test LigandInfo with optional SMILES.""" + ligand = LigandInfo( + resname="BEN", + chain_id="B", + resnum=1, + pdb_lines=[ + "HETATM 1 C1 BEN B 1 0.000 0.000 0.000" + ], + smiles="c1ccccc1", + ) + assert ligand.smiles == "c1ccccc1" + + +class TestWaterResidues: + """Tests for water residue constants.""" + + def test_common_water_names(self): + """Test that common water names are included.""" + assert "HOH" in WATER_RESIDUES + assert "WAT" in WATER_RESIDUES + assert "SOL" in WATER_RESIDUES + + def test_tip_models(self): + """Test that TIP water models are included.""" + assert "TIP3" in WATER_RESIDUES + assert "TIP4" in WATER_RESIDUES + assert "SPC" in WATER_RESIDUES + + +class TestFetchPdbeSmiles: + """Tests for PDBe SMILES fetching.""" + + def setup_method(self): + """Clear cache before each test.""" + _PDBE_SMILES_CACHE.clear() + + @pytest.mark.network + def test_fetch_known_ligand(self): + """Test fetching SMILES for a known PDB ligand.""" + # ATP is a well-known ligand + smiles = fetch_pdbe_smiles("ATP") + assert smiles is not None + assert len(smiles) > 0 + # Cache should be populated + assert "ATP" in _PDBE_SMILES_CACHE + + @pytest.mark.network + def test_fetch_unknown_ligand(self): + """Test fetching SMILES for a non-existent ligand.""" + smiles = fetch_pdbe_smiles("ZZZZZ") + assert smiles is None + # Failure should also be cached + assert "ZZZZZ" in _PDBE_SMILES_CACHE + assert _PDBE_SMILES_CACHE["ZZZZZ"] is None + + def test_cache_hit(self): + """Test that cached SMILES are returned.""" + # Pre-populate cache + _PDBE_SMILES_CACHE["CACHED"] = "CCC" + smiles = fetch_pdbe_smiles("CACHED") + assert smiles == "CCC" + + def test_case_insensitive(self): + """Test that lookup is case-insensitive.""" + _PDBE_SMILES_CACHE["TEST"] = "C=O" + # Lowercase should find uppercase cache entry + smiles = fetch_pdbe_smiles("test") + assert smiles == "C=O" diff --git a/tests/test_relaxer_integration.py b/tests/test_relaxer_integration.py index d62d92d..fb37596 100644 --- a/tests/test_relaxer_integration.py +++ b/tests/test_relaxer_integration.py @@ -20,29 +20,44 @@ def relaxer(): return Relaxer(config) +@pytest.fixture +def unconstrained_relaxer(): + """Create a Relaxer instance with unconstrained minimization.""" + from graphrelax.relaxer import Relaxer + + config = RelaxConfig(max_iterations=50, stiffness=10.0, constrained=False) + return Relaxer(config) + + @pytest.mark.integration class TestRelaxerGPUDetection: """Tests for GPU detection logic.""" - def test_check_gpu_available_returns_bool(self, relaxer): + def test_check_gpu_available_returns_bool(self): """Test that GPU check returns a boolean.""" - result = relaxer._check_gpu_available() + from graphrelax.utils import check_gpu_available + + result = check_gpu_available() assert isinstance(result, bool) - def test_check_gpu_available_cached(self, relaxer): + def test_check_gpu_available_cached(self): """Test that GPU check result is cached.""" - result1 = relaxer._check_gpu_available() - result2 = relaxer._check_gpu_available() + from graphrelax.utils import check_gpu_available + + result1 = check_gpu_available() + result2 = check_gpu_available() assert result1 == result2 @pytest.mark.integration -class TestRelaxDirect: - """Tests for direct OpenMM minimization.""" +class TestRelaxUnconstrained: + """Tests for unconstrained OpenMM minimization.""" - def test_relax_direct_runs(self, relaxer, small_peptide_pdb_string): - """Test that direct relaxation completes without error.""" - relaxed_pdb, debug_info, violations = relaxer._relax_direct( + def test_relax_unconstrained_runs( + self, unconstrained_relaxer, small_peptide_pdb_string + ): + """Test that unconstrained relaxation completes without error.""" + relaxed_pdb, debug_info, violations = unconstrained_relaxer.relax( small_peptide_pdb_string ) @@ -50,39 +65,44 @@ def test_relax_direct_runs(self, relaxer, small_peptide_pdb_string): assert isinstance(relaxed_pdb, str) assert len(relaxed_pdb) > 0 - def test_relax_direct_returns_debug_info( - self, relaxer, small_peptide_pdb_string + def test_relax_unconstrained_returns_debug_info( + self, unconstrained_relaxer, small_peptide_pdb_string ): """Test that debug info contains expected keys.""" - _, debug_info, _ = relaxer._relax_direct(small_peptide_pdb_string) + _, debug_info, _ = unconstrained_relaxer.relax(small_peptide_pdb_string) assert "initial_energy" in debug_info assert "final_energy" in debug_info assert "rmsd" in debug_info - assert "attempts" in debug_info - def test_relax_direct_energy_types(self, relaxer, small_peptide_pdb_string): + def test_relax_unconstrained_energy_types( + self, unconstrained_relaxer, small_peptide_pdb_string + ): """Test that energy values are numeric.""" - _, debug_info, _ = relaxer._relax_direct(small_peptide_pdb_string) + _, debug_info, _ = unconstrained_relaxer.relax(small_peptide_pdb_string) assert isinstance(debug_info["initial_energy"], (int, float)) assert isinstance(debug_info["final_energy"], (int, float)) assert isinstance(debug_info["rmsd"], (int, float)) - def test_relax_direct_pdb_format(self, relaxer, small_peptide_pdb_string): + def test_relax_unconstrained_pdb_format( + self, unconstrained_relaxer, small_peptide_pdb_string + ): """Test that output is valid PDB format.""" - relaxed_pdb, _, _ = relaxer._relax_direct(small_peptide_pdb_string) + relaxed_pdb, _, _ = unconstrained_relaxer.relax( + small_peptide_pdb_string + ) # Should contain ATOM records assert "ATOM" in relaxed_pdb or "HETATM" in relaxed_pdb - def test_relax_direct_violations_array( - self, relaxer, small_peptide_pdb_string + def test_relax_unconstrained_violations_array( + self, unconstrained_relaxer, small_peptide_pdb_string ): """Test that violations is a numpy array.""" import numpy as np - _, _, violations = relaxer._relax_direct(small_peptide_pdb_string) + _, _, violations = unconstrained_relaxer.relax(small_peptide_pdb_string) assert isinstance(violations, np.ndarray) @@ -135,30 +155,28 @@ def test_get_energy_breakdown_has_total( class TestRelaxerConfig: """Tests for Relaxer with different configurations.""" - def test_high_stiffness(self, small_peptide_pdb_string): - """Test relaxation with high stiffness (more restrained).""" + def test_high_stiffness_constrained(self, small_peptide_pdb_string): + """Test constrained relaxation with high stiffness (more restrained).""" from graphrelax.relaxer import Relaxer - config = RelaxConfig(max_iterations=50, stiffness=100.0) + config = RelaxConfig( + max_iterations=50, stiffness=100.0, constrained=True + ) relaxer = Relaxer(config) - relaxed_pdb, debug_info, _ = relaxer._relax_direct( - small_peptide_pdb_string - ) + relaxed_pdb, debug_info, _ = relaxer.relax(small_peptide_pdb_string) # High stiffness should result in small RMSD assert debug_info["rmsd"] < 1.0 # Less than 1 Angstrom - def test_zero_stiffness(self, small_peptide_pdb_string): - """Test relaxation with no restraints.""" + def test_unconstrained_minimization(self, small_peptide_pdb_string): + """Test unconstrained minimization (no restraints).""" from graphrelax.relaxer import Relaxer - config = RelaxConfig(max_iterations=50, stiffness=0.0) + config = RelaxConfig(max_iterations=50, constrained=False) relaxer = Relaxer(config) - relaxed_pdb, debug_info, _ = relaxer._relax_direct( - small_peptide_pdb_string - ) + relaxed_pdb, debug_info, _ = relaxer.relax(small_peptide_pdb_string) assert relaxed_pdb is not None @@ -166,11 +184,11 @@ def test_limited_iterations(self, small_peptide_pdb_string): """Test relaxation with limited iterations.""" from graphrelax.relaxer import Relaxer - config = RelaxConfig(max_iterations=10, stiffness=10.0) + config = RelaxConfig( + max_iterations=10, stiffness=10.0, constrained=False + ) relaxer = Relaxer(config) - relaxed_pdb, debug_info, _ = relaxer._relax_direct( - small_peptide_pdb_string - ) + relaxed_pdb, debug_info, _ = relaxer.relax(small_peptide_pdb_string) assert relaxed_pdb is not None