diff --git a/abcfold/abcfold.py b/abcfold/abcfold.py index 07a05dc..c32fb05 100644 --- a/abcfold/abcfold.py +++ b/abcfold/abcfold.py @@ -180,8 +180,8 @@ def run(args, config, defaults, config_file): ) if boltz_success: - bolt_out_dir = list(args.output_dir.glob("boltz_results*"))[0] - bo = BoltzOutput(bolt_out_dir, input_params, name) + bolt_out_dirs = list(args.output_dir.glob("boltz_results*")) + bo = BoltzOutput(bolt_out_dirs, input_params, name, args.save_input) outputs.append(bo) successful_runs.append(boltz_success) @@ -194,10 +194,9 @@ def run(args, config, defaults, config_file): elif args.templates: template_hits_path = make_dummy_m8_file(run_json, temp_dir) - chai_output_dir = args.output_dir.joinpath("chai1") chai_success = run_chai( input_json=run_json, - output_dir=chai_output_dir, + output_dir=args.output_dir, save_input=args.save_input, number_of_models=args.number_of_models, num_recycles=args.num_recycles, @@ -205,7 +204,8 @@ def run(args, config, defaults, config_file): ) if chai_success: - co = ChaiOutput(chai_output_dir, input_params, name, args.save_input) + chai_output_dirs = list(args.output_dir.glob("chai_output*")) + co = ChaiOutput(chai_output_dirs, input_params, name, args.save_input) outputs.append(co) successful_runs.append(chai_success) @@ -259,28 +259,11 @@ def run(args, config, defaults, config_file): if args.boltz: if boltz_success: programs_run.append("Boltz") - for idx in bo.output.keys(): - model = bo.output[idx]["cif"] - model.check_clashes() - score_file = bo.output[idx]["json"] - plddt = model.residue_plddts - if len(indicies) > 0: - plddt = insert_none_by_minus_one(indicies[index_counter], plddt) - index_counter += 1 - model_data = get_model_data( - model, plot_dict, "Boltz", plddt, score_file, args.output_dir - ) - boltz_models["models"].append(model_data) - - chai_models = {"models": []} - if args.chai1: - if chai_success: - programs_run.append("Chai-1") - for idx in co.output.keys(): - if idx >= 0: - model = co.output[idx]["cif"] + for seed in bo.output.keys(): + for idx in bo.output[seed].keys(): + model = bo.output[seed][idx]["cif"] model.check_clashes() - score_file = co.output[idx]["scores"] + score_file = bo.output[seed][idx]["json"] plddt = model.residue_plddts if len(indicies) > 0: plddt = insert_none_by_minus_one( @@ -290,12 +273,38 @@ def run(args, config, defaults, config_file): model_data = get_model_data( model, plot_dict, - "Chai-1", + "Boltz", plddt, score_file, - args.output_dir, + args.output_dir ) - chai_models["models"].append(model_data) + boltz_models["models"].append(model_data) + + chai_models = {"models": []} + if args.chai1: + if chai_success: + programs_run.append("Chai-1") + for seed in co.output.keys(): + for idx in co.output[seed].keys(): + if idx >= 0: + model = co.output[seed][idx]["cif"] + model.check_clashes() + score_file = co.output[seed][idx]["scores"] + plddt = model.residue_plddts + if len(indicies) > 0: + plddt = insert_none_by_minus_one( + indicies[index_counter], plddt + ) + index_counter += 1 + model_data = get_model_data( + model, + plot_dict, + "Chai-1", + plddt, + score_file, + args.output_dir, + ) + chai_models["models"].append(model_data) combined_models = ( alphafold_models["models"] + boltz_models["models"] + chai_models["models"] diff --git a/abcfold/boltz/af3_to_boltz.py b/abcfold/boltz/af3_to_boltz.py index f807270..7d26114 100644 --- a/abcfold/boltz/af3_to_boltz.py +++ b/abcfold/boltz/af3_to_boltz.py @@ -19,6 +19,7 @@ def __init__(self, working_dir: Union[str, Path], create_files: bool = True): self.working_dir = working_dir self.yaml_string: str = "" self.msa_file: Optional[Union[str, Path]] = "null" + self.seeds: list = [42] self.__ids: List[Union[str, int]] = [] self.__id_char: str = "A" self.__id_links: Dict[Union[str, int], list] = {} @@ -75,6 +76,11 @@ def json_to_yaml( self.yaml_string += self.add_version_number("1") for key, value in json_dict.items(): + if key == "modelSeeds": + if isinstance(value, list): + self.seeds = value + elif isinstance(value, int): + self.seeds = [value] if key == "sequences": if "sequences" not in self.yaml_string: self.yaml_string += self.add_non_indented_string("sequences") diff --git a/abcfold/boltz/run_boltz.py b/abcfold/boltz/run_boltz.py index 65c3502..2dd3d48 100644 --- a/abcfold/boltz/run_boltz.py +++ b/abcfold/boltz/run_boltz.py @@ -52,43 +52,51 @@ def run_boltz( boltz_yaml = BoltzYaml(working_dir) boltz_yaml.json_to_yaml(input_json) - out_file = working_dir.joinpath(f"{input_json.stem}.yaml") - - boltz_yaml.write_yaml(out_file) - logger.info("Running Boltz") - cmd = ( - generate_boltz_command(out_file, output_dir, number_of_models, num_recycles) - if not test - else generate_boltz_test_command() - ) - - with subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) as proc: - stdout = "" - if proc.stdout: - for line in proc.stdout: - sys.stdout.write(line.decode()) - sys.stdout.flush() - stdout += line.decode() - _, stderr = proc.communicate() - if proc.returncode != 0: - if proc.stderr: - logger.error(stderr.decode()) - output_err_file = output_dir / "boltz_error.log" - with open(output_err_file, "w") as f: - f.write(stderr.decode()) - logger.error( - "Boltz run failed. Error log is in %s", output_err_file - ) - else: - logger.error("Boltz run failed") - return False - elif "WARNING: ran out of memory" in stdout: - logger.error("Boltz ran out of memory") - return False + + for seed in boltz_yaml.seeds: + out_file = working_dir.joinpath(f"{input_json.stem}_seed-{seed}.yaml") + + boltz_yaml.write_yaml(out_file) + logger.info("Running Boltz using seed: %s", seed) + cmd = ( + generate_boltz_command( + out_file, + output_dir, + number_of_models, + num_recycles, + seed=seed, + ) + if not test + else generate_boltz_test_command() + ) + + with subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) as proc: + stdout = "" + if proc.stdout: + for line in proc.stdout: + sys.stdout.write(line.decode()) + sys.stdout.flush() + stdout += line.decode() + _, stderr = proc.communicate() + if proc.returncode != 0: + if proc.stderr: + logger.error(stderr.decode()) + output_err_file = output_dir / "boltz_error.log" + with open(output_err_file, "w") as f: + f.write(stderr.decode()) + logger.error( + "Boltz run failed. Error log is in %s", output_err_file + ) + else: + logger.error("Boltz run failed") + return False + elif "WARNING: ran out of memory" in stdout: + logger.error("Boltz ran out of memory") + return False logger.info("Boltz run complete") logger.info("Output files are in %s", output_dir) @@ -100,6 +108,7 @@ def generate_boltz_command( output_dir: Union[str, Path], number_of_models: int = 5, num_recycles: int = 10, + seed: int = 42, ) -> list: """ Generate the Boltz command @@ -108,6 +117,7 @@ def generate_boltz_command( input_yaml (Union[str, Path]): Path to the input YAML file output_dir (Union[str, Path]): Path to the output directory number_of_models (int): Number of models to generate + seed (int): Seed for the random number generator Returns: list: The Boltz command @@ -125,6 +135,8 @@ def generate_boltz_command( str(number_of_models), "--recycling_steps", str(num_recycles), + "--seed", + str(seed), ] diff --git a/abcfold/chai1/af3_to_chai.py b/abcfold/chai1/af3_to_chai.py index 9f80c37..7ab40dd 100644 --- a/abcfold/chai1/af3_to_chai.py +++ b/abcfold/chai1/af3_to_chai.py @@ -40,6 +40,7 @@ def __init__(self, working_dir: Union[str, Path], create_files: bool = True): self.fasta = Path(working_dir) / "chai1.fasta" self.constraints = Path(working_dir) / "chai1_constraints.csv" self.msa_file: Optional[Union[str, Path]] = None + self.seeds: list = [42] self.__ids: List[Union[str, int]] = [] self.__create_files = create_files @@ -203,6 +204,12 @@ def json_to_fasta(self, json_file_or_dict: Union[dict, str, Path]): bonded_pairs = json_dict["bondedAtomPairs"] self.bonded_pairs_to_file(bonded_pairs, fasta_data) + if "modelSeeds" in json_dict.keys(): + if isinstance(json_dict["modelSeeds"], int): + self.seeds = [json_dict["modelSeeds"]] + elif isinstance(json_dict["modelSeeds"], list): + self.seeds = json_dict["modelSeeds"] + if not self.__create_files: self.fasta.unlink() diff --git a/abcfold/chai1/run_chai1.py b/abcfold/chai1/run_chai1.py index 69bc20d..4528ed3 100644 --- a/abcfold/chai1/run_chai1.py +++ b/abcfold/chai1/run_chai1.py @@ -48,12 +48,10 @@ def run_chai( with tempfile.TemporaryDirectory() as temp_dir: working_dir = Path(temp_dir) - chai_output_dir = output_dir if save_input: logger.info("Saving input fasta file and msa to the output directory") working_dir = output_dir working_dir.mkdir(parents=True, exist_ok=True) - chai_output_dir = output_dir / "chai_output" chai_fasta = ChaiFasta(working_dir) chai_fasta.json_to_fasta(input_json) @@ -62,42 +60,46 @@ def run_chai( msa_dir = chai_fasta.working_dir out_constraints = chai_fasta.constraints - cmd = ( - generate_chai_command( - out_fasta, - msa_dir, - out_constraints, - chai_output_dir, - number_of_models, - num_recycles=num_recycles, - use_templates_server=use_templates_server, - template_hits_path=template_hits_path, + for seed in chai_fasta.seeds: + chai_output_dir = output_dir / f"chai_output_seed-{seed}" + + logger.info(f"Running Chai-1 using seed {seed}") + cmd = ( + generate_chai_command( + out_fasta, + msa_dir, + out_constraints, + chai_output_dir, + number_of_models, + num_recycles=num_recycles, + seed=seed, + use_templates_server=use_templates_server, + template_hits_path=template_hits_path, + ) + if not test + else generate_chai_test_command() ) - if not test - else generate_chai_test_command() - ) - logger.info("Running Chai-1") - with subprocess.Popen( - cmd, - stdout=sys.stdout, - stderr=subprocess.PIPE, - ) as proc: - _, stderr = proc.communicate() - if proc.returncode != 0: - if proc.stderr: - if chai_output_dir.exists(): - output_err_file = chai_output_dir / "chai_error.log" + with subprocess.Popen( + cmd, + stdout=sys.stdout, + stderr=subprocess.PIPE, + ) as proc: + _, stderr = proc.communicate() + if proc.returncode != 0: + if proc.stderr: + if chai_output_dir.exists(): + output_err_file = chai_output_dir / "chai_error.log" + else: + output_err_file = chai_output_dir.parent / "chai_error.log" + with open(output_err_file, "w") as f: + f.write(stderr.decode()) + logger.error( + "Chai-1 run failed. Error log is in %s", output_err_file + ) else: - output_err_file = chai_output_dir.parent / "chai_error.log" - with open(output_err_file, "w") as f: - f.write(stderr.decode()) - logger.error( - "Chai-1 run failed. Error log is in %s", output_err_file - ) - else: - logger.error("Chai-1 run failed") - return False + logger.error("Chai-1 run failed") + return False logger.info("Chai-1 run complete") return True @@ -110,6 +112,7 @@ def generate_chai_command( output_dir: Union[str, Path], number_of_models: int = 5, num_recycles: int = 10, + seed: int = 42, use_templates_server: bool = False, template_hits_path: Path | None = None, ) -> list: @@ -123,6 +126,7 @@ def generate_chai_command( output_dir (Union[str, Path]): Path to the output directory number_of_models (int): Number of models to generate num_recycles (int): Number of trunk recycles + seed (int): Seed for the random number generator use_templates_server (bool): If True, use templates from the server template_hits_path (Path): Path to the template hits m8 file @@ -141,6 +145,7 @@ def generate_chai_command( cmd += ["--num-diffn-samples", str(number_of_models)] cmd += ["--num-trunk-recycles", str(num_recycles)] + cmd += ["--seed", str(seed)] assert not ( use_templates_server and template_hits_path diff --git a/abcfold/html/html_utils.py b/abcfold/html/html_utils.py index 2a5b349..0e7964e 100644 --- a/abcfold/html/html_utils.py +++ b/abcfold/html/html_utils.py @@ -232,36 +232,46 @@ def get_all_cif_files(outputs) -> Dict[str, list]: method_cif_objs["Alphafold3"] = [] method_cif_objs["Alphafold3"].extend(output.cif_files[seed]) elif isinstance(output, BoltzOutput): - - method_cif_objs["Boltz"] = output.cif_files + for seed in output.seeds: + if "Boltz" not in method_cif_objs: + method_cif_objs["Boltz"] = [] + method_cif_objs["Boltz"].extend(output.cif_files[seed]) elif isinstance(output, ChaiOutput): - method_cif_objs["Chai-1"] = output.cif_files + for seed in output.seeds: + if "Chai-1" not in method_cif_objs: + method_cif_objs["Chai-1"] = [] + method_cif_objs["Chai-1"].extend(output.cif_files[seed]) return method_cif_objs def parse_scores(score_file: Union[ConfidenceJsonFile, NpzFile]) -> tuple: """ - Parse the scores from the score file + Parse the scores from the score file. Args: score_file (Union[ConfidenceJsonFile, NpzFile]): The score file object. Returns: - tuple: A tuple containing ptm_score and iptm_score as floats. + tuple: A tuple containing ptm_score and iptm_score as floats, or None if invalid """ ptm_score = None iptm_score = None if isinstance(score_file, ConfidenceJsonFile): data = score_file.load_json_file() - if "ptm" in data and "iptm" in data: - ptm_score = round(float(data["ptm"]), 2) - iptm_score = round(float(data["iptm"]), 2) elif isinstance(score_file, NpzFile): data = score_file.load_npz_file() - if "ptm" in data and "iptm" in data: - ptm_score = round(float(data["ptm"]), 2) - iptm_score = round(float(data["iptm"]), 2) + else: + return ptm_score, iptm_score + for key in ("ptm", "iptm"): + try: + value = float(data[key]) + if key == "ptm": + ptm_score = round(value, 2) + else: + iptm_score = round(value, 2) + except (KeyError, TypeError, ValueError): + continue return ptm_score, iptm_score diff --git a/abcfold/output/alphafold3.py b/abcfold/output/alphafold3.py index b4b10fe..fa3f692 100644 --- a/abcfold/output/alphafold3.py +++ b/abcfold/output/alphafold3.py @@ -42,7 +42,6 @@ def __init__( }, etc... } - This is different to the boltz and chai equivalent as they do not have seeds """ self.output_dir = Path(af3_output_dir) self.input_params = input_params diff --git a/abcfold/output/boltz.py b/abcfold/output/boltz.py index d70ff0e..c885d50 100644 --- a/abcfold/output/boltz.py +++ b/abcfold/output/boltz.py @@ -13,43 +13,48 @@ class BoltzOutput: def __init__( self, - boltz_output_dir: Union[str, Path], + boltz_output_dirs: list[Union[str, Path]], input_params: dict, name: str, + save_input: bool = False, ): """ Object to process the output of an Boltz run Args: - boltz_output_dir (Union[str, Path]): Path to the Boltz output directory + boltz_output_dirs (list[Union[str, Path]]): Path to the Boltz + output directory input_params (dict): Dictionary containing the input parameters used for the Boltz run name (str): Name given to the Boltz run + save_input (bool): If True, Boltz was run with the save_input flag Attributes: - output_dir (Path): Path to the Boltz output directory + output_dirs (list): List of paths to the Boltz output directory(s) input_params (dict): Dictionary containing the input parameters used for the Boltz run name (str): Name given to the Boltz run output (dict): Dictionary containing the processed output the contents - of the Boltz output directory. The dictionary is structured as follows: + of the Boltz output directory(s). The dictionary is structured as follows: { - 1: { - "pae": NpzFile, - "plddt": NpzFile, - "pde": NpzFile, - "cif": CifFile, - "json": ConfidenceJsonFile + "seed-1": { + 1: { + "pae": NpzFile, + "plddt": NpzFile, + "pde": NpzFile, + "cif": CifFile, + "json": ConfidenceJsonFile + }, + 2: { + "pae": NpzFile, + "plddt": NpzFile, + "pde": NpzFile, + "cif": CifFile, + "json": ConfidenceJsonFile + }, }, - 2: { - "pae": NpzFile, - "plddt": NpzFile, - "pde": NpzFile, - "cif": CifFile, - "json": ConfidenceJsonFile - }, - ... + etc... } pae_files (list): Ordered list of NpzFile objects containing the PAE data plddt_files (list): Ordered list of NpzFile objects containing the PLDDT @@ -58,77 +63,123 @@ def __init__( cif_files (list): Ordered list of CifFile objects containing the model data scores_files (list): Ordered list of ConfidenceJsonFile objects containing the model scores - """ - self.output_dir = Path(boltz_output_dir) + self.output_dirs = [Path(x) for x in boltz_output_dirs] self.input_params = input_params self.name = name + self.save_input = save_input + + parent_dir = self.output_dirs[0].parent + new_parent = parent_dir / f"boltz_{self.name}" + new_parent.mkdir(parents=True, exist_ok=True) + + if self.save_input: + boltz_yaml = list(parent_dir.glob("*.yaml"))[0] + if boltz_yaml.exists(): + boltz_yaml.rename(new_parent / "boltz_input.yaml") + boltz_msa = list(parent_dir.glob("*.a3m"))[0] + if boltz_msa.exists(): + boltz_msa.rename(new_parent / boltz_msa.name) + + new_output_dirs = [] + for output_dir in self.output_dirs: + if output_dir.name.startswith("boltz_results_"): + new_path = new_parent / output_dir.name + output_dir.rename(new_path) + new_output_dirs.append(new_path) + else: + new_output_dirs.append(output_dir) + self.output_dirs = new_output_dirs - if self.output_dir.name.startswith("boltz_results_"): - self.output_dir = self.output_dir.rename( - self.output_dir.parent / f"boltz_{name}" - ) self.yaml_input_obj = self.get_input_yaml() - self.output = self.process_boltz_output() - self.pae_files = [value["pae"] for value in self.output.values()] - self.cif_files = [value["cif"] for value in self.output.values()] + self.output = self.process_boltz_output() + self.seeds = list(self.output.keys()) + self.pae_files = { + seed: [value["pae"] for value in self.output[seed].values()] + for seed in self.seeds + } + self.cif_files = { + seed: [value["cif"] for value in self.output[seed].values()] + for seed in self.seeds + } + self.plddt_files = { + seed: [value["plddt"] for value in self.output[seed].values()] + for seed in self.seeds + } + self.pde_files = { + seed: [value["pde"] for value in self.output[seed].values()] + for seed in self.seeds + } + self.scores_files = { + seed: [value["json"] for value in self.output[seed].values()] + for seed in self.seeds + } self.pae_to_af3() - self.af3_pae_files = [value["af3_pae"] for value in self.output.values()] - self.plddt_files = [value["plddt"] for value in self.output.values()] - self.pde_files = [value["pde"] for value in self.output.values()] - self.scores_files = [value["json"] for value in self.output.values()] + self.af3_pae_files = { + seed: [value["af3_pae"] for value in self.output[seed].values()] + for seed in self.seeds + } def process_boltz_output(self): """ Function to process the output of a Boltz run """ file_groups = {} - for pathway in self.output_dir.rglob("*"): - number = pathway.stem.split("_model_")[-1] - if not number.isdigit(): - continue - number = int(number) - - file_type = pathway.suffix[1:] - if file_type == FileTypes.NPZ.value: - file_ = NpzFile(str(pathway)) - elif file_type == FileTypes.CIF.value: - file_ = CifFile(str(pathway), self.input_params) - - elif file_type == FileTypes.JSON.value: - file_ = ConfidenceJsonFile(str(pathway)) - else: - continue - if number not in file_groups: - file_groups[number] = [file_] - else: - file_groups[number].append(file_) - - model_number_file_type_file = {} - for model_number, files in file_groups.items(): - intermediate_dict = {} - for file_ in sorted(files, key=lambda x: x.suffix): - if file_.pathway.stem.startswith("pae"): - intermediate_dict["pae"] = file_ - elif file_.pathway.stem.startswith("plddt"): - intermediate_dict["plddt"] = file_ - elif file_.pathway.stem.startswith("pde"): - intermediate_dict["pde"] = file_ - elif file_.pathway.suffix == ".cif": - file_.name = f"Boltz_{model_number}" - file_ = self.update_chain_labels(file_) - intermediate_dict["cif"] = file_ + for pathway in self.output_dirs: + seed = pathway.name.split("_")[-1] + if seed not in file_groups: + file_groups[seed] = {} + + for output in pathway.rglob("*"): + number = output.stem.split("_model_")[-1] + if not number.isdigit(): + continue + number = int(number) + + file_type = output.suffix[1:] + + if file_type == FileTypes.NPZ.value: + file_ = NpzFile(str(output)) + elif file_type == FileTypes.CIF.value: + file_ = CifFile(str(output), self.input_params) + elif file_type == FileTypes.JSON.value: + file_ = ConfidenceJsonFile(str(output)) else: - intermediate_dict[file_.suffix] = file_ - - model_number_file_type_file[model_number] = intermediate_dict + continue + if number not in file_groups[seed]: + file_groups[seed][number] = [file_] + else: + file_groups[seed][number].append(file_) + + seed_dict = {} + for seed, models in file_groups.items(): + model_number_file_type_file = {} + for model_number, files in models.items(): + intermediate_dict = {} + for file_ in sorted(files, key=lambda x: x.suffix): + if file_.pathway.stem.startswith("pae"): + intermediate_dict["pae"] = file_ + elif file_.pathway.stem.startswith("plddt"): + intermediate_dict["plddt"] = file_ + elif file_.pathway.stem.startswith("pde"): + intermediate_dict["pde"] = file_ + elif file_.pathway.suffix == ".cif": + file_.name = f"Boltz_{seed}_{model_number}" + file_ = self.update_chain_labels(file_) + intermediate_dict["cif"] = file_ + else: + intermediate_dict[file_.suffix] = file_ + + model_number_file_type_file[model_number] = intermediate_dict + + model_number_file_type_file = { + key: model_number_file_type_file[key] + for key in sorted(model_number_file_type_file) + } + seed_dict[seed] = model_number_file_type_file - model_number_file_type_file = { - key: model_number_file_type_file[key] - for key in sorted(model_number_file_type_file) - } - return model_number_file_type_file + return seed_dict def add_plddt_to_cif(self): """ @@ -178,19 +229,33 @@ def pae_to_af3(self): Returns: None """ - for i, (pae_file, cif_file) in enumerate(zip(self.pae_files, self.cif_files)): - pae = Af3Pae.from_boltz( - pae_file.data, - cif_file, - ) - - out_name = cif_file.pathway.parent.joinpath( - cif_file.pathway.stem + "_af3_pae.json" - ) - - pae.to_file(out_name) - - self.output[i]["af3_pae"] = ConfidenceJsonFile(out_name) + new_pae_files = {} + for seed in self.seeds: + for (pae_file, cif_file) in zip(self.pae_files[seed], self.cif_files[seed]): + pae = Af3Pae.from_boltz( + pae_file.data, + cif_file, + ) + + out_name = pae_file.pathway + + pae.to_file(out_name) + + if seed not in new_pae_files: + new_pae_files[seed] = [] + new_pae_files[seed].append(ConfidenceJsonFile(out_name)) + + self.output = { + seed: { + i: { + "cif": cif_file, + "af3_pae": new_pae_files[seed][i], + "json": self.output[seed][i]["json"], + } + for i, cif_file in enumerate(self.cif_files[seed]) + } + for seed in self.seeds + } def update_chain_labels(self, cif_file) -> CifFile: """ @@ -213,7 +278,7 @@ def get_input_yaml(self) -> BoltzYaml: BoltzYaml: Object containing the input yaml file """ - by = BoltzYaml(self.output_dir, create_files=False) + by = BoltzYaml(self.output_dirs[0], create_files=False) by.json_to_yaml(self.input_params) return by diff --git a/abcfold/output/chai.py b/abcfold/output/chai.py index e8b920d..37fcc2f 100644 --- a/abcfold/output/chai.py +++ b/abcfold/output/chai.py @@ -1,4 +1,5 @@ import logging +import shutil from pathlib import Path from typing import Union @@ -13,7 +14,7 @@ class ChaiOutput: def __init__( self, - chai_output_dir: Union[str, Path], + chai_output_dirs: list[Union[str, Path]], input_params: dict, name: str, save_input: bool = False, @@ -22,7 +23,8 @@ def __init__( Object to process the output of an Chai-1 run Args: - chai_output_dir (Union[str, Path]): Path to the Chai-1 output directory + chai_output_dirs (list[Union[str, Path]]): Path to the Chai-1 + output directory input_params (dict): Dictionary containing the input parameters used for the Chai-1 run name (str): Name given to the Chai-1 run @@ -31,23 +33,24 @@ def __init__( Attributes: input_params (dict): Dictionary containing the input parameters used for the Chai-1 run - output_dir (Path): Path to the Chai-1 output directory + output_dirs (Path): List of paths to the Chai-1 output directory(s) name (str): Name given to the Chai-1 run output (dict): Dictionary containing the processed output the contents - of the Chai-1 output directory. The dictionary is structured as follows: + of the Chai-1 output directory(s). The dictionary is structured as follows: { - 1: { - "pae": NpzFile, - "cif": CifFile, - "scores": NpyFile - }, - 2: { - "pae": NpzFile, - "cif": CifFile, - "scores": NpyFile - }, - ... + "seed-1": { + 1: { + "pae": NpzFile, + "cif": CifFile, + "scores": NpyFile + }, + 2: { + "pae": NpzFile, + "cif": CifFile, + "scores": NpyFile + }, + etc... } pae_files (list): Ordered list of NpzFile objects containing the PAE data cif_files (list): Ordered list of CifFile objects containing the CIF data @@ -56,105 +59,165 @@ def __init__( """ self.input_params = input_params - self.output_dir = Path(chai_output_dir) + self.output_dirs = [Path(x) for x in chai_output_dirs] self.name = name self.save_input = save_input - if not self.output_dir.name.startswith("chai1_" + self.name): - self.output_dir = self.output_dir.rename( - self.output_dir.parent / f"chai1_{self.name}" - ) + parent_dir = self.output_dirs[0].parent + new_parent = parent_dir / f"chai1_{self.name}" + new_parent.mkdir(parents=True, exist_ok=True) + + if self.save_input: + chai_fasta = parent_dir / "chai1.fasta" + if chai_fasta.exists(): + chai_fasta.rename(new_parent / "chai1.fasta") + chai_msa = list(parent_dir.glob("*.aligned.pqt"))[0] + if chai_msa.exists(): + chai_msa.rename(new_parent / chai_msa.name) + + new_output_dirs = [] + for output_dir in self.output_dirs: + if output_dir.name.startswith("chai_output_"): + new_path = new_parent / output_dir.name + output_dir.rename(new_path) + new_output_dirs.append(new_path) + else: + new_output_dirs.append(output_dir) + self.output_dirs = new_output_dirs self.input_fasta = self.get_input_fasta() self.output = self.process_chai_output() - self.pae_files = [ - value["pae"] for value in self.output.values() if "pae" in value - ] - self.cif_files = [ - value["cif"] for value in self.output.values() if "cif" in value - ] + self.seeds = list(self.output.keys()) + + self.pae_files = { + seed: [value["pae"] for value in self.output[seed].values() + if "pae" in value] for seed in self.seeds + } + self.cif_files = { + seed: [value["cif"] for value in self.output[seed].values()] + for seed in self.seeds + } + self.scores_files = { + seed: [value["scores"] for value in self.output[seed].values()] + for seed in self.seeds + } self.pae_to_af3() - self.scores_files = [ - value["scores"] for value in self.output.values() if "scores" in value - ] - self.af3_pae_files = [ - value["af3_pae"] for value in self.output.values() if "af3_pae" in value - ] + self.af3_pae_files = { + seed: [value["af3_pae"] for value in self.output[seed].values()] + for seed in self.seeds + } def process_chai_output(self): file_groups = {} - if self.save_input: - self.output_dir = self.output_dir / "chai_output" - - for pathway in self.output_dir.iterdir(): - number = pathway.stem.split("model_idx_")[-1] - if number.isdigit(): - number = int(number) - - file_type = pathway.suffix[1:] - - if file_type == FileTypes.NPZ.value: - file_ = NpzFile(str(pathway)) - - elif file_type == FileTypes.CIF.value: - file_ = CifFile(str(pathway), self.input_params) - file_ = self.update_chain_labels(file_) - - elif file_type == FileTypes.NPY.value: - file_ = NpyFile(str(pathway)) - else: - continue - - if isinstance(number, str): - number = -1 - - if number not in file_groups: - file_groups[number] = [file_] - else: - file_groups[number].append(file_) - - model_number_file_type_file = {} - for model_number, files in file_groups.items(): - intermediate_dict = {} - for file_ in sorted(files, key=lambda x: x.suffix): - if file_.pathway.stem.startswith("scores.model"): - intermediate_dict["scores"] = file_ - elif file_.pathway.stem.startswith("pred.model"): - file_.name = f"Chai-1_{model_number}" - # Chai cif not recognised by pae-viewer, so we load and save - file_.to_file(file_.pathway) - intermediate_dict["cif"] = file_ - elif file_.pathway.stem.startswith("pae_scores"): - intermediate_dict["pae"] = file_ - - model_number_file_type_file[model_number] = intermediate_dict - - model_number_file_type_file = { - model_number: model_number_file_type_file[model_number] - for model_number in sorted(model_number_file_type_file) - } + for pathway in self.output_dirs: + seed = pathway.name.split("_")[-1] + if seed not in file_groups: + file_groups[seed] = {} + + for output in pathway.rglob("*"): + number = output.stem.split("model_idx_")[-1] + if number.isdigit(): + number = int(number) + + file_type = output.suffix[1:] + + if file_type == FileTypes.NPZ.value: + file_ = NpzFile(str(output)) + + elif file_type == FileTypes.CIF.value: + file_ = CifFile(str(output), self.input_params) + file_ = self.update_chain_labels(file_) + + elif file_type == FileTypes.NPY.value: + file_ = NpyFile(str(output)) + else: + continue + + if isinstance(number, str): + number = -1 + + if number not in file_groups[seed]: + file_groups[seed][number] = [file_] + else: + file_groups[seed][number].append(file_) + + seed_dict = {} + for seed, models in file_groups.items(): + model_number_file_type_file = {} + pae_file = None + if -1 in models: + for file_ in models[-1]: + if file_.pathway.stem.startswith("pae_scores"): + pae_file = file_ + break + + for model_number, files in models.items(): + if model_number == -1: + continue + intermediate_dict = {} + for file_ in sorted(files, key=lambda x: x.suffix): + if file_.pathway.stem.startswith("scores.model"): + intermediate_dict["scores"] = file_ + elif file_.pathway.stem.startswith("pred.model"): + file_.name = f"Chai-1_{seed}_{model_number}" + # Chai cif not recognised by pae-viewer, so we load and save + file_.to_file(file_.pathway) + intermediate_dict["cif"] = file_ + if model_number != -1 and pae_file is not None: + new_pae_path = ( + file_.pathway.parent / f"pae_scores_model_{model_number}.npy" + ) + shutil.copy(pae_file.pathway, new_pae_path) + intermediate_dict["pae"] = NpyFile(str(new_pae_path)) + + model_number_file_type_file[model_number] = intermediate_dict + + model_number_file_type_file = { + model_number: model_number_file_type_file[model_number] + for model_number in sorted(model_number_file_type_file) + } + seed_dict[seed] = model_number_file_type_file - return model_number_file_type_file + return seed_dict - def pae_to_af3(self) -> None: + def pae_to_af3(self): """ Convert the Chai-1 PAE data to the format expected by AlphaFold3 + Returns: + None """ - - pae_file = self.pae_files[-1] - for i, cif_file in enumerate(self.cif_files): - pae = Af3Pae.from_chai1( - pae_file.data[i], - cif_file, - ) - - out_name = self.output_dir.joinpath(cif_file.pathway.stem + "_af3_pae.json") - pae.to_file(out_name) - - self.output[i]["af3_pae"] = ConfidenceJsonFile(out_name) + new_pae_files = {} + for seed in self.seeds: + for i, (pae_file, cif_file) in enumerate( + zip(self.pae_files[seed], self.cif_files[seed]) + ): + pae = Af3Pae.from_chai1( + pae_file.data[i], + cif_file, + ) + + out_name = pae_file.pathway + + pae.to_file(out_name) + + if seed not in new_pae_files: + new_pae_files[seed] = [] + new_pae_files[seed].append(ConfidenceJsonFile(out_name)) + + self.output = { + seed: { + i: { + "cif": cif_file, + "af3_pae": new_pae_files[seed][i], + "scores": self.output[seed][i]["scores"], + } + for i, cif_file in enumerate(self.cif_files[seed]) + } + for seed in self.seeds + } def get_input_fasta(self) -> ChaiFasta: """ @@ -165,7 +228,7 @@ def get_input_fasta(self) -> ChaiFasta: """ - ch = ChaiFasta(self.output_dir, create_files=False) + ch = ChaiFasta(self.output_dirs[0], create_files=False) ch.json_to_fasta(self.input_params) return ch diff --git a/abcfold/output/file_handlers.py b/abcfold/output/file_handlers.py index 61fffaa..6851f2f 100644 --- a/abcfold/output/file_handlers.py +++ b/abcfold/output/file_handlers.py @@ -100,7 +100,7 @@ def __init__(self, npz_file: Union[str, Path]): self.data = self.load_npz_file() def load_npz_file(self) -> dict: - return dict(np.load(self.npz_file)) + return dict(np.load(self.npz_file, allow_pickle=True)) class NpyFile(FileBase): @@ -121,7 +121,7 @@ def __init__(self, npy_file: Union[str, Path]): self.data = self.load_npy_file() def load_npy_file(self) -> np.ndarray: - return np.load(self.npy_file) + return np.load(self.npy_file, allow_pickle=True) class CifFile(FileBase): diff --git a/abcfold/plots/pae_plot.py b/abcfold/plots/pae_plot.py index c6293b6..c827631 100644 --- a/abcfold/plots/pae_plot.py +++ b/abcfold/plots/pae_plot.py @@ -97,6 +97,20 @@ def create_pae_plots( ) run_script(cmd) + for seed in output.seeds: + run_scripts.extend( + prepare_scripts( + output.cif_files[seed], + output.af3_pae_files[seed], + plots_dir, + pathway_plot, + template_file, + True, + ) + ) + + continue + elif isinstance(output, ChaiOutput): css_path = CSSPATHS["C"] template_file = plots_dir.joinpath("chai_template.html") @@ -109,6 +123,20 @@ def create_pae_plots( ) run_script(cmd) + for seed in output.seeds: + run_scripts.extend( + prepare_scripts( + output.cif_files[seed], + output.af3_pae_files[seed], + plots_dir, + pathway_plot, + template_file, + True, + ) + ) + + continue + elif isinstance(output, AlphafoldOutput): css_path = CSSPATHS["A"] template_file = plots_dir.joinpath("af3_template.html") @@ -134,21 +162,11 @@ def create_pae_plots( ) continue + else: logger.error("Invalid output type") raise ValueError() - run_scripts.extend( - prepare_scripts( - output.cif_files, - output.af3_pae_files, - plots_dir, - pathway_plot, - template_file, - False, - ) - ) - processes = [Process(target=run_script, args=(script,)) for script in run_scripts] for process in processes: process.start() diff --git a/pyproject.toml b/pyproject.toml index 8b4baf1..0df08cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "ABCFold" -version = "1.0.6" +version = "1.0.7" description = "Input processing tools for AlphaFold3, Boltz and Chai-1" readme = "README.md" license = { text = "BSD License" } diff --git a/tests/conftest.py b/tests/conftest.py index ab7cc3b..97c7348 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ import json import logging +import shutil +import tempfile from collections import namedtuple from pathlib import Path @@ -58,36 +60,46 @@ def output_objs(): d = {} adir = data_dir.joinpath("alphafold3_6BJ9") - bdir = data_dir.joinpath("boltz_6BJ9") - cdir = data_dir.joinpath("chai1_6BJ9") + bdir = data_dir.joinpath("boltz_6BJ9_seed-1") + cdir = data_dir.joinpath("chai1_6BJ9_seed-1") name = "6BJ9" input_params = adir.joinpath("6bj9_data.json") - with open(input_params, "r") as f: - input_params = json.load(f) - - af3_output = AlphafoldOutput( - adir, - input_params.copy(), - name, - ) - boltz_output = BoltzOutput( - bdir, - input_params.copy(), - name, - ) - - chai_output = ChaiOutput( - cdir, - input_params.copy(), - name, - ) - - d["af3_output"] = af3_output - d["boltz_output"] = boltz_output - d["chai_output"] = chai_output - - nt = namedtuple("output_objs", d) - n = nt(**d) - - yield n + # Create temporary directories + with tempfile.TemporaryDirectory() as temp_dir: + temp_adir = Path(temp_dir) / "alphafold3_6BJ9" + temp_bdir = Path(temp_dir) / "boltz_6BJ9_seed-1" + temp_cdir = Path(temp_dir) / "chai1_6BJ9_seed-1" + + shutil.copytree(adir, temp_adir) + shutil.copytree(bdir, temp_bdir) + shutil.copytree(cdir, temp_cdir) + + with open(input_params, "r") as f: + input_params = json.load(f) + + af3_output = AlphafoldOutput( + temp_adir, + input_params.copy(), + name, + ) + boltz_output = BoltzOutput( + [temp_bdir], + input_params.copy(), + name, + ) + + chai_output = ChaiOutput( + [temp_cdir], + input_params.copy(), + name, + ) + + d["af3_output"] = af3_output + d["boltz_output"] = boltz_output + d["chai_output"] = chai_output + + nt = namedtuple("output_objs", d) + n = nt(**d) + + yield n diff --git a/tests/test_af3_output.py b/tests/test_af3_output.py index 92f20b8..25dbd31 100644 --- a/tests/test_af3_output.py +++ b/tests/test_af3_output.py @@ -1,6 +1,11 @@ +from pathlib import Path + + def test_process_af3_output(test_data, output_objs): af3_output = output_objs.af3_output - assert str(af3_output.output_dir) == test_data.test_alphafold3_6BJ9_ + assert af3_output.output_dir.relative_to( + af3_output.output_dir.parent + ) == Path(test_data.test_alphafold3_6BJ9_).relative_to("tests/test_data") assert "seed-1" in af3_output.output diff --git a/tests/test_boltz_output.py b/tests/test_boltz_output.py index 8d453d8..ca03d0d 100644 --- a/tests/test_boltz_output.py +++ b/tests/test_boltz_output.py @@ -7,49 +7,51 @@ def test_process_boltz_output(test_data, output_objs): boltz_output = output_objs.boltz_output - assert str(boltz_output.output_dir) == str( - Path(test_data.test_boltz_6BJ9_).parent.joinpath("boltz_6BJ9") - ) + assert boltz_output.output_dirs[0].relative_to( + boltz_output.output_dirs[0].parent + ) == Path(test_data.test_boltz_6BJ9_seed_1_).relative_to("tests/test_data") assert boltz_output.name == "6BJ9" - assert 0 in boltz_output.output - assert 1 in boltz_output.output + assert 0 in boltz_output.output['seed-1'] + assert 1 in boltz_output.output['seed-1'] - assert "plddt" in boltz_output.output[0] - assert "pae" in boltz_output.output[0] - assert "pde" in boltz_output.output[0] - assert "cif" in boltz_output.output[0] - assert "json" in boltz_output.output[0] + assert "af3_pae" in boltz_output.output['seed-1'][0] + assert "cif" in boltz_output.output['seed-1'][0] + assert "json" in boltz_output.output['seed-1'][0] - assert "plddt" in boltz_output.output[1] - assert "pae" in boltz_output.output[1] - assert "pde" in boltz_output.output[1] - assert "cif" in boltz_output.output[1] - assert "json" in boltz_output.output[1] + assert "af3_pae" in boltz_output.output['seed-1'][1] + assert "cif" in boltz_output.output['seed-1'][1] + assert "json" in boltz_output.output['seed-1'][1] - assert all(isinstance(pae_file, NpzFile) for pae_file in boltz_output.pae_files) assert all( - isinstance(plddt_file, NpzFile) for plddt_file in boltz_output.plddt_files + isinstance(pae_file, NpzFile) for pae_file in boltz_output.pae_files['seed-1'] + ) + assert all( + isinstance(plddt_file, NpzFile) + for plddt_file in boltz_output.plddt_files['seed-1'] ) - assert all(isinstance(pde_file, NpzFile) for pde_file in boltz_output.pde_files) - assert all(isinstance(cif_file, CifFile) for cif_file in boltz_output.cif_files) + assert all( + isinstance(pde_file, NpzFile) for pde_file in boltz_output.pde_files['seed-1'] + ) + assert all( + isinstance(cif_file, CifFile) for cif_file in boltz_output.cif_files['seed-1'] + ) assert all( isinstance(scores_file, ConfidenceJsonFile) - for scores_file in boltz_output.scores_files + for scores_file in boltz_output.scores_files['seed-1'] ) - assert boltz_output.cif_files[0].chain_lengths() == { + assert boltz_output.cif_files['seed-1'][0].chain_lengths() == { "A": 393, "B": 393, "C": 1, "D": 1, } - # boltz_output.add_plddt_to_cif() with tempfile.TemporaryDirectory() as temp_dir_str: temp_dir = Path(temp_dir_str) - for i, cif_file in enumerate(boltz_output.cif_files): + for i, cif_file in enumerate(boltz_output.cif_files['seed-1']): cif_file.to_file(temp_dir / f"{i}.cif") assert (temp_dir / f"{i}.cif").exists() @@ -57,7 +59,8 @@ def test_process_boltz_output(test_data, output_objs): def test_boltz_pae_to_af3_pae(test_data, output_objs): comparison_af3_output = output_objs.af3_output.af3_pae_files["seed-1"][0].data for pae_file, cif_file in zip( - output_objs.boltz_output.pae_files, output_objs.boltz_output.cif_files + output_objs.boltz_output.pae_files['seed-1'], + output_objs.boltz_output.cif_files['seed-1'] ): pae = Af3Pae.from_boltz( pae_file.data, diff --git a/tests/test_chai_output.py b/tests/test_chai_output.py index 265985a..1db6b43 100644 --- a/tests/test_chai_output.py +++ b/tests/test_chai_output.py @@ -8,26 +8,34 @@ def test_process_chai_output(test_data, output_objs): chai_output = output_objs.chai_output - assert str(chai_output.output_dir) == str(test_data.test_chai1_6BJ9_) + assert chai_output.output_dirs[0].relative_to( + chai_output.output_dirs[0].parent + ) == Path(test_data.test_chai1_6BJ9_seed_1_).relative_to("tests/test_data") - assert -1 in chai_output.output - assert 0 in chai_output.output - assert 1 in chai_output.output + assert 0 in chai_output.output['seed-1'] + assert 1 in chai_output.output['seed-1'] - assert all(isinstance(pae_file, NpyFile) for pae_file in chai_output.pae_files) - assert all(isinstance(cif_file, CifFile) for cif_file in chai_output.cif_files) assert all( - isinstance(scores_file, NpzFile) for scores_file in chai_output.scores_files + isinstance(pae_file, NpyFile) for pae_file in chai_output.pae_files['seed-1'] + ) + assert all( + isinstance(cif_file, CifFile) for cif_file in chai_output.cif_files['seed-1'] + ) + assert all( + isinstance(scores_file, NpzFile) + for scores_file in chai_output.scores_files['seed-1'] ) def test_chai_pae_to_af3_pae(output_objs): comparison_af3_output = output_objs.af3_output.af3_pae_files["seed-1"][0].data - pae_file = output_objs.chai_output.pae_files[-1] - for i, cif_file in enumerate(output_objs.chai_output.cif_files): + for pae_file, cif_file in zip( + output_objs.chai_output.pae_files['seed-1'], + output_objs.chai_output.cif_files['seed-1'] + ): assert cif_file.input_params pae = Af3Pae.from_chai1( - pae_file.data[i], + pae_file.data, cif_file, ) @@ -38,9 +46,9 @@ def test_chai_pae_to_af3_pae(output_objs): # for some reason the lengths are different for atom - realted things # If it isn't breaking the output page generation, then it's fine - assert len(pae.scores["pae"]) == len(comparison_af3_output["pae"]) + assert len(pae.scores["pae"][0]) == len(comparison_af3_output["pae"]) - assert len(pae.scores["contact_probs"]) == len( + assert len(pae.scores["contact_probs"][0]) == len( comparison_af3_output["contact_probs"] ) assert len(pae.scores["token_chain_ids"]) == len( diff --git a/tests/test_data/boltz_6BJ9/predictions/test_mmseqs/confidence_test_mmseqs_model_0.json b/tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/confidence_test_mmseqs_model_0.json similarity index 100% rename from tests/test_data/boltz_6BJ9/predictions/test_mmseqs/confidence_test_mmseqs_model_0.json rename to tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/confidence_test_mmseqs_model_0.json diff --git a/tests/test_data/boltz_6BJ9/predictions/test_mmseqs/confidence_test_mmseqs_model_1.json b/tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/confidence_test_mmseqs_model_1.json similarity index 100% rename from tests/test_data/boltz_6BJ9/predictions/test_mmseqs/confidence_test_mmseqs_model_1.json rename to tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/confidence_test_mmseqs_model_1.json diff --git a/tests/test_data/boltz_6BJ9/predictions/test_mmseqs/pae_test_mmseqs_model_0.npz b/tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/pae_test_mmseqs_model_0.npz similarity index 100% rename from tests/test_data/boltz_6BJ9/predictions/test_mmseqs/pae_test_mmseqs_model_0.npz rename to tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/pae_test_mmseqs_model_0.npz diff --git a/tests/test_data/boltz_6BJ9/predictions/test_mmseqs/pae_test_mmseqs_model_1.npz b/tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/pae_test_mmseqs_model_1.npz similarity index 100% rename from tests/test_data/boltz_6BJ9/predictions/test_mmseqs/pae_test_mmseqs_model_1.npz rename to tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/pae_test_mmseqs_model_1.npz diff --git a/tests/test_data/boltz_6BJ9/predictions/test_mmseqs/pde_test_mmseqs_model_0.npz b/tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/pde_test_mmseqs_model_0.npz similarity index 100% rename from tests/test_data/boltz_6BJ9/predictions/test_mmseqs/pde_test_mmseqs_model_0.npz rename to tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/pde_test_mmseqs_model_0.npz diff --git a/tests/test_data/boltz_6BJ9/predictions/test_mmseqs/pde_test_mmseqs_model_1.npz b/tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/pde_test_mmseqs_model_1.npz similarity index 100% rename from tests/test_data/boltz_6BJ9/predictions/test_mmseqs/pde_test_mmseqs_model_1.npz rename to tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/pde_test_mmseqs_model_1.npz diff --git a/tests/test_data/boltz_6BJ9/predictions/test_mmseqs/plddt_test_mmseqs_model_0.npz b/tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/plddt_test_mmseqs_model_0.npz similarity index 100% rename from tests/test_data/boltz_6BJ9/predictions/test_mmseqs/plddt_test_mmseqs_model_0.npz rename to tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/plddt_test_mmseqs_model_0.npz diff --git a/tests/test_data/boltz_6BJ9/predictions/test_mmseqs/plddt_test_mmseqs_model_1.npz b/tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/plddt_test_mmseqs_model_1.npz similarity index 100% rename from tests/test_data/boltz_6BJ9/predictions/test_mmseqs/plddt_test_mmseqs_model_1.npz rename to tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/plddt_test_mmseqs_model_1.npz diff --git a/tests/test_data/boltz_6BJ9/predictions/test_mmseqs/test_mmseqs_model_0.cif b/tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/test_mmseqs_model_0.cif similarity index 100% rename from tests/test_data/boltz_6BJ9/predictions/test_mmseqs/test_mmseqs_model_0.cif rename to tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/test_mmseqs_model_0.cif diff --git a/tests/test_data/boltz_6BJ9/predictions/test_mmseqs/test_mmseqs_model_0_af3_pae.json b/tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/test_mmseqs_model_0_af3_pae.json similarity index 100% rename from tests/test_data/boltz_6BJ9/predictions/test_mmseqs/test_mmseqs_model_0_af3_pae.json rename to tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/test_mmseqs_model_0_af3_pae.json diff --git a/tests/test_data/boltz_6BJ9/predictions/test_mmseqs/test_mmseqs_model_1.cif b/tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/test_mmseqs_model_1.cif similarity index 100% rename from tests/test_data/boltz_6BJ9/predictions/test_mmseqs/test_mmseqs_model_1.cif rename to tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/test_mmseqs_model_1.cif diff --git a/tests/test_data/boltz_6BJ9/predictions/test_mmseqs/test_mmseqs_model_1_af3_pae.json b/tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/test_mmseqs_model_1_af3_pae.json similarity index 100% rename from tests/test_data/boltz_6BJ9/predictions/test_mmseqs/test_mmseqs_model_1_af3_pae.json rename to tests/test_data/boltz_6BJ9_seed-1/predictions/test_mmseqs/test_mmseqs_model_1_af3_pae.json diff --git a/tests/test_data/chai1_6BJ9/msa_depth.pdf b/tests/test_data/chai1_6BJ9_seed-1/msa_depth.pdf similarity index 100% rename from tests/test_data/chai1_6BJ9/msa_depth.pdf rename to tests/test_data/chai1_6BJ9_seed-1/msa_depth.pdf diff --git a/tests/test_data/chai1_6BJ9/pae_scores.npy b/tests/test_data/chai1_6BJ9_seed-1/pae_scores.npy similarity index 100% rename from tests/test_data/chai1_6BJ9/pae_scores.npy rename to tests/test_data/chai1_6BJ9_seed-1/pae_scores.npy diff --git a/tests/test_data/chai1_6BJ9/pred.model_idx_0.cif b/tests/test_data/chai1_6BJ9_seed-1/pred.model_idx_0.cif similarity index 100% rename from tests/test_data/chai1_6BJ9/pred.model_idx_0.cif rename to tests/test_data/chai1_6BJ9_seed-1/pred.model_idx_0.cif diff --git a/tests/test_data/chai1_6BJ9/pred.model_idx_0_af3_pae.json b/tests/test_data/chai1_6BJ9_seed-1/pred.model_idx_0_af3_pae.json similarity index 100% rename from tests/test_data/chai1_6BJ9/pred.model_idx_0_af3_pae.json rename to tests/test_data/chai1_6BJ9_seed-1/pred.model_idx_0_af3_pae.json diff --git a/tests/test_data/chai1_6BJ9/pred.model_idx_1.cif b/tests/test_data/chai1_6BJ9_seed-1/pred.model_idx_1.cif similarity index 100% rename from tests/test_data/chai1_6BJ9/pred.model_idx_1.cif rename to tests/test_data/chai1_6BJ9_seed-1/pred.model_idx_1.cif diff --git a/tests/test_data/chai1_6BJ9/pred.model_idx_1_af3_pae.json b/tests/test_data/chai1_6BJ9_seed-1/pred.model_idx_1_af3_pae.json similarity index 100% rename from tests/test_data/chai1_6BJ9/pred.model_idx_1_af3_pae.json rename to tests/test_data/chai1_6BJ9_seed-1/pred.model_idx_1_af3_pae.json diff --git a/tests/test_data/chai1_6BJ9/scores.model_idx_0.npz b/tests/test_data/chai1_6BJ9_seed-1/scores.model_idx_0.npz similarity index 100% rename from tests/test_data/chai1_6BJ9/scores.model_idx_0.npz rename to tests/test_data/chai1_6BJ9_seed-1/scores.model_idx_0.npz diff --git a/tests/test_data/chai1_6BJ9/scores.model_idx_1.npz b/tests/test_data/chai1_6BJ9_seed-1/scores.model_idx_1.npz similarity index 100% rename from tests/test_data/chai1_6BJ9/scores.model_idx_1.npz rename to tests/test_data/chai1_6BJ9_seed-1/scores.model_idx_1.npz diff --git a/tests/test_file_handlers.py b/tests/test_file_handlers.py index 09274b0..8fc835b 100644 --- a/tests/test_file_handlers.py +++ b/tests/test_file_handlers.py @@ -9,7 +9,7 @@ def test_npz_file(test_data): - test_npz = Path(test_data.test_boltz_6BJ9_).joinpath( + test_npz = Path(test_data.test_boltz_6BJ9_seed_1_).joinpath( "predictions/test_mmseqs/pae_test_mmseqs_model_1.npz" ) npz_file = file_handlers.NpzFile(test_npz) @@ -19,7 +19,7 @@ def test_npz_file(test_data): def test_npy_file(test_data): - test_npy = Path(test_data.test_chai1_6BJ9_).joinpath("pae_scores.npy") + test_npy = Path(test_data.test_chai1_6BJ9_seed_1_).joinpath("pae_scores.npy") npy_file = file_handlers.NpyFile(test_npy) assert npy_file.data.shape == (2, 888, 888) @@ -152,7 +152,7 @@ def test_superpose_models(test_data): model_1 = Path(test_data.test_alphafold3_6BJ9_).joinpath( "seed-1_sample-0/model.cif" ) - model_2 = Path(test_data.test_boltz_6BJ9_).joinpath( + model_2 = Path(test_data.test_boltz_6BJ9_seed_1_).joinpath( "predictions/test_mmseqs/test_mmseqs_model_0.cif" ) diff --git a/tests/test_plots.py b/tests/test_plots.py index f8b4685..837e7f5 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -9,8 +9,8 @@ def test_plddt_plot(output_objs): af3_files = output_objs.af3_output.cif_files["seed-1"] - boltz_files = output_objs.boltz_output.cif_files - chai_files = output_objs.chai_output.cif_files + boltz_files = output_objs.boltz_output.cif_files["seed-1"] + chai_files = output_objs.chai_output.cif_files["seed-1"] assert len(af3_files) == len(boltz_files) == len(chai_files) plot_files = { @@ -44,41 +44,47 @@ def test_pae_plots(output_objs): assert "confidences_seed-1_sample-0_af3_pae_plot.html" in values assert "confidences_seed-1_sample-1_af3_pae_plot.html" in values - assert "test_mmseqs_model_0_af3_pae_pae_plot.html" in values - assert "test_mmseqs_model_1_af3_pae_pae_plot.html" in values - assert "pred.model_idx_0_af3_pae_pae_plot.html" in values - assert "pred.model_idx_1_af3_pae_pae_plot.html" in values + assert "pae_test_mmseqs_model_0_test_mmseqs_af3_pae_plot.html" in values + assert "pae_test_mmseqs_model_1_test_mmseqs_af3_pae_plot.html" in values + assert "pae_scores_model_0_chai1_6BJ9_seed-1_af3_pae_plot.html" in values + assert "pae_scores_model_1_chai1_6BJ9_seed-1_af3_pae_plot.html" in values assert ( - "tests/test_data/alphafold3_6BJ9/seed-1_sample-0/\ -model.cif" - in plot_pathways + any( + "alphafold3_6BJ9/seed-1_sample-0/model.cif" in x for x in plot_pathways + ) ) assert ( - "tests/test_data/alphafold3_6BJ9/seed-1_sample-1/\ -model.cif" - in plot_pathways + any( + "alphafold3_6BJ9/seed-1_sample-1/model.cif" in x for x in plot_pathways + ) ) assert ( - "tests/test_data/boltz_6BJ9/predictions/test_mmseqs/\ -test_mmseqs_model_0.cif" - in plot_pathways + any( + "boltz_6BJ9_seed-1/predictions/test_mmseqs/test_mmseqs_model_0.cif" + in x for x in plot_pathways + ) ) assert ( - "tests/test_data/boltz_6BJ9/predictions/test_mmseqs/\ -test_mmseqs_model_1.cif" - in plot_pathways + any( + "boltz_6BJ9_seed-1/predictions/test_mmseqs/test_mmseqs_model_1.cif" + in x for x in plot_pathways + ) + ) + assert ( + any("chai1_6BJ9_seed-1/pred.model_idx_0.cif" in x for x in plot_pathways) + ) + assert ( + any("chai1_6BJ9_seed-1/pred.model_idx_1.cif" in x for x in plot_pathways) ) - assert "tests/test_data/chai1_6BJ9/pred.model_idx_0.cif" in plot_pathways - assert "tests/test_data/chai1_6BJ9/pred.model_idx_1.cif" in plot_pathways assert len(list(temp_dir.glob("*.html"))) == 6 def test_get_sequence_data(output_objs): af3_files = output_objs.af3_output.cif_files["seed-1"] - boltz_files = output_objs.boltz_output.cif_files - chai_files = output_objs.chai_output.cif_files + boltz_files = output_objs.boltz_output.cif_files["seed-1"] + chai_files = output_objs.chai_output.cif_files["seed-1"] cif_files = [] diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 4e670db..145e762 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -71,7 +71,7 @@ def test_check_input_json(test_data): def test_clash_checker(test_data): - cif_file = Path(test_data.test_boltz_6BJ9_).joinpath( + cif_file = Path(test_data.test_boltz_6BJ9_seed_1_).joinpath( "predictions", "test_mmseqs", "test_mmseqs_model_0.cif" )