Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 38 additions & 29 deletions abcfold/abcfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ def run(args, config, defaults, config_file):
)

if boltz_success:
bolt_out_dir = list(args.output_dir.glob("boltz_results*"))[0]
bo = BoltzOutput(bolt_out_dir, input_params, name)
bolt_out_dirs = list(args.output_dir.glob("boltz_results*"))
bo = BoltzOutput(bolt_out_dirs, input_params, name, args.save_input)
outputs.append(bo)
successful_runs.append(boltz_success)

Expand All @@ -194,18 +194,18 @@ def run(args, config, defaults, config_file):
elif args.templates:
template_hits_path = make_dummy_m8_file(run_json, temp_dir)

chai_output_dir = args.output_dir.joinpath("chai1")
chai_success = run_chai(
input_json=run_json,
output_dir=chai_output_dir,
output_dir=args.output_dir,
save_input=args.save_input,
number_of_models=args.number_of_models,
num_recycles=args.num_recycles,
template_hits_path=template_hits_path,
)

if chai_success:
co = ChaiOutput(chai_output_dir, input_params, name, args.save_input)
chai_output_dirs = list(args.output_dir.glob("chai_output*"))
co = ChaiOutput(chai_output_dirs, input_params, name, args.save_input)
outputs.append(co)
successful_runs.append(chai_success)

Expand Down Expand Up @@ -259,28 +259,11 @@ def run(args, config, defaults, config_file):
if args.boltz:
if boltz_success:
programs_run.append("Boltz")
for idx in bo.output.keys():
model = bo.output[idx]["cif"]
model.check_clashes()
score_file = bo.output[idx]["json"]
plddt = model.residue_plddts
if len(indicies) > 0:
plddt = insert_none_by_minus_one(indicies[index_counter], plddt)
index_counter += 1
model_data = get_model_data(
model, plot_dict, "Boltz", plddt, score_file, args.output_dir
)
boltz_models["models"].append(model_data)

