Skip to content

Commit c48f850

Browse files
authored
Merge pull request #408 from rostro36/main
Add use_custom_templates option
2 parents f37d0d9 + 7d22739 commit c48f850

File tree

4 files changed

+114
-60
lines changed

4 files changed

+114
-60
lines changed

docs/source/Inference.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ Some commonly used command line flags are here. A full list of flags can be view
138138
- `--data_random_seed`: Specifies a random seed to use.
139139
- `--save_outputs`: Saves a copy of all outputs from the model, e.g. the output of the msa track, ptm heads.
140140
- `--experiment_config_json`: Specify configuration settings using a json file. For example, passing a json with `{globals.relax.max_iterations = 10}` specifies 10 as the maximum number of relaxation iterations. See for [`openfold/config.py`](https://github.com/aqlaboratory/openfold/blob/main/openfold/config.py#L283) the full dictionary of configuration settings. Any parameters that are not manually set in these configuration settings will refer to the defaults specified by your `config_preset`.
141+
- `--use_custom_template`: Uses all .cif files in `template_mmcif_dir` as template input. Make sure the chains of interest have the identifier _A_ and have the same length as the input sequence. The same templates will be read for all sequences that are passed for inference.
141142

142143

143144
### Advanced Options for Increasing Efficiency
@@ -159,12 +160,12 @@ Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement)
159160
#### Long sequence inference
160161
To minimize memory usage during inference on long sequences, consider the following changes:
161162

162-
- As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template stack is a major memory bottleneck for inference on long sequences. OpenFold supports two mutually exclusive inference modes to address this issue. One, `average_templates` in the `template` section of the config, is similar to the solution offered by AlphaFold-Multimer, which is simply to average individual template representations. Our version is modified slightly to accommodate weights trained using the standard template algorithm. Using said weights, we notice no significant difference in performance between our averaged template embeddings and the standard ones. The second, `offload_templates`, temporarily offloads individual template embeddings into CPU memory. The former is an approximation while the latter is slightly slower; both are memory-efficient and allow the model to utilize arbitrarily many templates across sequence lengths. Both are disabled by default, and it is up to the user to determine which best suits their needs, if either.
163-
- Inference-time low-memory attention (LMA) can be enabled in the model config. This setting trades off speed for vastly improved memory usage. By default, LMA is run with query and key chunk sizes of 1024 and 4096, respectively. These represent a favorable tradeoff in most memory-constrained cases. Powerusers can choose to tweak these settings in `openfold/model/primitives.py`. For more information on the LMA algorithm, see the aforementioned Staats & Rabe preprint.
164-
- Disable `tune_chunk_size` for long sequences. Past a certain point, it only wastes time.
165-
- As a last resort, consider enabling `offload_inference`. This enables more extensive CPU offloading at various bottlenecks throughout the model.
163+
- As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template stack is a major memory bottleneck for inference on long sequences. OpenFold supports two mutually exclusive inference modes to address this issue. One, `average_templates` in the `template` section of the config, is similar to the solution offered by AlphaFold-Multimer, which is simply to average individual template representations. Our version is modified slightly to accommodate weights trained using the standard template algorithm. Using said weights, we notice no significant difference in performance between our averaged template embeddings and the standard ones. The second, `offload_templates`, temporarily offloads individual template embeddings into CPU memory. The former is an approximation while the latter is slightly slower; both are memory-efficient and allow the model to utilize arbitrarily many templates across sequence lengths. Both are disabled by default, and it is up to the user to determine which best suits their needs, if either.
164+
- Inference-time low-memory attention (LMA) can be enabled in the model config. This setting trades off speed for vastly improved memory usage. By default, LMA is run with query and key chunk sizes of 1024 and 4096, respectively. These represent a favorable tradeoff in most memory-constrained cases. Powerusers can choose to tweak these settings in `openfold/model/primitives.py`. For more information on the LMA algorithm, see the aforementioned Staats & Rabe preprint.
165+
- Disable `tune_chunk_size` for long sequences. Past a certain point, it only wastes time.
166+
- As a last resort, consider enabling `offload_inference`. This enables more extensive CPU offloading at various bottlenecks throughout the model.
166167
- Disable FlashAttention, which seems unstable on long sequences.
167168

168-
Using the most conservative settings, we were able to run inference on a 4600-residue complex with a single A100. Compared to AlphaFold's own memory offloading mode, ours is considerably faster; the same complex takes the more efficent AlphaFold-Multimer more than double the time. Use the `long_sequence_inference` config option to enable all of these interventions at once. The `run_pretrained_openfold.py` script can enable this config option with the `--long_sequence_inference` command line option
169+
Using the most conservative settings, we were able to run inference on a 4600-residue complex with a single A100. Compared to AlphaFold's own memory offloading mode, ours is considerably faster; the same complex takes the more efficent AlphaFold-Multimer more than double the time. Use the `long_sequence_inference` config option to enable all of these interventions at once. The `run_pretrained_openfold.py` script can enable this config option with the `--long_sequence_inference` command line option
169170

