Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/graphrelax/LigandMPNN/openfold/np/relax/amber_minimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
residue_constants,
)
from openfold.np.relax import cleanup, utils
from tqdm import tqdm

try:
from simtk import openmm, unit
Expand Down Expand Up @@ -487,11 +488,14 @@ def _run_one_iteration(
start = time.perf_counter()
minimized = False
attempts = 0

while not minimized and attempts < max_attempts:
attempts += 1
try:
logging.info(
"Minimizing protein, attempt %d of %d.", attempts, max_attempts
"Minimizing protein, attempt %d of %d.",
attempts,
max_attempts,
)
ret = _openmm_minimize(
pdb_string,
Expand All @@ -506,6 +510,7 @@ def _run_one_iteration(
except Exception as e: # pylint: disable=broad-except
print(e)
logging.info(e)

if not minimized:
raise ValueError(f"Minimization failed after {max_attempts} attempts.")
ret["opt_time"] = time.perf_counter() - start
Expand All @@ -525,6 +530,7 @@ def run_pipeline(
max_attempts: int = 100,
checks: bool = True,
exclude_residues: Optional[Sequence[int]] = None,
pbar: Optional[tqdm] = None,
):
"""Run iterative amber relax.

Expand All @@ -548,6 +554,7 @@ def run_pipeline(
checks: Whether to perform cleaning checks.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
pbar: Optional progress bar for status updates.

Returns:
out: A dictionary of output values.
Expand Down Expand Up @@ -607,6 +614,7 @@ def run_pipeline(
ret["num_exclusions"],
)
iteration += 1

return ret


Expand Down
6 changes: 4 additions & 2 deletions src/graphrelax/LigandMPNN/openfold/np/relax/relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
# limitations under the License.

"""Amber relaxation."""
from typing import Any, Dict, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple

import numpy as np
from openfold.np import protein
from openfold.np.relax import amber_minimize, utils
from tqdm import tqdm


class AmberRelaxation(object):
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(
self._use_gpu = use_gpu

def process(
self, *, prot: protein.Protein
self, *, prot: protein.Protein, pbar: Optional[tqdm] = None
) -> Tuple[str, Dict[str, Any], np.ndarray]:
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
out = amber_minimize.run_pipeline(
Expand All @@ -72,6 +73,7 @@ def process(
exclude_residues=self._exclude_residues,
max_outer_iterations=self._max_outer_iterations,
use_gpu=self._use_gpu,
pbar=pbar,
)
min_pos = out["pos"]
start_pos = out["posinit"]
Expand Down
4 changes: 4 additions & 0 deletions src/graphrelax/LigandMPNN/sc_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# flake8: noqa
# This is vendored code from LigandMPNN - style issues are preserved from upstream
import sys
from typing import Optional

import numpy as np
import torch
Expand All @@ -27,6 +28,7 @@
)
from openfold.utils import feats
from openfold.utils.rigid_utils import Rigid
from tqdm import tqdm

torch_pi = torch.tensor(np.pi, device="cpu")

Expand Down Expand Up @@ -66,6 +68,7 @@ def pack_side_chains(
num_samples=10,
repack_everything=True,
num_context_atoms=16,
pbar: Optional[tqdm] = None,
):
device = feature_dict["X"].device
torsion_dict = make_torsion_features(feature_dict, repack_everything)
Expand Down Expand Up @@ -101,6 +104,7 @@ def pack_side_chains(
feature_dict["h_V"] = h_V
feature_dict["h_E"] = h_E
feature_dict["E_idx"] = E_idx

for step in range(num_denoising_steps):
mean, concentration, mix_logits = model_sc.decode(feature_dict)
mix = D.Categorical(logits=mix_logits)
Expand Down
7 changes: 7 additions & 0 deletions src/graphrelax/designer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import torch
from tqdm import tqdm

from graphrelax.config import DesignConfig
from graphrelax.resfile import ALL_AAS, DesignSpec, ResidueMode
Expand Down Expand Up @@ -130,6 +131,7 @@ def design(
pdb_path: Path,
design_spec: Optional[DesignSpec] = None,
design_all: bool = False,
pbar: Optional[tqdm] = None,
) -> dict:
"""
Run sequence design on a structure.
Expand All @@ -138,6 +140,7 @@ def design(
pdb_path: Path to input PDB
design_spec: Specification of which residues to design/fix
design_all: If True, design all residues (full redesign)
pbar: Optional progress bar for status updates

Returns:
Dictionary with designed sequence, structure, and scores
Expand Down Expand Up @@ -214,6 +217,7 @@ def design(
self.config.sc_num_denoising_steps,
self.config.sc_num_samples,
repack_everything=False,
pbar=pbar,
)
output_dict.update(sc_dict)

Expand Down Expand Up @@ -249,13 +253,15 @@ def repack(
self,
pdb_path: Path,
design_spec: Optional[DesignSpec] = None,
pbar: Optional[tqdm] = None,
) -> dict:
"""
Repack side chains without changing sequence.

Args:
pdb_path: Path to input PDB
design_spec: Specification (NATRO residues excluded from repacking)
pbar: Optional progress bar for status updates

Returns:
Dictionary with repacked structure
Expand Down Expand Up @@ -299,6 +305,7 @@ def repack(
self.config.sc_num_denoising_steps,
self.config.sc_num_samples,
repack_everything=True,
pbar=pbar,
)

# Get sequence
Expand Down
Loading
Loading