chai_models = {"models": []}
if args.chai1:
if chai_success:
programs_run.append("Chai-1")
for idx in co.output.keys():
if idx >= 0:
model = co.output[idx]["cif"]
for seed in bo.output.keys():
for idx in bo.output[seed].keys():
model = bo.output[seed][idx]["cif"]
model.check_clashes()
score_file = co.output[idx]["scores"]
score_file = bo.output[seed][idx]["json"]
plddt = model.residue_plddts
if len(indicies) > 0:
plddt = insert_none_by_minus_one(
Expand All @@ -290,12 +273,38 @@ def run(args, config, defaults, config_file):
model_data = get_model_data(
model,
plot_dict,
"Chai-1",
"Boltz",
plddt,
score_file,
args.output_dir,
args.output_dir
)
chai_models["models"].append(model_data)
boltz_models["models"].append(model_data)

chai_models = {"models": []}
if args.chai1:
if chai_success:
programs_run.append("Chai-1")
for seed in co.output.keys():
for idx in co.output[seed].keys():
if idx >= 0:
model = co.output[seed][idx]["cif"]
model.check_clashes()
score_file = co.output[seed][idx]["scores"]
plddt = model.residue_plddts
if len(indicies) > 0:
plddt = insert_none_by_minus_one(
indicies[index_counter], plddt
)
index_counter += 1
model_data = get_model_data(
model,
plot_dict,
"Chai-1",
plddt,
score_file,
args.output_dir,
)
chai_models["models"].append(model_data)

combined_models = (
alphafold_models["models"] + boltz_models["models"] + chai_models["models"]
Expand Down
6 changes: 6 additions & 0 deletions abcfold/boltz/af3_to_boltz.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, working_dir: Union[str, Path], create_files: bool = True):
self.working_dir = working_dir
self.yaml_string: str = ""
self.msa_file: Optional[Union[str, Path]] = "null"
self.seeds: list = [42]
self.__ids: List[Union[str, int]] = []
self.__id_char: str = "A"
self.__id_links: Dict[Union[str, int], list] = {}
Expand Down Expand Up @@ -75,6 +76,11 @@ def json_to_yaml(

self.yaml_string += self.add_version_number("1")
for key, value in json_dict.items():
if key == "modelSeeds":
if isinstance(value, list):
self.seeds = value
elif isinstance(value, int):
self.seeds = [value]
if key == "sequences":
if "sequences" not in self.yaml_string:
self.yaml_string += self.add_non_indented_string("sequences")
Expand Down
86 changes: 49 additions & 37 deletions abcfold/boltz/run_boltz.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,43 +52,51 @@ def run_boltz(

boltz_yaml = BoltzYaml(working_dir)
boltz_yaml.json_to_yaml(input_json)
out_file = working_dir.joinpath(f"{input_json.stem}.yaml")

boltz_yaml.write_yaml(out_file)
logger.info("Running Boltz")
cmd = (
generate_boltz_command(out_file, output_dir, number_of_models, num_recycles)
if not test
else generate_boltz_test_command()
)

with subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) as proc:
stdout = ""
if proc.stdout:
for line in proc.stdout:
sys.stdout.write(line.decode())
sys.stdout.flush()
stdout += line.decode()
_, stderr = proc.communicate()
if proc.returncode != 0:
if proc.stderr:
logger.error(stderr.decode())
output_err_file = output_dir / "boltz_error.log"
with open(output_err_file, "w") as f:
f.write(stderr.decode())
logger.error(
"Boltz run failed. Error log is in %s", output_err_file
)
else:
logger.error("Boltz run failed")
return False
elif "WARNING: ran out of memory" in stdout:
logger.error("Boltz ran out of memory")
return False

for seed in boltz_yaml.seeds:
out_file = working_dir.joinpath(f"{input_json.stem}_seed-{seed}.yaml")

boltz_yaml.write_yaml(out_file)
logger.info("Running Boltz using seed: %s", seed)
cmd = (
generate_boltz_command(
out_file,
output_dir,
number_of_models,
num_recycles,
seed=seed,
)
if not test
else generate_boltz_test_command()
)

with subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) as proc:
stdout = ""
if proc.stdout:
for line in proc.stdout:
sys.stdout.write(line.decode())
sys.stdout.flush()
stdout += line.decode()
_, stderr = proc.communicate()
if proc.returncode != 0:
if proc.stderr:
logger.error(stderr.decode())
output_err_file = output_dir / "boltz_error.log"
with open(output_err_file, "w") as f:
f.write(stderr.decode())
logger.error(
"Boltz run failed. Error log is in %s", output_err_file
)
else:
logger.error("Boltz run failed")
return False
elif "WARNING: ran out of memory" in stdout:
logger.error("Boltz ran out of memory")
return False

logger.info("Boltz run complete")
logger.info("Output files are in %s", output_dir)
Expand All @@ -100,6 +108,7 @@ def generate_boltz_command(
output_dir: Union[str, Path],
number_of_models: int = 5,
num_recycles: int = 10,
seed: int = 42,
) -> list:
"""
Generate the Boltz command
Expand All @@ -108,6 +117,7 @@ def generate_boltz_command(
input_yaml (Union[str, Path]): Path to the input YAML file
output_dir (Union[str, Path]): Path to the output directory
number_of_models (int): Number of models to generate
seed (int): Seed for the random number generator

Returns:
list: The Boltz command
Expand All @@ -125,6 +135,8 @@ def generate_boltz_command(
str(number_of_models),
"--recycling_steps",
str(num_recycles),
"--seed",
str(seed),
]


Expand Down
7 changes: 7 additions & 0 deletions abcfold/chai1/af3_to_chai.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, working_dir: Union[str, Path], create_files: bool = True):
self.fasta = Path(working_dir) / "chai1.fasta"
self.constraints = Path(working_dir) / "chai1_constraints.csv"
self.msa_file: Optional[Union[str, Path]] = None
self.seeds: list = [42]
self.__ids: List[Union[str, int]] = []
self.__create_files = create_files

Expand Down Expand Up @@ -203,6 +204,12 @@ def json_to_fasta(self, json_file_or_dict: Union[dict, str, Path]):
bonded_pairs = json_dict["bondedAtomPairs"]
self.bonded_pairs_to_file(bonded_pairs, fasta_data)

if "modelSeeds" in json_dict.keys():
if isinstance(json_dict["modelSeeds"], int):
self.seeds = [json_dict["modelSeeds"]]
elif isinstance(json_dict["modelSeeds"], list):
self.seeds = json_dict["modelSeeds"]

if not self.__create_files:
self.fasta.unlink()

Expand Down
75 changes: 40 additions & 35 deletions abcfold/chai1/run_chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,10 @@ def run_chai(

with tempfile.TemporaryDirectory() as temp_dir:
working_dir = Path(temp_dir)
chai_output_dir = output_dir
if save_input:
logger.info("Saving input fasta file and msa to the output directory")
working_dir = output_dir
working_dir.mkdir(parents=True, exist_ok=True)
chai_output_dir = output_dir / "chai_output"

chai_fasta = ChaiFasta(working_dir)
chai_fasta.json_to_fasta(input_json)
Expand All @@ -62,42 +60,46 @@ def run_chai(
msa_dir = chai_fasta.working_dir
out_constraints = chai_fasta.constraints

cmd = (
generate_chai_command(
out_fasta,
msa_dir,
out_constraints,
chai_output_dir,
number_of_models,
num_recycles=num_recycles,
use_templates_server=use_templates_server,
template_hits_path=template_hits_path,
for seed in chai_fasta.seeds:
chai_output_dir = output_dir / f"chai_output_seed-{seed}"

logger.info(f"Running Chai-1 using seed {seed}")
cmd = (
generate_chai_command(
out_fasta,
msa_dir,
out_constraints,
chai_output_dir,
number_of_models,
num_recycles=num_recycles,
seed=seed,
use_templates_server=use_templates_server,
template_hits_path=template_hits_path,
)
if not test
else generate_chai_test_command()
)
if not test
else generate_chai_test_command()
)

logger.info("Running Chai-1")
with subprocess.Popen(
cmd,
stdout=sys.stdout,
stderr=subprocess.PIPE,
) as proc:
_, stderr = proc.communicate()
if proc.returncode != 0:
if proc.stderr:
if chai_output_dir.exists():
output_err_file = chai_output_dir / "chai_error.log"
with subprocess.Popen(
cmd,
stdout=sys.stdout,
stderr=subprocess.PIPE,
) as proc:
_, stderr = proc.communicate()
if proc.returncode != 0:
if proc.stderr:
if chai_output_dir.exists():
output_err_file = chai_output_dir / "chai_error.log"
else:
output_err_file = chai_output_dir.parent / "chai_error.log"
with open(output_err_file, "w") as f:
f.write(stderr.decode())
logger.error(
"Chai-1 run failed. Error log is in %s", output_err_file
)
else:
output_err_file = chai_output_dir.parent / "chai_error.log"
with open(output_err_file, "w") as f:
f.write(stderr.decode())
logger.error(
"Chai-1 run failed. Error log is in %s", output_err_file
)
else:
logger.error("Chai-1 run failed")
return False
logger.error("Chai-1 run failed")
return False

logger.info("Chai-1 run complete")
return True
Expand All @@ -110,6 +112,7 @@ def generate_chai_command(
output_dir: Union[str, Path],
number_of_models: int = 5,
num_recycles: int = 10,
seed: int = 42,
use_templates_server: bool = False,
template_hits_path: Path | None = None,
) -> list:
Expand All @@ -123,6 +126,7 @@ def generate_chai_command(
output_dir (Union[str, Path]): Path to the output directory
number_of_models (int): Number of models to generate
num_recycles (int): Number of trunk recycles
seed (int): Seed for the random number generator
use_templates_server (bool): If True, use templates from the server
template_hits_path (Path): Path to the template hits m8 file

Expand All @@ -141,6 +145,7 @@ def generate_chai_command(

cmd += ["--num-diffn-samples", str(number_of_models)]
cmd += ["--num-trunk-recycles", str(num_recycles)]
cmd += ["--seed", str(seed)]

assert not (
use_templates_server and template_hits_path
Expand Down
Loading
Loading