170-
Input FASTA files containing multiple sequences are treated as complexes. In this case, the inference script runs AlphaFold-Gap, a hack proposed [here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer).
171+
Input FASTA files containing multiple sequences are treated as complexes. In this case, the inference script runs AlphaFold-Gap, a hack proposed [here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer).

openfold/data/data_pipeline.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,19 @@
2323
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
2424
import numpy as np
2525
import torch
26-
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
27-
from openfold.data.templates import get_custom_template_features, empty_template_feats
26+
from openfold.data import (
27+
templates,
28+
parsers,
29+
mmcif_parsing,
30+
msa_identifiers,
31+
msa_pairing,
32+
feature_processing_multimer,
33+
)
34+
from openfold.data.templates import (
35+
get_custom_template_features,
36+
empty_template_feats,
37+
CustomHitFeaturizer,
38+
)
2839
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
2940
from openfold.np import residue_constants, protein
3041

@@ -38,7 +49,9 @@ def make_template_features(
3849
template_featurizer: Any,
3950
) -> FeatureDict:
4051
hits_cat = sum(hits.values(), [])
41-
if(len(hits_cat) == 0 or template_featurizer is None):
52+
if template_featurizer is None or (
53+
len(hits_cat) == 0 and not isinstance(template_featurizer, CustomHitFeaturizer)
54+
):
4255
template_features = empty_template_feats(len(input_sequence))
4356
else:
4457
templates_result = template_featurizer.get_templates(

openfold/data/templates.py

Lines changed: 78 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import json
2323
import logging
2424
import os
25+
from pathlib import Path
2526
import re
2627
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
2728

@@ -947,55 +948,71 @@ def _process_single_hit(
947948

948949

949950
def get_custom_template_features(
950-
mmcif_path: str,
951-
query_sequence: str,
952-
pdb_id: str,
953-
chain_id: str,
954-
kalign_binary_path: str):
955-
956-
with open(mmcif_path, "r") as mmcif_path:
957-
cif_string = mmcif_path.read()
958-
959-
mmcif_parse_result = mmcif_parsing.parse(
960-
file_id=pdb_id, mmcif_string=cif_string
961-
)
962-
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]
963-
964-
965-
mapping = {x:x for x, _ in enumerate(query_sequence)}
966-
967-
968-
features, warnings = _extract_template_features(
969-
mmcif_object=mmcif_parse_result.mmcif_object,
970-
pdb_id=pdb_id,
971-
mapping=mapping,
972-
template_sequence=template_sequence,
973-
query_sequence=query_sequence,
974-
template_chain_id=chain_id,
975-
kalign_binary_path=kalign_binary_path,
976-
_zero_center_positions=True
977-
)
978-
features["template_sum_probs"] = [1.0]
979-
980-
# TODO: clean up this logic
981-
template_features = {}
982-
for template_feature_name in TEMPLATE_FEATURES:
983-
template_features[template_feature_name] = []
984-
985-
for k in template_features:
986-
template_features[k].append(features[k])
987-
988-
for name in template_features:
989-
template_features[name] = np.stack(
990-
template_features[name], axis=0
991-
).astype(TEMPLATE_FEATURES[name])
951+
mmcif_path: str,
952+
query_sequence: str,
953+
pdb_id: str,
954+
chain_id: Optional[str] = "A",
955+
kalign_binary_path: Optional[str] = None,
956+
):
957+
if os.path.isfile(mmcif_path):
958+
template_paths = [Path(mmcif_path)]
992959

960+
elif os.path.isdir(mmcif_path):
961+
template_paths = list(Path(mmcif_path).glob("*.cif"))
962+
else:
963+
logging.error("Custom template path %s does not exist", mmcif_path)
964+
raise ValueError(f"Custom template path {mmcif_path} does not exist")
965+
966+
warnings = []
967+
template_features = dict()
968+
for template_path in template_paths:
969+
logging.info("Featurizing template: %s", template_path)
970+
# pdb_id only for error reporting, take file name
971+
pdb_id = Path(template_path).stem
972+
with open(template_path, "r") as mmcif_path:
973+
cif_string = mmcif_path.read()
974+
mmcif_parse_result = mmcif_parsing.parse(
975+
file_id=pdb_id, mmcif_string=cif_string
976+
)
977+
# mapping skipping "-"
978+
mapping = {
979+
x: x for x, curr_char in enumerate(query_sequence) if curr_char.isalnum()
980+
}
981+
realigned_sequence, realigned_mapping = _realign_pdb_template_to_query(
982+
old_template_sequence=query_sequence,
983+
template_chain_id=chain_id,
984+
mmcif_object=mmcif_parse_result.mmcif_object,
985+
old_mapping=mapping,
986+
kalign_binary_path=kalign_binary_path,
987+
)
988+
curr_features, curr_warnings = _extract_template_features(
989+
mmcif_object=mmcif_parse_result.mmcif_object,
990+
pdb_id=pdb_id,
991+
mapping=realigned_mapping,
992+
template_sequence=realigned_sequence,
993+
query_sequence=query_sequence,
994+
template_chain_id=chain_id,
995+
kalign_binary_path=kalign_binary_path,
996+
_zero_center_positions=True,
997+
)
998+
curr_features["template_sum_probs"] = [
999+
1.0
1000+
] # template given by user, 100% confident
1001+
template_features = {
1002+
curr_name: template_features.get(curr_name, []) + [curr_item]
1003+
for curr_name, curr_item in curr_features.items()
1004+
}
1005+
warnings.append(curr_warnings)
1006+
template_features = {
1007+
template_feature_name: np.stack(
1008+
template_features[template_feature_name], axis=0
1009+
).astype(template_feature_type)
1010+
for template_feature_name, template_feature_type in TEMPLATE_FEATURES.items()
1011+
}
9931012
return TemplateSearchResult(
9941013
features=template_features, errors=None, warnings=warnings
9951014
)
9961015

