Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"ml-collections",
"dm-tree",
"openmm",
"tqdm",
]

# Note: pdbfixer is not on PyPI, install via conda:
Expand Down
6 changes: 6 additions & 0 deletions src/graphrelax/LigandMPNN/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@
# This is vendored code from LigandMPNN - style issues are preserved from upstream
from __future__ import print_function

import logging

import numpy as np
import torch
import torch.utils

# Completely disable all logging during prody import to suppress debug messages
logging.disable(logging.CRITICAL)
from prody import *

confProDy(verbosity="none")
logging.disable(logging.NOTSET)

restype_1to3 = {
"A": "ALA",
Expand Down
10 changes: 9 additions & 1 deletion src/graphrelax/LigandMPNN/openfold/np/relax/amber_minimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
residue_constants,
)
from openfold.np.relax import cleanup, utils
from tqdm import tqdm

try:
from simtk import openmm, unit
Expand Down Expand Up @@ -487,11 +488,14 @@ def _run_one_iteration(
start = time.perf_counter()
minimized = False
attempts = 0

while not minimized and attempts < max_attempts:
attempts += 1
try:
logging.info(
"Minimizing protein, attempt %d of %d.", attempts, max_attempts
"Minimizing protein, attempt %d of %d.",
attempts,
max_attempts,
)
ret = _openmm_minimize(
pdb_string,
Expand All @@ -506,6 +510,7 @@ def _run_one_iteration(
except Exception as e: # pylint: disable=broad-except
print(e)
logging.info(e)

if not minimized:
raise ValueError(f"Minimization failed after {max_attempts} attempts.")
ret["opt_time"] = time.perf_counter() - start
Expand All @@ -525,6 +530,7 @@ def run_pipeline(
max_attempts: int = 100,
checks: bool = True,
exclude_residues: Optional[Sequence[int]] = None,
pbar: Optional[tqdm] = None,
):
"""Run iterative amber relax.

Expand All @@ -548,6 +554,7 @@ def run_pipeline(
checks: Whether to perform cleaning checks.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
pbar: Optional progress bar for status updates.

Returns:
out: A dictionary of output values.
Expand Down Expand Up @@ -607,6 +614,7 @@ def run_pipeline(
ret["num_exclusions"],
)
iteration += 1

return ret


Expand Down
6 changes: 4 additions & 2 deletions src/graphrelax/LigandMPNN/openfold/np/relax/relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
# limitations under the License.

"""Amber relaxation."""
from typing import Any, Dict, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple

import numpy as np
from openfold.np import protein
from openfold.np.relax import amber_minimize, utils
from tqdm import tqdm


class AmberRelaxation(object):
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(
self._use_gpu = use_gpu

def process(
self, *, prot: protein.Protein
self, *, prot: protein.Protein, pbar: Optional[tqdm] = None
) -> Tuple[str, Dict[str, Any], np.ndarray]:
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
out = amber_minimize.run_pipeline(
Expand All @@ -72,6 +73,7 @@ def process(
exclude_residues=self._exclude_residues,
max_outer_iterations=self._max_outer_iterations,
use_gpu=self._use_gpu,
pbar=pbar,
)
min_pos = out["pos"]
start_pos = out["posinit"]
Expand Down
4 changes: 4 additions & 0 deletions src/graphrelax/LigandMPNN/sc_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# flake8: noqa
# This is vendored code from LigandMPNN - style issues are preserved from upstream
import sys
from typing import Optional

import numpy as np
import torch
Expand All @@ -27,6 +28,7 @@
)
from openfold.utils import feats
from openfold.utils.rigid_utils import Rigid
from tqdm import tqdm

torch_pi = torch.tensor(np.pi, device="cpu")

Expand Down Expand Up @@ -66,6 +68,7 @@ def pack_side_chains(
num_samples=10,
repack_everything=True,
num_context_atoms=16,
pbar: Optional[tqdm] = None,
):
device = feature_dict["X"].device
torsion_dict = make_torsion_features(feature_dict, repack_everything)
Expand Down Expand Up @@ -101,6 +104,7 @@ def pack_side_chains(
feature_dict["h_V"] = h_V
feature_dict["h_E"] = h_E
feature_dict["E_idx"] = E_idx

for step in range(num_denoising_steps):
mean, concentration, mix_logits = model_sc.decode(feature_dict)
mix = D.Categorical(logits=mix_logits)
Expand Down
30 changes: 30 additions & 0 deletions src/graphrelax/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,36 @@ def setup_logging(verbose: bool):
level=level,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
force=True, # Override any existing configuration
)


def configure_third_party_loggers(verbose: bool):
"""Configure third-party loggers after they've been imported."""
from prody import confProDy

if verbose:
# Enable ProDy logging when verbose
confProDy(verbosity="debug")
for name in [
"prody",
"prody.proteins",
"prody.atomic",
"prody.dynamics",
]:
logging.getLogger(name).setLevel(logging.DEBUG)
else:
# Ensure ProDy stays silent (already silenced at import, but reinforce)
confProDy(verbosity="none")
for name in [
"prody",
"prody.proteins",
"prody.atomic",
"prody.dynamics",
]:
logging.getLogger(name).setLevel(logging.ERROR)


def _check_for_ligands(input_path: Path, fmt) -> bool:
"""
Check if input structure has ligands (non-water HETATM records).
Expand Down Expand Up @@ -319,6 +346,9 @@ def main(args=None) -> int:
from graphrelax.pipeline import Pipeline
from graphrelax.structure_io import detect_format

# Configure third-party loggers after import
configure_third_party_loggers(opts.verbose)

# Determine mode
if opts.repack_only:
mode = PipelineMode.REPACK_ONLY
Expand Down
7 changes: 7 additions & 0 deletions src/graphrelax/designer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import torch
from tqdm import tqdm

from graphrelax.config import DesignConfig
from graphrelax.resfile import ALL_AAS, DesignSpec, ResidueMode
Expand Down Expand Up @@ -130,6 +131,7 @@ def design(
pdb_path: Path,
design_spec: Optional[DesignSpec] = None,
design_all: bool = False,
pbar: Optional[tqdm] = None,
) -> dict:
"""
Run sequence design on a structure.
Expand All @@ -138,6 +140,7 @@ def design(
pdb_path: Path to input PDB
design_spec: Specification of which residues to design/fix
design_all: If True, design all residues (full redesign)
pbar: Optional progress bar for status updates

Returns:
Dictionary with designed sequence, structure, and scores
Expand Down Expand Up @@ -214,6 +217,7 @@ def design(
self.config.sc_num_denoising_steps,
self.config.sc_num_samples,
repack_everything=False,
pbar=pbar,
)
output_dict.update(sc_dict)

Expand Down Expand Up @@ -249,13 +253,15 @@ def repack(
self,
pdb_path: Path,
design_spec: Optional[DesignSpec] = None,
pbar: Optional[tqdm] = None,
) -> dict:
"""
Repack side chains without changing sequence.

Args:
pdb_path: Path to input PDB
design_spec: Specification (NATRO residues excluded from repacking)
pbar: Optional progress bar for status updates

Returns:
Dictionary with repacked structure
Expand Down Expand Up @@ -299,6 +305,7 @@ def repack(
self.config.sc_num_denoising_steps,
self.config.sc_num_samples,
repack_everything=True,
pbar=pbar,
)

# Get sequence
Expand Down
Loading
Loading