From 53851e4dee3c9645b2d5da4224bd4f5ed1d9f84a Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 18 Jan 2026 19:18:09 +0000 Subject: [PATCH 1/2] Add tqdm progress bars for multi-iteration loops Show progress bars when iteration counts are greater than one: - Pipeline outputs loop (n_outputs > 1) - Pipeline iterations loop (n_iterations > 1) - AMBER relaxation outer iterations (max_outer_iterations > 1) - AMBER minimization retry attempts (max_attempts > 1) - Side-chain denoising steps (num_denoising_steps > 1) --- .../openfold/np/relax/amber_minimize.py | 144 +++++++++++------- src/graphrelax/LigandMPNN/sc_utils.py | 8 +- src/graphrelax/pipeline.py | 18 ++- 3 files changed, 110 insertions(+), 60 deletions(-) diff --git a/src/graphrelax/LigandMPNN/openfold/np/relax/amber_minimize.py b/src/graphrelax/LigandMPNN/openfold/np/relax/amber_minimize.py index 374b3b3..8a59c90 100644 --- a/src/graphrelax/LigandMPNN/openfold/np/relax/amber_minimize.py +++ b/src/graphrelax/LigandMPNN/openfold/np/relax/amber_minimize.py @@ -30,6 +30,7 @@ residue_constants, ) from openfold.np.relax import cleanup, utils +from tqdm import tqdm try: from simtk import openmm, unit @@ -487,25 +488,41 @@ def _run_one_iteration( start = time.perf_counter() minimized = False attempts = 0 - while not minimized and attempts < max_attempts: - attempts += 1 - try: - logging.info( - "Minimizing protein, attempt %d of %d.", attempts, max_attempts - ) - ret = _openmm_minimize( - pdb_string, - max_iterations=max_iterations, - tolerance=tolerance, - stiffness=stiffness, - restraint_set=restraint_set, - exclude_residues=exclude_residues, - use_gpu=use_gpu, - ) - minimized = True - except Exception as e: # pylint: disable=broad-except - print(e) - logging.info(e) + + pbar = None + if max_attempts > 1: + pbar = tqdm( + total=max_attempts, desc=" Minimization attempts", unit="attempt" + ) + + try: + while not minimized and attempts < max_attempts: + attempts += 1 + try: + logging.info( + "Minimizing protein, attempt %d of %d.", + attempts, + max_attempts, + ) + ret = _openmm_minimize( + pdb_string, + max_iterations=max_iterations, + tolerance=tolerance, + stiffness=stiffness, + restraint_set=restraint_set, + exclude_residues=exclude_residues, + use_gpu=use_gpu, + ) + minimized = True + except Exception as e: # pylint: disable=broad-except + print(e) + logging.info(e) + if pbar: + pbar.update(1) + finally: + if pbar: + pbar.close() + if not minimized: raise ValueError(f"Minimization failed after {max_attempts} attempts.") ret["opt_time"] = time.perf_counter() - start @@ -566,47 +583,60 @@ def run_pipeline( violations = np.inf iteration = 0 - while violations > 0 and iteration < max_outer_iterations: - ret = _run_one_iteration( - pdb_string=pdb_string, - exclude_residues=exclude_residues, - max_iterations=max_iterations, - tolerance=tolerance, - stiffness=stiffness, - restraint_set=restraint_set, - max_attempts=max_attempts, - use_gpu=use_gpu, + pbar = None + if max_outer_iterations > 1: + pbar = tqdm( + total=max_outer_iterations, desc=" Relax iterations", unit="iter" ) - headers = protein.get_pdb_headers(prot) - if len(headers) > 0: - ret["min_pdb"] = "\n".join(["\n".join(headers), ret["min_pdb"]]) + try: + while violations > 0 and iteration < max_outer_iterations: + ret = _run_one_iteration( + pdb_string=pdb_string, + exclude_residues=exclude_residues, + max_iterations=max_iterations, + tolerance=tolerance, + stiffness=stiffness, + restraint_set=restraint_set, + max_attempts=max_attempts, + use_gpu=use_gpu, + ) + + headers = protein.get_pdb_headers(prot) + if len(headers) > 0: + ret["min_pdb"] = "\n".join(["\n".join(headers), ret["min_pdb"]]) + + prot = protein.from_pdb_string(ret["min_pdb"]) + if place_hydrogens_every_iteration: + pdb_string = clean_protein(prot, checks=True) + else: + pdb_string = ret["min_pdb"] + ret.update(get_violation_metrics(prot)) + ret.update( + { + "num_exclusions": len(exclude_residues), + "iteration": iteration, + } + ) + violations = ret["violations_per_residue"] + exclude_residues = exclude_residues.union(ret["residue_violations"]) + + logging.info( + "Iteration completed: Einit %.2f Efinal %.2f Time %.2f s " + "num residue violations %d num residue exclusions %d ", + ret["einit"], + ret["efinal"], + ret["opt_time"], + ret["num_residue_violations"], + ret["num_exclusions"], + ) + iteration += 1 + if pbar: + pbar.update(1) + finally: + if pbar: + pbar.close() - prot = protein.from_pdb_string(ret["min_pdb"]) - if place_hydrogens_every_iteration: - pdb_string = clean_protein(prot, checks=True) - else: - pdb_string = ret["min_pdb"] - ret.update(get_violation_metrics(prot)) - ret.update( - { - "num_exclusions": len(exclude_residues), - "iteration": iteration, - } - ) - violations = ret["violations_per_residue"] - exclude_residues = exclude_residues.union(ret["residue_violations"]) - - logging.info( - "Iteration completed: Einit %.2f Efinal %.2f Time %.2f s " - "num residue violations %d num residue exclusions %d ", - ret["einit"], - ret["efinal"], - ret["opt_time"], - ret["num_residue_violations"], - ret["num_exclusions"], - ) - iteration += 1 return ret diff --git a/src/graphrelax/LigandMPNN/sc_utils.py b/src/graphrelax/LigandMPNN/sc_utils.py index fde88ad..b328e5f 100644 --- a/src/graphrelax/LigandMPNN/sc_utils.py +++ b/src/graphrelax/LigandMPNN/sc_utils.py @@ -27,6 +27,7 @@ ) from openfold.utils import feats from openfold.utils.rigid_utils import Rigid +from tqdm import tqdm torch_pi = torch.tensor(np.pi, device="cpu") @@ -101,7 +102,12 @@ def pack_side_chains( feature_dict["h_V"] = h_V feature_dict["h_E"] = h_E feature_dict["E_idx"] = E_idx - for step in range(num_denoising_steps): + + step_range = range(num_denoising_steps) + if num_denoising_steps > 1: + step_range = tqdm(step_range, desc=" Denoising steps", unit="step") + + for step in step_range: mean, concentration, mix_logits = model_sc.decode(feature_dict) mix = D.Categorical(logits=mix_logits) comp = D.VonMises(mean, concentration) diff --git a/src/graphrelax/pipeline.py b/src/graphrelax/pipeline.py index d76dcc4..0c5db7b 100644 --- a/src/graphrelax/pipeline.py +++ b/src/graphrelax/pipeline.py @@ -5,6 +5,8 @@ from pathlib import Path from typing import Optional +from tqdm import tqdm + from graphrelax.config import PipelineConfig, PipelineMode from graphrelax.designer import Designer from graphrelax.idealize import idealize_structure @@ -92,7 +94,13 @@ def run( all_results = [] all_scores = [] - for output_idx in range(1, self.config.n_outputs + 1): + output_range = range(1, self.config.n_outputs + 1) + if self.config.n_outputs > 1: + output_range = tqdm( + output_range, desc="Generating outputs", unit="output" + ) + + for output_idx in output_range: logger.info( f"Generating output {output_idx}/{self.config.n_outputs}" ) @@ -210,7 +218,13 @@ def _run_single_output( original_native_sequence = None try: - for iteration in range(1, self.config.n_iterations + 1): + iteration_range = range(1, self.config.n_iterations + 1) + if self.config.n_iterations > 1: + iteration_range = tqdm( + iteration_range, desc=" Iterations", unit="iter" + ) + + for iteration in iteration_range: logger.info( f" Iteration {iteration}/{self.config.n_iterations}" ) From 656ef2ea6d163f72be5ff9eb7d414c4b225bc97c Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 18 Jan 2026 20:05:24 +0000 Subject: [PATCH 2/2] Refactor to single unified tqdm progress bar Replace multiple nested progress bars with a single unified bar that tracks overall progress across all outputs and iterations. The bar shows: - Total progress: output X/Y, iter A/B - Current phase: designing/repacking/relaxing Changes: - pipeline.py: Create single progress bar with total = n_outputs * n_iterations - designer.py: Accept pbar parameter, pass to pack_side_chains - relaxer.py: Accept pbar parameter, pass to AmberRelaxation - relax.py: Accept pbar parameter, pass to amber_minimize - amber_minimize.py: Accept pbar parameter, remove standalone progress bars - sc_utils.py: Accept pbar parameter, remove standalone progress bar --- .../openfold/np/relax/amber_minimize.py | 144 +++++++---------- .../LigandMPNN/openfold/np/relax/relax.py | 6 +- src/graphrelax/LigandMPNN/sc_utils.py | 8 +- src/graphrelax/designer.py | 7 + src/graphrelax/pipeline.py | 153 +++++++++++------- src/graphrelax/relaxer.py | 19 ++- 6 files changed, 181 insertions(+), 156 deletions(-) diff --git a/src/graphrelax/LigandMPNN/openfold/np/relax/amber_minimize.py b/src/graphrelax/LigandMPNN/openfold/np/relax/amber_minimize.py index 8a59c90..5b752ff 100644 --- a/src/graphrelax/LigandMPNN/openfold/np/relax/amber_minimize.py +++ b/src/graphrelax/LigandMPNN/openfold/np/relax/amber_minimize.py @@ -489,39 +489,27 @@ def _run_one_iteration( minimized = False attempts = 0 - pbar = None - if max_attempts > 1: - pbar = tqdm( - total=max_attempts, desc=" Minimization attempts", unit="attempt" - ) - - try: - while not minimized and attempts < max_attempts: - attempts += 1 - try: - logging.info( - "Minimizing protein, attempt %d of %d.", - attempts, - max_attempts, - ) - ret = _openmm_minimize( - pdb_string, - max_iterations=max_iterations, - tolerance=tolerance, - stiffness=stiffness, - restraint_set=restraint_set, - exclude_residues=exclude_residues, - use_gpu=use_gpu, - ) - minimized = True - except Exception as e: # pylint: disable=broad-except - print(e) - logging.info(e) - if pbar: - pbar.update(1) - finally: - if pbar: - pbar.close() + while not minimized and attempts < max_attempts: + attempts += 1 + try: + logging.info( + "Minimizing protein, attempt %d of %d.", + attempts, + max_attempts, + ) + ret = _openmm_minimize( + pdb_string, + max_iterations=max_iterations, + tolerance=tolerance, + stiffness=stiffness, + restraint_set=restraint_set, + exclude_residues=exclude_residues, + use_gpu=use_gpu, + ) + minimized = True + except Exception as e: # pylint: disable=broad-except + print(e) + logging.info(e) if not minimized: raise ValueError(f"Minimization failed after {max_attempts} attempts.") @@ -542,6 +530,7 @@ def run_pipeline( max_attempts: int = 100, checks: bool = True, exclude_residues: Optional[Sequence[int]] = None, + pbar: Optional[tqdm] = None, ): """Run iterative amber relax. @@ -565,6 +554,7 @@ def run_pipeline( checks: Whether to perform cleaning checks. exclude_residues: An optional list of zero-indexed residues to exclude from restraints. + pbar: Optional progress bar for status updates. Returns: out: A dictionary of output values. @@ -583,59 +573,47 @@ def run_pipeline( violations = np.inf iteration = 0 - pbar = None - if max_outer_iterations > 1: - pbar = tqdm( - total=max_outer_iterations, desc=" Relax iterations", unit="iter" + while violations > 0 and iteration < max_outer_iterations: + ret = _run_one_iteration( + pdb_string=pdb_string, + exclude_residues=exclude_residues, + max_iterations=max_iterations, + tolerance=tolerance, + stiffness=stiffness, + restraint_set=restraint_set, + max_attempts=max_attempts, + use_gpu=use_gpu, ) - try: - while violations > 0 and iteration < max_outer_iterations: - ret = _run_one_iteration( - pdb_string=pdb_string, - exclude_residues=exclude_residues, - max_iterations=max_iterations, - tolerance=tolerance, - stiffness=stiffness, - restraint_set=restraint_set, - max_attempts=max_attempts, - use_gpu=use_gpu, - ) - - headers = protein.get_pdb_headers(prot) - if len(headers) > 0: - ret["min_pdb"] = "\n".join(["\n".join(headers), ret["min_pdb"]]) - - prot = protein.from_pdb_string(ret["min_pdb"]) - if place_hydrogens_every_iteration: - pdb_string = clean_protein(prot, checks=True) - else: - pdb_string = ret["min_pdb"] - ret.update(get_violation_metrics(prot)) - ret.update( - { - "num_exclusions": len(exclude_residues), - "iteration": iteration, - } - ) - violations = ret["violations_per_residue"] - exclude_residues = exclude_residues.union(ret["residue_violations"]) + headers = protein.get_pdb_headers(prot) + if len(headers) > 0: + ret["min_pdb"] = "\n".join(["\n".join(headers), ret["min_pdb"]]) - logging.info( - "Iteration completed: Einit %.2f Efinal %.2f Time %.2f s " - "num residue violations %d num residue exclusions %d ", - ret["einit"], - ret["efinal"], - ret["opt_time"], - ret["num_residue_violations"], - ret["num_exclusions"], - ) - iteration += 1 - if pbar: - pbar.update(1) - finally: - if pbar: - pbar.close() + prot = protein.from_pdb_string(ret["min_pdb"]) + if place_hydrogens_every_iteration: + pdb_string = clean_protein(prot, checks=True) + else: + pdb_string = ret["min_pdb"] + ret.update(get_violation_metrics(prot)) + ret.update( + { + "num_exclusions": len(exclude_residues), + "iteration": iteration, + } + ) + violations = ret["violations_per_residue"] + exclude_residues = exclude_residues.union(ret["residue_violations"]) + + logging.info( + "Iteration completed: Einit %.2f Efinal %.2f Time %.2f s " + "num residue violations %d num residue exclusions %d ", + ret["einit"], + ret["efinal"], + ret["opt_time"], + ret["num_residue_violations"], + ret["num_exclusions"], + ) + iteration += 1 return ret diff --git a/src/graphrelax/LigandMPNN/openfold/np/relax/relax.py b/src/graphrelax/LigandMPNN/openfold/np/relax/relax.py index 14ed73a..4ca3553 100644 --- a/src/graphrelax/LigandMPNN/openfold/np/relax/relax.py +++ b/src/graphrelax/LigandMPNN/openfold/np/relax/relax.py @@ -16,11 +16,12 @@ # limitations under the License. """Amber relaxation.""" -from typing import Any, Dict, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple import numpy as np from openfold.np import protein from openfold.np.relax import amber_minimize, utils +from tqdm import tqdm class AmberRelaxation(object): @@ -61,7 +62,7 @@ def __init__( self._use_gpu = use_gpu def process( - self, *, prot: protein.Protein + self, *, prot: protein.Protein, pbar: Optional[tqdm] = None ) -> Tuple[str, Dict[str, Any], np.ndarray]: """Runs Amber relax on a prediction, adds hydrogens, returns PDB string.""" out = amber_minimize.run_pipeline( @@ -72,6 +73,7 @@ def process( exclude_residues=self._exclude_residues, max_outer_iterations=self._max_outer_iterations, use_gpu=self._use_gpu, + pbar=pbar, ) min_pos = out["pos"] start_pos = out["posinit"] diff --git a/src/graphrelax/LigandMPNN/sc_utils.py b/src/graphrelax/LigandMPNN/sc_utils.py index b328e5f..c457067 100644 --- a/src/graphrelax/LigandMPNN/sc_utils.py +++ b/src/graphrelax/LigandMPNN/sc_utils.py @@ -1,6 +1,7 @@ # flake8: noqa # This is vendored code from LigandMPNN - style issues are preserved from upstream import sys +from typing import Optional import numpy as np import torch @@ -67,6 +68,7 @@ def pack_side_chains( num_samples=10, repack_everything=True, num_context_atoms=16, + pbar: Optional[tqdm] = None, ): device = feature_dict["X"].device torsion_dict = make_torsion_features(feature_dict, repack_everything) @@ -103,11 +105,7 @@ def pack_side_chains( feature_dict["h_E"] = h_E feature_dict["E_idx"] = E_idx - step_range = range(num_denoising_steps) - if num_denoising_steps > 1: - step_range = tqdm(step_range, desc=" Denoising steps", unit="step") - - for step in step_range: + for step in range(num_denoising_steps): mean, concentration, mix_logits = model_sc.decode(feature_dict) mix = D.Categorical(logits=mix_logits) comp = D.VonMises(mean, concentration) diff --git a/src/graphrelax/designer.py b/src/graphrelax/designer.py index b90f325..9e224a3 100644 --- a/src/graphrelax/designer.py +++ b/src/graphrelax/designer.py @@ -8,6 +8,7 @@ import numpy as np import torch +from tqdm import tqdm from graphrelax.config import DesignConfig from graphrelax.resfile import ALL_AAS, DesignSpec, ResidueMode @@ -131,6 +132,7 @@ def design( pdb_path: Path, design_spec: Optional[DesignSpec] = None, design_all: bool = False, + pbar: Optional[tqdm] = None, ) -> dict: """ Run sequence design on a structure. @@ -139,6 +141,7 @@ def design( pdb_path: Path to input PDB design_spec: Specification of which residues to design/fix design_all: If True, design all residues (full redesign) + pbar: Optional progress bar for status updates Returns: Dictionary with designed sequence, structure, and scores @@ -215,6 +218,7 @@ def design( self.config.sc_num_denoising_steps, self.config.sc_num_samples, repack_everything=False, + pbar=pbar, ) output_dict.update(sc_dict) @@ -250,6 +254,7 @@ def repack( self, pdb_path: Path, design_spec: Optional[DesignSpec] = None, + pbar: Optional[tqdm] = None, ) -> dict: """ Repack side chains without changing sequence. @@ -257,6 +262,7 @@ def repack( Args: pdb_path: Path to input PDB design_spec: Specification (NATRO residues excluded from repacking) + pbar: Optional progress bar for status updates Returns: Dictionary with repacked structure @@ -300,6 +306,7 @@ def repack( self.config.sc_num_denoising_steps, self.config.sc_num_samples, repack_everything=True, + pbar=pbar, ) # Get sequence diff --git a/src/graphrelax/pipeline.py b/src/graphrelax/pipeline.py index 0c5db7b..cc2407b 100644 --- a/src/graphrelax/pipeline.py +++ b/src/graphrelax/pipeline.py @@ -94,66 +94,72 @@ def run( all_results = [] all_scores = [] - output_range = range(1, self.config.n_outputs + 1) - if self.config.n_outputs > 1: - output_range = tqdm( - output_range, desc="Generating outputs", unit="output" - ) - - for output_idx in output_range: - logger.info( - f"Generating output {output_idx}/{self.config.n_outputs}" - ) - - result = self._run_single_output( - input_pdb=input_pdb, - design_spec=design_spec, - design_all=design_all, - input_format=input_format, - ) + # Calculate total steps for unified progress bar + total_steps = self.config.n_outputs * self.config.n_iterations + pbar = None + if total_steps > 1: + pbar = tqdm(total=total_steps, desc="Progress", unit="step") - # Format output path - out_path = format_output_path( - output_pdb, output_idx, self.config.n_outputs - ) - - # Convert to target format if needed and save - final_structure = result["final_pdb"] - if target_format != StructureFormat.PDB: - final_structure = convert_to_format( - final_structure, target_format + try: + for output_idx in range(1, self.config.n_outputs + 1): + logger.info( + f"Generating output {output_idx}/{self.config.n_outputs}" ) - save_pdb_string(final_structure, out_path) - - result["output_path"] = out_path - all_results.append(result) - - # Collect scores - score_dict = { - "total_score": result.get("final_energy", 0.0), - "openmm_energy": result.get("final_energy", 0.0), - } - - # Add energy breakdown if available - if "energy_breakdown" in result: - for key, val in result["energy_breakdown"].items(): - if key != "total_energy": - score_dict[key] = val - - # Add LigandMPNN score - if "ligandmpnn_loss" in result: - score_dict["ligandmpnn_score"] = compute_ligandmpnn_score( - result["ligandmpnn_loss"] + + result = self._run_single_output( + input_pdb=input_pdb, + design_spec=design_spec, + design_all=design_all, + input_format=input_format, + output_idx=output_idx, + pbar=pbar, ) - # Add sequence recovery - if result.get("sequence") and result.get("native_sequence"): - score_dict["seq_recovery"] = compute_sequence_recovery( - result["sequence"], result["native_sequence"] + # Format output path + out_path = format_output_path( + output_pdb, output_idx, self.config.n_outputs ) - score_dict["description"] = out_path.name - all_scores.append(score_dict) + # Convert to target format if needed and save + final_structure = result["final_pdb"] + if target_format != StructureFormat.PDB: + final_structure = convert_to_format( + final_structure, target_format + ) + save_pdb_string(final_structure, out_path) + + result["output_path"] = out_path + all_results.append(result) + + # Collect scores + score_dict = { + "total_score": result.get("final_energy", 0.0), + "openmm_energy": result.get("final_energy", 0.0), + } + + # Add energy breakdown if available + if "energy_breakdown" in result: + for key, val in result["energy_breakdown"].items(): + if key != "total_energy": + score_dict[key] = val + + # Add LigandMPNN score + if "ligandmpnn_loss" in result: + score_dict["ligandmpnn_score"] = compute_ligandmpnn_score( + result["ligandmpnn_loss"] + ) + + # Add sequence recovery + if result.get("sequence") and result.get("native_sequence"): + score_dict["seq_recovery"] = compute_sequence_recovery( + result["sequence"], result["native_sequence"] + ) + + score_dict["description"] = out_path.name + all_scores.append(score_dict) + finally: + if pbar: + pbar.close() # Write scorefile if requested if self.config.scorefile and all_scores: @@ -170,6 +176,8 @@ def _run_single_output( design_spec: Optional[DesignSpec], design_all: bool, input_format: StructureFormat = StructureFormat.PDB, + output_idx: int = 1, + pbar: Optional[tqdm] = None, ) -> dict: """Run pipeline for a single output.""" result = { @@ -218,23 +226,25 @@ def _run_single_output( original_native_sequence = None try: - iteration_range = range(1, self.config.n_iterations + 1) - if self.config.n_iterations > 1: - iteration_range = tqdm( - iteration_range, desc=" Iterations", unit="iter" - ) - - for iteration in iteration_range: + for iteration in range(1, self.config.n_iterations + 1): logger.info( f" Iteration {iteration}/{self.config.n_iterations}" ) + # Update progress bar description + if pbar: + pbar.set_description( + f"Output {output_idx}/{self.config.n_outputs}, " + f"iter {iteration}/{self.config.n_iterations}" + ) + iter_result = self._run_iteration( pdb_path=current_pdb_path, design_spec=design_spec, design_all=design_all, iteration=iteration, original_native_sequence=original_native_sequence, + pbar=pbar, ) result["iterations"].append(iter_result) @@ -261,6 +271,10 @@ def _run_single_output( f"RMSD={info['rmsd']:.3f}" ) + # Increment progress bar after each iteration completes + if pbar: + pbar.update(1) + # Store final results result["final_pdb"] = current_pdb if result["iterations"]: @@ -287,6 +301,7 @@ def _run_iteration( design_all: bool, iteration: int, original_native_sequence: Optional[str] = None, + pbar: Optional[tqdm] = None, ) -> dict: """Run a single iteration of design/repack + relax.""" iter_result = {} @@ -295,10 +310,13 @@ def _run_iteration( if self.config.mode in (PipelineMode.DESIGN, PipelineMode.DESIGN_ONLY): # Design mode logger.debug(" Running design...") + if pbar: + pbar.set_postfix_str("designing") design_result = self.designer.design( pdb_path=pdb_path, design_spec=design_spec, design_all=design_all, + pbar=pbar, ) pdb_string = self.designer.result_to_pdb_string(design_result) @@ -319,9 +337,12 @@ def _run_iteration( elif self.config.mode in (PipelineMode.RELAX, PipelineMode.REPACK_ONLY): # Repack mode logger.debug(" Running repack...") + if pbar: + pbar.set_postfix_str("repacking") repack_result = self.designer.repack( pdb_path=pdb_path, design_spec=design_spec, + pbar=pbar, ) pdb_string = self.designer.result_to_pdb_string(repack_result) @@ -342,7 +363,11 @@ def _run_iteration( PipelineMode.DESIGN, ): logger.debug(" Running relaxation...") - relaxed_pdb, relax_info, violations = self.relaxer.relax(pdb_string) + if pbar: + pbar.set_postfix_str("relaxing") + relaxed_pdb, relax_info, violations = self.relaxer.relax( + pdb_string, pbar=pbar + ) iter_result["output_pdb"] = relaxed_pdb iter_result["relax_info"] = relax_info @@ -357,4 +382,8 @@ def _run_iteration( # No relaxation iter_result["output_pdb"] = pdb_string + # Clear postfix when done with iteration + if pbar: + pbar.set_postfix_str("") + return iter_result diff --git a/src/graphrelax/relaxer.py b/src/graphrelax/relaxer.py index 3407ef3..327b844 100644 --- a/src/graphrelax/relaxer.py +++ b/src/graphrelax/relaxer.py @@ -10,6 +10,7 @@ from openmm import Platform from openmm import app as openmm_app from openmm import openmm, unit +from tqdm import tqdm from graphrelax.chain_gaps import ( detect_chain_gaps, @@ -54,7 +55,9 @@ def _check_gpu_available(self) -> bool: logger.info("OpenMM CUDA not available, using CPU") return False - def relax(self, pdb_string: str) -> Tuple[str, dict, np.ndarray]: + def relax( + self, pdb_string: str, pbar: Optional[tqdm] = None + ) -> Tuple[str, dict, np.ndarray]: """ Relax a structure from PDB string. @@ -70,6 +73,7 @@ def relax(self, pdb_string: str) -> Tuple[str, dict, np.ndarray]: Args: pdb_string: PDB file contents as string + pbar: Optional progress bar for status updates Returns: Tuple of (relaxed_pdb_string, debug_info, violations) @@ -93,7 +97,9 @@ def relax(self, pdb_string: str) -> Tuple[str, dict, np.ndarray]: if self.config.constrained: prot = protein.from_pdb_string(protein_pdb) - relaxed_pdb, debug_info, violations = self.relax_protein(prot) + relaxed_pdb, debug_info, violations = self.relax_protein( + prot, pbar=pbar + ) else: relaxed_pdb, debug_info, violations = self._relax_unconstrained( protein_pdb @@ -126,12 +132,15 @@ def relax_pdb_file(self, pdb_path: Path) -> Tuple[str, dict, np.ndarray]: pdb_string = f.read() return self.relax(pdb_string) - def relax_protein(self, prot) -> Tuple[str, dict, np.ndarray]: + def relax_protein( + self, prot, pbar: Optional[tqdm] = None + ) -> Tuple[str, dict, np.ndarray]: """ Relax a Protein object using OpenFold's AmberRelaxation. Args: prot: OpenFold Protein object + pbar: Optional progress bar for status updates Returns: Tuple of (relaxed_pdb_string, debug_info, violations) @@ -152,7 +161,9 @@ def relax_protein(self, prot) -> Tuple[str, dict, np.ndarray]: f"stiffness={self.config.stiffness}, gpu={use_gpu})" ) - relaxed_pdb, debug_data, violations = relaxer.process(prot=prot) + relaxed_pdb, debug_data, violations = relaxer.process( + prot=prot, pbar=pbar + ) logger.info( f"Relaxation complete: E_init={debug_data['initial_energy']:.2f}, "