diff --git a/src/graphrelax/LigandMPNN/openfold/np/relax/amber_minimize.py b/src/graphrelax/LigandMPNN/openfold/np/relax/amber_minimize.py index 374b3b3..5b752ff 100644 --- a/src/graphrelax/LigandMPNN/openfold/np/relax/amber_minimize.py +++ b/src/graphrelax/LigandMPNN/openfold/np/relax/amber_minimize.py @@ -30,6 +30,7 @@ residue_constants, ) from openfold.np.relax import cleanup, utils +from tqdm import tqdm try: from simtk import openmm, unit @@ -487,11 +488,14 @@ def _run_one_iteration( start = time.perf_counter() minimized = False attempts = 0 + while not minimized and attempts < max_attempts: attempts += 1 try: logging.info( - "Minimizing protein, attempt %d of %d.", attempts, max_attempts + "Minimizing protein, attempt %d of %d.", + attempts, + max_attempts, ) ret = _openmm_minimize( pdb_string, @@ -506,6 +510,7 @@ def _run_one_iteration( except Exception as e: # pylint: disable=broad-except print(e) logging.info(e) + if not minimized: raise ValueError(f"Minimization failed after {max_attempts} attempts.") ret["opt_time"] = time.perf_counter() - start @@ -525,6 +530,7 @@ def run_pipeline( max_attempts: int = 100, checks: bool = True, exclude_residues: Optional[Sequence[int]] = None, + pbar: Optional[tqdm] = None, ): """Run iterative amber relax. @@ -548,6 +554,7 @@ def run_pipeline( checks: Whether to perform cleaning checks. exclude_residues: An optional list of zero-indexed residues to exclude from restraints. + pbar: Optional progress bar for status updates. Returns: out: A dictionary of output values. @@ -607,6 +614,7 @@ def run_pipeline( ret["num_exclusions"], ) iteration += 1 + return ret diff --git a/src/graphrelax/LigandMPNN/openfold/np/relax/relax.py b/src/graphrelax/LigandMPNN/openfold/np/relax/relax.py index 14ed73a..4ca3553 100644 --- a/src/graphrelax/LigandMPNN/openfold/np/relax/relax.py +++ b/src/graphrelax/LigandMPNN/openfold/np/relax/relax.py @@ -16,11 +16,12 @@ # limitations under the License. """Amber relaxation.""" -from typing import Any, Dict, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple import numpy as np from openfold.np import protein from openfold.np.relax import amber_minimize, utils +from tqdm import tqdm class AmberRelaxation(object): @@ -61,7 +62,7 @@ def __init__( self._use_gpu = use_gpu def process( - self, *, prot: protein.Protein + self, *, prot: protein.Protein, pbar: Optional[tqdm] = None ) -> Tuple[str, Dict[str, Any], np.ndarray]: """Runs Amber relax on a prediction, adds hydrogens, returns PDB string.""" out = amber_minimize.run_pipeline( @@ -72,6 +73,7 @@ def process( exclude_residues=self._exclude_residues, max_outer_iterations=self._max_outer_iterations, use_gpu=self._use_gpu, + pbar=pbar, ) min_pos = out["pos"] start_pos = out["posinit"] diff --git a/src/graphrelax/LigandMPNN/sc_utils.py b/src/graphrelax/LigandMPNN/sc_utils.py index fde88ad..c457067 100644 --- a/src/graphrelax/LigandMPNN/sc_utils.py +++ b/src/graphrelax/LigandMPNN/sc_utils.py @@ -1,6 +1,7 @@ # flake8: noqa # This is vendored code from LigandMPNN - style issues are preserved from upstream import sys +from typing import Optional import numpy as np import torch @@ -27,6 +28,7 @@ ) from openfold.utils import feats from openfold.utils.rigid_utils import Rigid +from tqdm import tqdm torch_pi = torch.tensor(np.pi, device="cpu") @@ -66,6 +68,7 @@ def pack_side_chains( num_samples=10, repack_everything=True, num_context_atoms=16, + pbar: Optional[tqdm] = None, ): device = feature_dict["X"].device torsion_dict = make_torsion_features(feature_dict, repack_everything) @@ -101,6 +104,7 @@ def pack_side_chains( feature_dict["h_V"] = h_V feature_dict["h_E"] = h_E feature_dict["E_idx"] = E_idx + for step in range(num_denoising_steps): mean, concentration, mix_logits = model_sc.decode(feature_dict) mix = D.Categorical(logits=mix_logits) diff --git a/src/graphrelax/designer.py b/src/graphrelax/designer.py index 37e7f7e..e11f530 100644 --- a/src/graphrelax/designer.py +++ b/src/graphrelax/designer.py @@ -8,6 +8,7 @@ import numpy as np import torch +from tqdm import tqdm from graphrelax.config import DesignConfig from graphrelax.resfile import ALL_AAS, DesignSpec, ResidueMode @@ -130,6 +131,7 @@ def design( pdb_path: Path, design_spec: Optional[DesignSpec] = None, design_all: bool = False, + pbar: Optional[tqdm] = None, ) -> dict: """ Run sequence design on a structure. @@ -138,6 +140,7 @@ def design( pdb_path: Path to input PDB design_spec: Specification of which residues to design/fix design_all: If True, design all residues (full redesign) + pbar: Optional progress bar for status updates Returns: Dictionary with designed sequence, structure, and scores @@ -214,6 +217,7 @@ def design( self.config.sc_num_denoising_steps, self.config.sc_num_samples, repack_everything=False, + pbar=pbar, ) output_dict.update(sc_dict) @@ -249,6 +253,7 @@ def repack( self, pdb_path: Path, design_spec: Optional[DesignSpec] = None, + pbar: Optional[tqdm] = None, ) -> dict: """ Repack side chains without changing sequence. @@ -256,6 +261,7 @@ def repack( Args: pdb_path: Path to input PDB design_spec: Specification (NATRO residues excluded from repacking) + pbar: Optional progress bar for status updates Returns: Dictionary with repacked structure @@ -299,6 +305,7 @@ def repack( self.config.sc_num_denoising_steps, self.config.sc_num_samples, repack_everything=True, + pbar=pbar, ) # Get sequence diff --git a/src/graphrelax/pipeline.py b/src/graphrelax/pipeline.py index d76dcc4..cc2407b 100644 --- a/src/graphrelax/pipeline.py +++ b/src/graphrelax/pipeline.py @@ -5,6 +5,8 @@ from pathlib import Path from typing import Optional +from tqdm import tqdm + from graphrelax.config import PipelineConfig, PipelineMode from graphrelax.designer import Designer from graphrelax.idealize import idealize_structure @@ -92,60 +94,72 @@ def run( all_results = [] all_scores = [] - for output_idx in range(1, self.config.n_outputs + 1): - logger.info( - f"Generating output {output_idx}/{self.config.n_outputs}" - ) - - result = self._run_single_output( - input_pdb=input_pdb, - design_spec=design_spec, - design_all=design_all, - input_format=input_format, - ) - - # Format output path - out_path = format_output_path( - output_pdb, output_idx, self.config.n_outputs - ) + # Calculate total steps for unified progress bar + total_steps = self.config.n_outputs * self.config.n_iterations + pbar = None + if total_steps > 1: + pbar = tqdm(total=total_steps, desc="Progress", unit="step") - # Convert to target format if needed and save - final_structure = result["final_pdb"] - if target_format != StructureFormat.PDB: - final_structure = convert_to_format( - final_structure, target_format + try: + for output_idx in range(1, self.config.n_outputs + 1): + logger.info( + f"Generating output {output_idx}/{self.config.n_outputs}" ) - save_pdb_string(final_structure, out_path) - - result["output_path"] = out_path - all_results.append(result) - - # Collect scores - score_dict = { - "total_score": result.get("final_energy", 0.0), - "openmm_energy": result.get("final_energy", 0.0), - } - - # Add energy breakdown if available - if "energy_breakdown" in result: - for key, val in result["energy_breakdown"].items(): - if key != "total_energy": - score_dict[key] = val - - # Add LigandMPNN score - if "ligandmpnn_loss" in result: - score_dict["ligandmpnn_score"] = compute_ligandmpnn_score( - result["ligandmpnn_loss"] + + result = self._run_single_output( + input_pdb=input_pdb, + design_spec=design_spec, + design_all=design_all, + input_format=input_format, + output_idx=output_idx, + pbar=pbar, ) - # Add sequence recovery - if result.get("sequence") and result.get("native_sequence"): - score_dict["seq_recovery"] = compute_sequence_recovery( - result["sequence"], result["native_sequence"] + # Format output path + out_path = format_output_path( + output_pdb, output_idx, self.config.n_outputs ) - score_dict["description"] = out_path.name - all_scores.append(score_dict) + # Convert to target format if needed and save + final_structure = result["final_pdb"] + if target_format != StructureFormat.PDB: + final_structure = convert_to_format( + final_structure, target_format + ) + save_pdb_string(final_structure, out_path) + + result["output_path"] = out_path + all_results.append(result) + + # Collect scores + score_dict = { + "total_score": result.get("final_energy", 0.0), + "openmm_energy": result.get("final_energy", 0.0), + } + + # Add energy breakdown if available + if "energy_breakdown" in result: + for key, val in result["energy_breakdown"].items(): + if key != "total_energy": + score_dict[key] = val + + # Add LigandMPNN score + if "ligandmpnn_loss" in result: + score_dict["ligandmpnn_score"] = compute_ligandmpnn_score( + result["ligandmpnn_loss"] + ) + + # Add sequence recovery + if result.get("sequence") and result.get("native_sequence"): + score_dict["seq_recovery"] = compute_sequence_recovery( + result["sequence"], result["native_sequence"] + ) + + score_dict["description"] = out_path.name + all_scores.append(score_dict) + finally: + if pbar: + pbar.close() # Write scorefile if requested if self.config.scorefile and all_scores: @@ -162,6 +176,8 @@ def _run_single_output( design_spec: Optional[DesignSpec], design_all: bool, input_format: StructureFormat = StructureFormat.PDB, + output_idx: int = 1, + pbar: Optional[tqdm] = None, ) -> dict: """Run pipeline for a single output.""" result = { @@ -215,12 +231,20 @@ def _run_single_output( f" Iteration {iteration}/{self.config.n_iterations}" ) + # Update progress bar description + if pbar: + pbar.set_description( + f"Output {output_idx}/{self.config.n_outputs}, " + f"iter {iteration}/{self.config.n_iterations}" + ) + iter_result = self._run_iteration( pdb_path=current_pdb_path, design_spec=design_spec, design_all=design_all, iteration=iteration, original_native_sequence=original_native_sequence, + pbar=pbar, ) result["iterations"].append(iter_result) @@ -247,6 +271,10 @@ def _run_single_output( f"RMSD={info['rmsd']:.3f}" ) + # Increment progress bar after each iteration completes + if pbar: + pbar.update(1) + # Store final results result["final_pdb"] = current_pdb if result["iterations"]: @@ -273,6 +301,7 @@ def _run_iteration( design_all: bool, iteration: int, original_native_sequence: Optional[str] = None, + pbar: Optional[tqdm] = None, ) -> dict: """Run a single iteration of design/repack + relax.""" iter_result = {} @@ -281,10 +310,13 @@ def _run_iteration( if self.config.mode in (PipelineMode.DESIGN, PipelineMode.DESIGN_ONLY): # Design mode logger.debug(" Running design...") + if pbar: + pbar.set_postfix_str("designing") design_result = self.designer.design( pdb_path=pdb_path, design_spec=design_spec, design_all=design_all, + pbar=pbar, ) pdb_string = self.designer.result_to_pdb_string(design_result) @@ -305,9 +337,12 @@ def _run_iteration( elif self.config.mode in (PipelineMode.RELAX, PipelineMode.REPACK_ONLY): # Repack mode logger.debug(" Running repack...") + if pbar: + pbar.set_postfix_str("repacking") repack_result = self.designer.repack( pdb_path=pdb_path, design_spec=design_spec, + pbar=pbar, ) pdb_string = self.designer.result_to_pdb_string(repack_result) @@ -328,7 +363,11 @@ def _run_iteration( PipelineMode.DESIGN, ): logger.debug(" Running relaxation...") - relaxed_pdb, relax_info, violations = self.relaxer.relax(pdb_string) + if pbar: + pbar.set_postfix_str("relaxing") + relaxed_pdb, relax_info, violations = self.relaxer.relax( + pdb_string, pbar=pbar + ) iter_result["output_pdb"] = relaxed_pdb iter_result["relax_info"] = relax_info @@ -343,4 +382,8 @@ def _run_iteration( # No relaxation iter_result["output_pdb"] = pdb_string + # Clear postfix when done with iteration + if pbar: + pbar.set_postfix_str("") + return iter_result diff --git a/src/graphrelax/relaxer.py b/src/graphrelax/relaxer.py index 3407ef3..327b844 100644 --- a/src/graphrelax/relaxer.py +++ b/src/graphrelax/relaxer.py @@ -10,6 +10,7 @@ from openmm import Platform from openmm import app as openmm_app from openmm import openmm, unit +from tqdm import tqdm from graphrelax.chain_gaps import ( detect_chain_gaps, @@ -54,7 +55,9 @@ def _check_gpu_available(self) -> bool: logger.info("OpenMM CUDA not available, using CPU") return False - def relax(self, pdb_string: str) -> Tuple[str, dict, np.ndarray]: + def relax( + self, pdb_string: str, pbar: Optional[tqdm] = None + ) -> Tuple[str, dict, np.ndarray]: """ Relax a structure from PDB string. @@ -70,6 +73,7 @@ def relax(self, pdb_string: str) -> Tuple[str, dict, np.ndarray]: Args: pdb_string: PDB file contents as string + pbar: Optional progress bar for status updates Returns: Tuple of (relaxed_pdb_string, debug_info, violations) @@ -93,7 +97,9 @@ def relax(self, pdb_string: str) -> Tuple[str, dict, np.ndarray]: if self.config.constrained: prot = protein.from_pdb_string(protein_pdb) - relaxed_pdb, debug_info, violations = self.relax_protein(prot) + relaxed_pdb, debug_info, violations = self.relax_protein( + prot, pbar=pbar + ) else: relaxed_pdb, debug_info, violations = self._relax_unconstrained( protein_pdb @@ -126,12 +132,15 @@ def relax_pdb_file(self, pdb_path: Path) -> Tuple[str, dict, np.ndarray]: pdb_string = f.read() return self.relax(pdb_string) - def relax_protein(self, prot) -> Tuple[str, dict, np.ndarray]: + def relax_protein( + self, prot, pbar: Optional[tqdm] = None + ) -> Tuple[str, dict, np.ndarray]: """ Relax a Protein object using OpenFold's AmberRelaxation. Args: prot: OpenFold Protein object + pbar: Optional progress bar for status updates Returns: Tuple of (relaxed_pdb_string, debug_info, violations) @@ -152,7 +161,9 @@ def relax_protein(self, prot) -> Tuple[str, dict, np.ndarray]: f"stiffness={self.config.stiffness}, gpu={use_gpu})" ) - relaxed_pdb, debug_data, violations = relaxer.process(prot=prot) + relaxed_pdb, debug_data, violations = relaxer.process( + prot=prot, pbar=pbar + ) logger.info( f"Relaxation complete: E_init={debug_data['initial_energy']:.2f}, "