|
22 | 22 | import json
|
23 | 23 | import logging
|
24 | 24 | import os
|
| 25 | +from pathlib import Path |
25 | 26 | import re
|
26 | 27 | from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
|
27 | 28 |
|
@@ -947,55 +948,71 @@ def _process_single_hit(
|
947 | 948 |
|
948 | 949 |
|
949 | 950 | def get_custom_template_features(
|
950 |
| - mmcif_path: str, |
951 |
| - query_sequence: str, |
952 |
| - pdb_id: str, |
953 |
| - chain_id: str, |
954 |
| - kalign_binary_path: str): |
955 |
| - |
956 |
| - with open(mmcif_path, "r") as mmcif_path: |
957 |
| - cif_string = mmcif_path.read() |
958 |
| - |
959 |
| - mmcif_parse_result = mmcif_parsing.parse( |
960 |
| - file_id=pdb_id, mmcif_string=cif_string |
961 |
| - ) |
962 |
| - template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id] |
963 |
| - |
964 |
| - |
965 |
| - mapping = {x:x for x, _ in enumerate(query_sequence)} |
966 |
| - |
967 |
| - |
968 |
| - features, warnings = _extract_template_features( |
969 |
| - mmcif_object=mmcif_parse_result.mmcif_object, |
970 |
| - pdb_id=pdb_id, |
971 |
| - mapping=mapping, |
972 |
| - template_sequence=template_sequence, |
973 |
| - query_sequence=query_sequence, |
974 |
| - template_chain_id=chain_id, |
975 |
| - kalign_binary_path=kalign_binary_path, |
976 |
| - _zero_center_positions=True |
977 |
| - ) |
978 |
| - features["template_sum_probs"] = [1.0] |
979 |
| - |
980 |
| - # TODO: clean up this logic |
981 |
| - template_features = {} |
982 |
| - for template_feature_name in TEMPLATE_FEATURES: |
983 |
| - template_features[template_feature_name] = [] |
984 |
| - |
985 |
| - for k in template_features: |
986 |
| - template_features[k].append(features[k]) |
987 |
| - |
988 |
| - for name in template_features: |
989 |
| - template_features[name] = np.stack( |
990 |
| - template_features[name], axis=0 |
991 |
| - ).astype(TEMPLATE_FEATURES[name]) |
| 951 | + mmcif_path: str, |
| 952 | + query_sequence: str, |
| 953 | + pdb_id: str, |
| 954 | + chain_id: Optional[str] = "A", |
| 955 | + kalign_binary_path: Optional[str] = None, |
| 956 | +): |
| 957 | + if os.path.isfile(mmcif_path): |
| 958 | + template_paths = [Path(mmcif_path)] |
992 | 959 |
|
| 960 | + elif os.path.isdir(mmcif_path): |
| 961 | + template_paths = list(Path(mmcif_path).glob("*.cif")) |
| 962 | + else: |
| 963 | + logging.error("Custom template path %s does not exist", mmcif_path) |
| 964 | + raise ValueError(f"Custom template path {mmcif_path} does not exist") |
| 965 | + |
| 966 | + warnings = [] |
| 967 | + template_features = dict() |
| 968 | + for template_path in template_paths: |
| 969 | + logging.info("Featurizing template: %s", template_path) |
| 970 | + # pdb_id only for error reporting, take file name |
| 971 | + pdb_id = Path(template_path).stem |
| 972 | + with open(template_path, "r") as mmcif_path: |
| 973 | + cif_string = mmcif_path.read() |
| 974 | + mmcif_parse_result = mmcif_parsing.parse( |
| 975 | + file_id=pdb_id, mmcif_string=cif_string |
| 976 | + ) |
| 977 | + # mapping skipping "-" |
| 978 | + mapping = { |
| 979 | + x: x for x, curr_char in enumerate(query_sequence) if curr_char.isalnum() |
| 980 | + } |
| 981 | + realigned_sequence, realigned_mapping = _realign_pdb_template_to_query( |
| 982 | + old_template_sequence=query_sequence, |
| 983 | + template_chain_id=chain_id, |
| 984 | + mmcif_object=mmcif_parse_result.mmcif_object, |
| 985 | + old_mapping=mapping, |
| 986 | + kalign_binary_path=kalign_binary_path, |
| 987 | + ) |
| 988 | + curr_features, curr_warnings = _extract_template_features( |
| 989 | + mmcif_object=mmcif_parse_result.mmcif_object, |
| 990 | + pdb_id=pdb_id, |
| 991 | + mapping=realigned_mapping, |
| 992 | + template_sequence=realigned_sequence, |
| 993 | + query_sequence=query_sequence, |
| 994 | + template_chain_id=chain_id, |
| 995 | + kalign_binary_path=kalign_binary_path, |
| 996 | + _zero_center_positions=True, |
| 997 | + ) |
| 998 | + curr_features["template_sum_probs"] = [ |
| 999 | + 1.0 |
| 1000 | + ] # template given by user, 100% confident |
| 1001 | + template_features = { |
| 1002 | + curr_name: template_features.get(curr_name, []) + [curr_item] |
| 1003 | + for curr_name, curr_item in curr_features.items() |
| 1004 | + } |
| 1005 | + warnings.append(curr_warnings) |
| 1006 | + template_features = { |
| 1007 | + template_feature_name: np.stack( |
| 1008 | + template_features[template_feature_name], axis=0 |
| 1009 | + ).astype(template_feature_type) |
| 1010 | + for template_feature_name, template_feature_type in TEMPLATE_FEATURES.items() |
| 1011 | + } |
993 | 1012 | return TemplateSearchResult(
|
994 | 1013 | features=template_features, errors=None, warnings=warnings
|
995 | 1014 | )
|
996 | 1015 |
|
997 |
| - |
998 |
| - |
999 | 1016 | @dataclasses.dataclass(frozen=True)
|
1000 | 1017 | class TemplateSearchResult:
|
1001 | 1018 | features: Mapping[str, Any]
|
@@ -1188,6 +1205,23 @@ def get_templates(
|
1188 | 1205 | )
|
1189 | 1206 |
|
1190 | 1207 |
|
| 1208 | +class CustomHitFeaturizer(TemplateHitFeaturizer): |
| 1209 | + """Featurizer for templates given in folder. |
| 1210 | + Chain of interest has to be chain A and of same sequence length as input sequence.""" |
| 1211 | + def get_templates( |
| 1212 | + self, |
| 1213 | + query_sequence: str, |
| 1214 | + hits: Sequence[parsers.TemplateHit], |
| 1215 | + ) -> TemplateSearchResult: |
| 1216 | + """Computes the templates for given query sequence (more details above).""" |
| 1217 | + logging.info("Featurizing mmcif_dir: %s", self._mmcif_dir) |
| 1218 | + return get_custom_template_features( |
| 1219 | + self._mmcif_dir, |
| 1220 | + query_sequence=query_sequence, |
| 1221 | + pdb_id="test", |
| 1222 | + chain_id="A", |
| 1223 | + kalign_binary_path=self._kalign_binary_path, |
| 1224 | + ) |
1191 | 1225 | class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
|
1192 | 1226 | def get_templates(
|
1193 | 1227 | self,
|
|
0 commit comments