997-
998-
9991016
@dataclasses.dataclass(frozen=True)
10001017
class TemplateSearchResult:
10011018
features: Mapping[str, Any]
@@ -1188,6 +1205,23 @@ def get_templates(
11881205
)
11891206

11901207

1208+
class CustomHitFeaturizer(TemplateHitFeaturizer):
1209+
"""Featurizer for templates given in folder.
1210+
Chain of interest has to be chain A and of same sequence length as input sequence."""
1211+
def get_templates(
1212+
self,
1213+
query_sequence: str,
1214+
hits: Sequence[parsers.TemplateHit],
1215+
) -> TemplateSearchResult:
1216+
"""Computes the templates for given query sequence (more details above)."""
1217+
logging.info("Featurizing mmcif_dir: %s", self._mmcif_dir)
1218+
return get_custom_template_features(
1219+
self._mmcif_dir,
1220+
query_sequence=query_sequence,
1221+
pdb_id="test",
1222+
chain_id="A",
1223+
kalign_binary_path=self._kalign_binary_path,
1224+
)
11911225
class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
11921226
def get_templates(
11931227
self,

run_pretrained_openfold.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,15 @@ def main(args):
202202
)
203203

204204
is_multimer = "multimer" in args.config_preset
205-
206-
if is_multimer:
205+
is_custom_template = "use_custom_template" in args and args.use_custom_template
206+
if is_custom_template:
207+
template_featurizer = templates.CustomHitFeaturizer(
208+
mmcif_dir=args.template_mmcif_dir,
209+
max_template_date="9999-12-31", # just dummy, not used
210+
max_hits=-1, # just dummy, not used
211+
kalign_binary_path=args.kalign_binary_path
212+
)
213+
elif is_multimer:
207214
template_featurizer = templates.HmmsearchHitFeaturizer(
208215
mmcif_dir=args.template_mmcif_dir,
209216
max_template_date=args.max_template_date,
@@ -221,11 +228,9 @@ def main(args):
221228
release_dates_path=args.release_dates_path,
222229
obsolete_pdbs_path=args.obsolete_pdbs_path
223230
)
224-
225231
data_processor = data_pipeline.DataPipeline(
226232
template_featurizer=template_featurizer,
227233
)
228-
229234
if is_multimer:
230235
data_processor = data_pipeline.DataPipelineMultimer(
231236
monomer_data_pipeline=data_processor,
@@ -238,7 +243,6 @@ def main(args):
238243

239244
np.random.seed(random_seed)
240245
torch.manual_seed(random_seed + 1)
241-
242246
feature_processor = feature_pipeline.FeaturePipeline(config.data)
243247
if not os.path.exists(output_dir_base):
244248
os.makedirs(output_dir_base)
@@ -313,7 +317,6 @@ def main(args):
313317
)
314318

315319
feature_dicts[tag] = feature_dict
316-
317320
processed_feature_dict = feature_processor.process_features(
318321
feature_dict, mode='predict', is_multimer=is_multimer
319322
)
@@ -400,6 +403,10 @@ def main(args):
400403
help="""Path to alignment directory. If provided, alignment computation
401404
is skipped and database path arguments are ignored."""
402405
)
406+
parser.add_argument(
407+
"--use_custom_template", action="store_true", default=False,
408+
help="""Use mmcif given with "template_mmcif_dir" argument as template input."""
409+
)
403410
parser.add_argument(
404411
"--use_single_seq_mode", action="store_true", default=False,
405412
help="""Use single sequence embeddings instead of MSAs."""
@@ -494,5 +501,4 @@ def main(args):
494501
"""The model is being run on CPU. Consider specifying
495502
--model_device for better performance"""
496503
)
497-
498504
main(args)

0 commit comments

Comments
 (0)