diff --git a/README.md b/README.md index 635c295a..020ca3b3 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,11 @@ -# ColabDesign +# ColabDesign (EXPERIMENTAL justktln2 REMIX) + +The only modification to the forked branch has been to introduce I/O with MDTraj and vectorize ProteinMPNN sampling across trajectories. In principle it is possible to do this with any conformational ensemble with a fixed protein topology, so this could be used with structures generated by AlphaFold MSA subsampling. +You will therefore need to install MDTraj to use this. +The important parts of this will be refactored and folded into either [ciMIST](https://github.com/justktln2/ciMIST/) or ColabDesign. + +Original README, which I had nothing to do with, follows. + ### Making Protein Design accessible to all via Google Colab! - P(structure | sequence) - [TrDesign](/tr) - using TrRosetta for design diff --git a/build/lib/colabdesign/__init__.py b/build/lib/colabdesign/__init__.py new file mode 100644 index 00000000..4fa3969e --- /dev/null +++ b/build/lib/colabdesign/__init__.py @@ -0,0 +1,16 @@ +import os,jax +# disable triton_gemm for jax versions > 0.3 +if int(jax.__version__.split(".")[1]) > 3: + os.environ["XLA_FLAGS"] = "--xla_gpu_enable_triton_gemm=false" + +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) + +from colabdesign.shared.utils import clear_mem +from colabdesign.af.model import mk_af_model +from colabdesign.tr.model import mk_tr_model +from colabdesign.mpnn.model import mk_mpnn_model + +# backward compatability +mk_design_model = mk_afdesign_model = mk_af_model +mk_trdesign_model = mk_tr_model \ No newline at end of file diff --git a/build/lib/colabdesign/af/__init__.py b/build/lib/colabdesign/af/__init__.py new file mode 100644 index 00000000..bc92f615 --- /dev/null +++ b/build/lib/colabdesign/af/__init__.py @@ -0,0 +1,13 @@ +import os,jax +# disable triton_gemm for jax versions > 0.3 +if int(jax.__version__.split(".")[1]) > 3: + os.environ["XLA_FLAGS"] = "--xla_gpu_enable_triton_gemm=false" + +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) + +from colabdesign.shared.utils import clear_mem +from colabdesign.af.model import mk_af_model + +# backward compatability +mk_design_model = mk_afdesign_model = mk_af_model \ No newline at end of file diff --git a/build/lib/colabdesign/af/alphafold/__init__.py b/build/lib/colabdesign/af/alphafold/__init__.py new file mode 100644 index 00000000..a0fd7f82 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""An implementation of the inference pipeline of AlphaFold v2.0.""" diff --git a/build/lib/colabdesign/af/alphafold/common/__init__.py b/build/lib/colabdesign/af/alphafold/common/__init__.py new file mode 100644 index 00000000..d3c65d69 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/common/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common data types and constants used within Alphafold.""" diff --git a/build/lib/colabdesign/af/alphafold/common/confidence.py b/build/lib/colabdesign/af/alphafold/common/confidence.py new file mode 100644 index 00000000..1d566d86 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/common/confidence.py @@ -0,0 +1,169 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for processing confidence metrics.""" + +import jax.numpy as jnp +import jax +import numpy as np +from colabdesign.af.alphafold.common import residue_constants +import scipy.special + +def compute_tol(prev_pos, current_pos, mask, use_jnp=False): + # Early stopping criteria based on criteria used in + # AF2Complex: https://www.nature.com/articles/s41467-022-29394-2 + _np = jnp if use_jnp else np + dist = lambda x:_np.sqrt(((x[:,None] - x[None,:])**2).sum(-1)) + ca_idx = residue_constants.atom_order['CA'] + sq_diff = _np.square(dist(prev_pos[:,ca_idx])-dist(current_pos[:,ca_idx])) + mask_2d = mask[:,None] * mask[None,:] + return _np.sqrt((sq_diff * mask_2d).sum()/mask_2d.sum() + 1e-8) + + +def compute_plddt(logits, use_jnp=False): + """Computes per-residue pLDDT from logits. + Args: + logits: [num_res, num_bins] output from the PredictedLDDTHead. + Returns: + plddt: [num_res] per-residue pLDDT. + """ + if use_jnp: + _np, _softmax = jnp, jax.nn.softmax + else: + _np, _softmax = np, scipy.special.softmax + + num_bins = logits.shape[-1] + bin_width = 1.0 / num_bins + bin_centers = _np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width) + probs = _softmax(logits, axis=-1) + predicted_lddt_ca = (probs * bin_centers[None, :]).sum(-1) + return predicted_lddt_ca * 100 + +def _calculate_bin_centers(breaks, use_jnp=False): + """Gets the bin centers from the bin edges. + Args: + breaks: [num_bins - 1] the error bin edges. + Returns: + bin_centers: [num_bins] the error bin centers. + """ + _np = jnp if use_jnp else np + step = breaks[1] - breaks[0] + + # Add half-step to get the center + bin_centers = breaks + step / 2 + + # Add a catch-all bin at the end. + return _np.append(bin_centers, bin_centers[-1] + step) + +def _calculate_expected_aligned_error( + alignment_confidence_breaks, + aligned_distance_error_probs, + use_jnp=False): + """Calculates expected aligned distance errors for every pair of residues. + Args: + alignment_confidence_breaks: [num_bins - 1] the error bin edges. + aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted + probs for each error bin, for each pair of residues. + Returns: + predicted_aligned_error: [num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: The maximum predicted error possible. + """ + bin_centers = _calculate_bin_centers(alignment_confidence_breaks, use_jnp=use_jnp) + # Tuple of expected aligned distance error and max possible error. + pae = (aligned_distance_error_probs * bin_centers).sum(-1) + return (pae, bin_centers[-1]) + +def compute_predicted_aligned_error(logits, breaks, use_jnp=False): + """Computes aligned confidence metrics from logits. + Args: + logits: [num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + breaks: [num_bins - 1] the error bin edges. + + Returns: + aligned_confidence_probs: [num_res, num_res, num_bins] the predicted + aligned error probabilities over bins for each residue pair. + predicted_aligned_error: [num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: The maximum predicted error possible. + """ + _softmax = jax.nn.softmax if use_jnp else scipy.special.softmax + aligned_confidence_probs = _softmax(logits,axis=-1) + predicted_aligned_error, max_predicted_aligned_error = \ + _calculate_expected_aligned_error(breaks, aligned_confidence_probs, use_jnp=use_jnp) + + return { + 'aligned_confidence_probs': aligned_confidence_probs, + 'predicted_aligned_error': predicted_aligned_error, + 'max_predicted_aligned_error': max_predicted_aligned_error, + } + +def predicted_tm_score(logits, breaks, residue_weights = None, + asym_id = None, use_jnp=False): + """Computes predicted TM alignment or predicted interface TM alignment score. + + Args: + logits: [num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + breaks: [num_bins] the error bins. + residue_weights: [num_res] the per residue weights to use for the + expectation. + asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for + ipTM calculation. + + Returns: + ptm_score: The predicted TM alignment or the predicted iTM score. + """ + if use_jnp: + _np, _softmax = jnp, jax.nn.softmax + else: + _np, _softmax = np, scipy.special.softmax + + # residue_weights has to be in [0, 1], but can be floating-point, i.e. the + # exp. resolved head's probability. + if residue_weights is None: + residue_weights = _np.ones(logits.shape[0]) + + bin_centers = _calculate_bin_centers(breaks, use_jnp=use_jnp) + num_res = residue_weights.shape[0] + + # Clip num_res to avoid negative/undefined d0. + clipped_num_res = _np.maximum(residue_weights.sum(), 19) + + # Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick + # "Scoring function for automated assessment of protein structure template + # quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf + d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8 + + # Convert logits to probs. + probs = _softmax(logits, axis=-1) + + # TM-Score term for every bin. + tm_per_bin = 1. / (1 + _np.square(bin_centers) / _np.square(d0)) + # E_distances tm(distance). + predicted_tm_term = (probs * tm_per_bin).sum(-1) + + if asym_id is None: + pair_mask = _np.full((num_res,num_res),True) + else: + pair_mask = asym_id[:, None] != asym_id[None, :] + + predicted_tm_term *= pair_mask + + pair_residue_weights = pair_mask * (residue_weights[None, :] * residue_weights[:, None]) + normed_residue_mask = pair_residue_weights / (1e-8 + pair_residue_weights.sum(-1, keepdims=True)) + per_alignment = (predicted_tm_term * normed_residue_mask).sum(-1) + + return (per_alignment * residue_weights).max() \ No newline at end of file diff --git a/build/lib/colabdesign/af/alphafold/common/protein.py b/build/lib/colabdesign/af/alphafold/common/protein.py new file mode 100644 index 00000000..d1ae59bb --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/common/protein.py @@ -0,0 +1,229 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Protein data type.""" +import dataclasses +import io +from typing import Any, Mapping, Optional +from colabdesign.af.alphafold.common import residue_constants +from Bio.PDB import PDBParser +import numpy as np + +FeatureDict = Mapping[str, np.ndarray] +ModelOutput = Mapping[str, Any] # Is a nested dict. + + +@dataclasses.dataclass(frozen=True) +class Protein: + """Protein structure representation.""" + + # Cartesian coordinates of atoms in angstroms. The atom types correspond to + # residue_constants.atom_types, i.e. the first three are N, CA, CB. + atom_positions: np.ndarray # [num_res, num_atom_type, 3] + + # Amino-acid type for each residue represented as an integer between 0 and + # 20, where 20 is 'X'. + aatype: np.ndarray # [num_res] + + # Binary float mask to indicate presence of a particular atom. 1.0 if an atom + # is present and 0.0 if not. This should be used for loss masking. + atom_mask: np.ndarray # [num_res, num_atom_type] + + # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. + residue_index: np.ndarray # [num_res] + + # B-factors, or temperature factors, of each residue (in sq. angstroms units), + # representing the displacement of the residue from its ground truth mean + # value. + b_factors: np.ndarray # [num_res, num_atom_type] + + +def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: + """Takes a PDB string and constructs a Protein object. + + WARNING: All non-standard residue types will be converted into UNK. All + non-standard atoms will be ignored. + + Args: + pdb_str: The contents of the pdb file + chain_id: If None, then the pdb file must contain a single chain (which + will be parsed). If chain_id is specified (e.g. A), then only that chain + is parsed. + + Returns: + A new `Protein` parsed from the pdb contents. + """ + pdb_fh = io.StringIO(pdb_str) + parser = PDBParser(QUIET=True) + structure = parser.get_structure('none', pdb_fh) + models = list(structure.get_models()) + if len(models) != 1: + raise ValueError( + f'Only single model PDBs are supported. Found {len(models)} models.') + model = models[0] + + if chain_id is not None: + chain = model[chain_id] + else: + chains = list(model.get_chains()) + if len(chains) != 1: + raise ValueError( + 'Only single chain PDBs are supported when chain_id not specified. ' + f'Found {len(chains)} chains.') + else: + chain = chains[0] + + atom_positions = [] + aatype = [] + atom_mask = [] + residue_index = [] + b_factors = [] + + for res in chain: + if res.id[2] != ' ': + raise ValueError( + f'PDB contains an insertion code at chain {chain.id} and residue ' + f'index {res.id[1]}. These are not supported.') + res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') + restype_idx = residue_constants.restype_order.get( + res_shortname, residue_constants.restype_num) + pos = np.zeros((residue_constants.atom_type_num, 3)) + mask = np.zeros((residue_constants.atom_type_num,)) + res_b_factors = np.zeros((residue_constants.atom_type_num,)) + for atom in res: + if atom.name not in residue_constants.atom_types: + continue + pos[residue_constants.atom_order[atom.name]] = atom.coord + mask[residue_constants.atom_order[atom.name]] = 1. + res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor + if np.sum(mask) < 0.5: + # If no known atom positions are reported for the residue then skip it. + continue + aatype.append(restype_idx) + atom_positions.append(pos) + atom_mask.append(mask) + residue_index.append(res.id[1]) + b_factors.append(res_b_factors) + + return Protein( + atom_positions=np.array(atom_positions), + atom_mask=np.array(atom_mask), + aatype=np.array(aatype), + residue_index=np.array(residue_index), + b_factors=np.array(b_factors)) + + +def to_pdb(prot: Protein) -> str: + """Converts a `Protein` instance to a PDB string. + + Args: + prot: The protein to convert to PDB. + + Returns: + PDB string. + """ + restypes = residue_constants.restypes + ['X'] + res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK') + atom_types = residue_constants.atom_types + + pdb_lines = [] + + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(np.int32) + b_factors = prot.b_factors + + if np.any(aatype > residue_constants.restype_num): + raise ValueError('Invalid aatypes.') + + pdb_lines.append('MODEL 1') + atom_index = 1 + chain_id = 'A' + # Add all atom sites. + for i in range(aatype.shape[0]): + res_name_3 = res_1to3(aatype[i]) + for atom_name, pos, mask, b_factor in zip( + atom_types, atom_positions[i], atom_mask[i], b_factors[i]): + if mask < 0.5: + continue + + record_type = 'ATOM' + name = atom_name if len(atom_name) == 4 else f' {atom_name}' + alt_loc = '' + insertion_code = '' + occupancy = 1.00 + element = atom_name[0] # Protein supports only C, N, O, S, this works. + charge = '' + # PDB is a columnar format, every space matters here! + atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' + f'{res_name_3:>3} {chain_id:>1}' + f'{residue_index[i]:>4}{insertion_code:>1} ' + f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' + f'{occupancy:>6.2f}{b_factor:>6.2f} ' + f'{element:>2}{charge:>2}') + pdb_lines.append(atom_line) + atom_index += 1 + + # Close the chain. + chain_end = 'TER' + chain_termination_line = ( + f'{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} ' + f'{chain_id:>1}{residue_index[-1]:>4}') + pdb_lines.append(chain_termination_line) + pdb_lines.append('ENDMDL') + + pdb_lines.append('END') + pdb_lines.append('') + return '\n'.join(pdb_lines) + + +def ideal_atom_mask(prot: Protein) -> np.ndarray: + """Computes an ideal atom mask. + + `Protein.atom_mask` typically is defined according to the atoms that are + reported in the PDB. This function computes a mask according to heavy atoms + that should be present in the given sequence of amino acids. + + Args: + prot: `Protein` whose fields are `numpy.ndarray` objects. + + Returns: + An ideal atom mask. + """ + return residue_constants.STANDARD_ATOM_MASK[prot.aatype] + + +def from_prediction(features: FeatureDict, result: ModelOutput, + b_factors: Optional[np.ndarray] = None) -> Protein: + """Assembles a protein from a prediction. + + Args: + features: Dictionary holding model inputs. + result: Dictionary holding model outputs. + b_factors: (Optional) B-factors to use for the protein. + + Returns: + A protein instance. + """ + fold_output = result['structure_module'] + if b_factors is None: + b_factors = np.zeros_like(fold_output['final_atom_mask']) + + return Protein( + aatype=features['aatype'][0], + atom_positions=fold_output['final_atom_positions'], + atom_mask=fold_output['final_atom_mask'], + residue_index=features['residue_index'][0] + 1, + b_factors=b_factors) diff --git a/build/lib/colabdesign/af/alphafold/common/residue_constants.py b/build/lib/colabdesign/af/alphafold/common/residue_constants.py new file mode 100644 index 00000000..05659758 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/common/residue_constants.py @@ -0,0 +1,911 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants used in AlphaFold.""" + +import collections +import functools +from typing import List, Mapping, Tuple + +import numpy as np +import tree + +# Internal import (35fd). + + +# Distance from one CA to next CA [trans configuration: omega = 180]. +ca_ca = 3.80209737096 + +# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in +# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have +# chi angles so their chi angle lists are empty. +chi_angles_atoms = { + 'ALA': [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + 'ARG': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'NE'], ['CG', 'CD', 'NE', 'CZ']], + 'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'CYS': [['N', 'CA', 'CB', 'SG']], + 'GLN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1']], + 'GLU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1']], + 'GLY': [], + 'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']], + 'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']], + 'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'LYS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'CE'], ['CG', 'CD', 'CE', 'NZ']], + 'MET': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'SD'], + ['CB', 'CG', 'SD', 'CE']], + 'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']], + 'SER': [['N', 'CA', 'CB', 'OG']], + 'THR': [['N', 'CA', 'CB', 'OG1']], + 'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'VAL': [['N', 'CA', 'CB', 'CG1']], +} + +# If chi angles given in fixed-length array, this matrix determines how to mask +# them for each AA type. The order is as per restype_order (see below). +chi_angles_mask = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [1.0, 1.0, 1.0, 1.0], # ARG + [1.0, 1.0, 0.0, 0.0], # ASN + [1.0, 1.0, 0.0, 0.0], # ASP + [1.0, 0.0, 0.0, 0.0], # CYS + [1.0, 1.0, 1.0, 0.0], # GLN + [1.0, 1.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [1.0, 1.0, 0.0, 0.0], # HIS + [1.0, 1.0, 0.0, 0.0], # ILE + [1.0, 1.0, 0.0, 0.0], # LEU + [1.0, 1.0, 1.0, 1.0], # LYS + [1.0, 1.0, 1.0, 0.0], # MET + [1.0, 1.0, 0.0, 0.0], # PHE + [1.0, 1.0, 0.0, 0.0], # PRO + [1.0, 0.0, 0.0, 0.0], # SER + [1.0, 0.0, 0.0, 0.0], # THR + [1.0, 1.0, 0.0, 0.0], # TRP + [1.0, 1.0, 0.0, 0.0], # TYR + [1.0, 0.0, 0.0, 0.0], # VAL +] + +# The following chi angles are pi periodic: they can be rotated by a multiple +# of pi without affecting the structure. +chi_pi_periodic = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [0.0, 0.0, 0.0, 0.0], # ARG + [0.0, 0.0, 0.0, 0.0], # ASN + [0.0, 1.0, 0.0, 0.0], # ASP + [0.0, 0.0, 0.0, 0.0], # CYS + [0.0, 0.0, 0.0, 0.0], # GLN + [0.0, 0.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [0.0, 0.0, 0.0, 0.0], # HIS + [0.0, 0.0, 0.0, 0.0], # ILE + [0.0, 0.0, 0.0, 0.0], # LEU + [0.0, 0.0, 0.0, 0.0], # LYS + [0.0, 0.0, 0.0, 0.0], # MET + [0.0, 1.0, 0.0, 0.0], # PHE + [0.0, 0.0, 0.0, 0.0], # PRO + [0.0, 0.0, 0.0, 0.0], # SER + [0.0, 0.0, 0.0, 0.0], # THR + [0.0, 0.0, 0.0, 0.0], # TRP + [0.0, 1.0, 0.0, 0.0], # TYR + [0.0, 0.0, 0.0, 0.0], # VAL + [0.0, 0.0, 0.0, 0.0], # UNK +] + +# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi, +# psi and chi angles: +# 0: 'backbone group', +# 1: 'pre-omega-group', (empty) +# 2: 'phi-group', (currently empty, because it defines only hydrogens) +# 3: 'psi-group', +# 4,5,6,7: 'chi1,2,3,4-group' +# The atom positions are relative to the axis-end-atom of the corresponding +# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis +# is defined such that the dihedral-angle-definiting atom (the last entry in +# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate). +# format: [atomname, group_idx, rel_position] +rigid_group_atom_positions = { + 'ALA': [ + ['N', 0, (-0.525, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.529, -0.774, -1.205)], + ['O', 3, (0.627, 1.062, 0.000)], + ], + 'ARG': [ + ['N', 0, (-0.524, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.524, -0.778, -1.209)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.616, 1.390, -0.000)], + ['CD', 5, (0.564, 1.414, 0.000)], + ['NE', 6, (0.539, 1.357, -0.000)], + ['NH1', 7, (0.206, 2.301, 0.000)], + ['NH2', 7, (2.078, 0.978, -0.000)], + ['CZ', 7, (0.758, 1.093, -0.000)], + ], + 'ASN': [ + ['N', 0, (-0.536, 1.357, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.531, -0.787, -1.200)], + ['O', 3, (0.625, 1.062, 0.000)], + ['CG', 4, (0.584, 1.399, 0.000)], + ['ND2', 5, (0.593, -1.188, 0.001)], + ['OD1', 5, (0.633, 1.059, 0.000)], + ], + 'ASP': [ + ['N', 0, (-0.525, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, 0.000, -0.000)], + ['CB', 0, (-0.526, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.593, 1.398, -0.000)], + ['OD1', 5, (0.610, 1.091, 0.000)], + ['OD2', 5, (0.592, -1.101, -0.003)], + ], + 'CYS': [ + ['N', 0, (-0.522, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, 0.000)], + ['CB', 0, (-0.519, -0.773, -1.212)], + ['O', 3, (0.625, 1.062, -0.000)], + ['SG', 4, (0.728, 1.653, 0.000)], + ], + 'GLN': [ + ['N', 0, (-0.526, 1.361, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.779, -1.207)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.615, 1.393, 0.000)], + ['CD', 5, (0.587, 1.399, -0.000)], + ['NE2', 6, (0.593, -1.189, -0.001)], + ['OE1', 6, (0.634, 1.060, 0.000)], + ], + 'GLU': [ + ['N', 0, (-0.528, 1.361, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.526, -0.781, -1.207)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.615, 1.392, 0.000)], + ['CD', 5, (0.600, 1.397, 0.000)], + ['OE1', 6, (0.607, 1.095, -0.000)], + ['OE2', 6, (0.589, -1.104, -0.001)], + ], + 'GLY': [ + ['N', 0, (-0.572, 1.337, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.517, -0.000, -0.000)], + ['O', 3, (0.626, 1.062, -0.000)], + ], + 'HIS': [ + ['N', 0, (-0.527, 1.360, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.778, -1.208)], + ['O', 3, (0.625, 1.063, 0.000)], + ['CG', 4, (0.600, 1.370, -0.000)], + ['CD2', 5, (0.889, -1.021, 0.003)], + ['ND1', 5, (0.744, 1.160, -0.000)], + ['CE1', 5, (2.030, 0.851, 0.002)], + ['NE2', 5, (2.145, -0.466, 0.004)], + ], + 'ILE': [ + ['N', 0, (-0.493, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.536, -0.793, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.534, 1.437, -0.000)], + ['CG2', 4, (0.540, -0.785, -1.199)], + ['CD1', 5, (0.619, 1.391, 0.000)], + ], + 'LEU': [ + ['N', 0, (-0.520, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.773, -1.214)], + ['O', 3, (0.625, 1.063, -0.000)], + ['CG', 4, (0.678, 1.371, 0.000)], + ['CD1', 5, (0.530, 1.430, -0.000)], + ['CD2', 5, (0.535, -0.774, 1.200)], + ], + 'LYS': [ + ['N', 0, (-0.526, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.524, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.619, 1.390, 0.000)], + ['CD', 5, (0.559, 1.417, 0.000)], + ['CE', 6, (0.560, 1.416, 0.000)], + ['NZ', 7, (0.554, 1.387, 0.000)], + ], + 'MET': [ + ['N', 0, (-0.521, 1.364, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.210)], + ['O', 3, (0.625, 1.062, -0.000)], + ['CG', 4, (0.613, 1.391, -0.000)], + ['SD', 5, (0.703, 1.695, 0.000)], + ['CE', 6, (0.320, 1.786, -0.000)], + ], + 'PHE': [ + ['N', 0, (-0.518, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, -0.000)], + ['CB', 0, (-0.525, -0.776, -1.212)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.607, 1.377, 0.000)], + ['CD1', 5, (0.709, 1.195, -0.000)], + ['CD2', 5, (0.706, -1.196, 0.000)], + ['CE1', 5, (2.102, 1.198, -0.000)], + ['CE2', 5, (2.098, -1.201, -0.000)], + ['CZ', 5, (2.794, -0.003, -0.001)], + ], + 'PRO': [ + ['N', 0, (-0.566, 1.351, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, 0.000)], + ['CB', 0, (-0.546, -0.611, -1.293)], + ['O', 3, (0.621, 1.066, 0.000)], + ['CG', 4, (0.382, 1.445, 0.0)], + # ['CD', 5, (0.427, 1.440, 0.0)], + ['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger + ], + 'SER': [ + ['N', 0, (-0.529, 1.360, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.518, -0.777, -1.211)], + ['O', 3, (0.626, 1.062, -0.000)], + ['OG', 4, (0.503, 1.325, 0.000)], + ], + 'THR': [ + ['N', 0, (-0.517, 1.364, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, -0.000)], + ['CB', 0, (-0.516, -0.793, -1.215)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG2', 4, (0.550, -0.718, -1.228)], + ['OG1', 4, (0.472, 1.353, 0.000)], + ], + 'TRP': [ + ['N', 0, (-0.521, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.212)], + ['O', 3, (0.627, 1.062, 0.000)], + ['CG', 4, (0.609, 1.370, -0.000)], + ['CD1', 5, (0.824, 1.091, 0.000)], + ['CD2', 5, (0.854, -1.148, -0.005)], + ['CE2', 5, (2.186, -0.678, -0.007)], + ['CE3', 5, (0.622, -2.530, -0.007)], + ['NE1', 5, (2.140, 0.690, -0.004)], + ['CH2', 5, (3.028, -2.890, -0.013)], + ['CZ2', 5, (3.283, -1.543, -0.011)], + ['CZ3', 5, (1.715, -3.389, -0.011)], + ], + 'TYR': [ + ['N', 0, (-0.522, 1.362, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.776, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG', 4, (0.607, 1.382, -0.000)], + ['CD1', 5, (0.716, 1.195, -0.000)], + ['CD2', 5, (0.713, -1.194, -0.001)], + ['CE1', 5, (2.107, 1.200, -0.002)], + ['CE2', 5, (2.104, -1.201, -0.003)], + ['OH', 5, (4.168, -0.002, -0.005)], + ['CZ', 5, (2.791, -0.001, -0.003)], + ], + 'VAL': [ + ['N', 0, (-0.494, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.533, -0.795, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.540, 1.429, -0.000)], + ['CG2', 4, (0.533, -0.776, 1.203)], + ], +} + +# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. +residue_atoms = { + 'ALA': ['C', 'CA', 'CB', 'N', 'O'], + 'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'O', 'NH1', 'NH2'], + 'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'], + 'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'], + 'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'], + 'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'], + 'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'], + 'GLY': ['C', 'CA', 'N', 'O'], + 'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'], + 'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'], + 'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'], + 'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'], + 'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'], + 'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'], + 'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'], + 'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'], + 'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'], + 'TRP': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'CZ2', 'CZ3', + 'CH2', 'N', 'NE1', 'O'], + 'TYR': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O', + 'OH'], + 'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O'] +} + +# Naming swaps for ambiguous atom names. +# Due to symmetries in the amino acids the naming of atoms is ambiguous in +# 4 of the 20 amino acids. +# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities +# in LEU, VAL and ARG can be resolved by using the 3d constellations of +# the 'ambiguous' atoms and their neighbours) +residue_atom_renaming_swaps = { + 'ASP': {'OD1': 'OD2'}, + 'GLU': {'OE1': 'OE2'}, + 'PHE': {'CD1': 'CD2', 'CE1': 'CE2'}, + 'TYR': {'CD1': 'CD2', 'CE1': 'CE2'}, +} + +# Van der Waals radii [Angstroem] of the atoms (from Wikipedia) +van_der_waals_radius = { + 'C': 1.7, + 'N': 1.55, + 'O': 1.52, + 'S': 1.8, +} + +Bond = collections.namedtuple( + 'Bond', ['atom1_name', 'atom2_name', 'length', 'stddev']) +BondAngle = collections.namedtuple( + 'BondAngle', + ['atom1_name', 'atom2_name', 'atom3name', 'angle_rad', 'stddev']) + + +@functools.lru_cache(maxsize=None) +def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]], + Mapping[str, List[Bond]], + Mapping[str, List[BondAngle]]]: + """Load stereo_chemical_props.txt into a nice structure. + + Load literature values for bond lengths and bond angles and translate + bond angles into the length of the opposite edge of the triangle + ("residue_virtual_bonds"). + + Returns: + residue_bonds: dict that maps resname --> list of Bond tuples + residue_virtual_bonds: dict that maps resname --> list of Bond tuples + residue_bond_angles: dict that maps resname --> list of BondAngle tuples + """ + stereo_chemical_props_path = ( + 'alphafold/common/stereo_chemical_props.txt') + with open(stereo_chemical_props_path, 'rt') as f: + stereo_chemical_props = f.read() + lines_iter = iter(stereo_chemical_props.splitlines()) + # Load bond lengths. + residue_bonds = {} + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == '-': + break + bond, resname, length, stddev = line.split() + atom1, atom2 = bond.split('-') + if resname not in residue_bonds: + residue_bonds[resname] = [] + residue_bonds[resname].append( + Bond(atom1, atom2, float(length), float(stddev))) + residue_bonds['UNK'] = [] + + # Load bond angles. + residue_bond_angles = {} + next(lines_iter) # Skip empty line. + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == '-': + break + bond, resname, angle_degree, stddev_degree = line.split() + atom1, atom2, atom3 = bond.split('-') + if resname not in residue_bond_angles: + residue_bond_angles[resname] = [] + residue_bond_angles[resname].append( + BondAngle(atom1, atom2, atom3, + float(angle_degree) / 180. * np.pi, + float(stddev_degree) / 180. * np.pi)) + residue_bond_angles['UNK'] = [] + + def make_bond_key(atom1_name, atom2_name): + """Unique key to lookup bonds.""" + return '-'.join(sorted([atom1_name, atom2_name])) + + # Translate bond angles into distances ("virtual bonds"). + residue_virtual_bonds = {} + for resname, bond_angles in residue_bond_angles.items(): + # Create a fast lookup dict for bond lengths. + bond_cache = {} + for b in residue_bonds[resname]: + bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b + residue_virtual_bonds[resname] = [] + for ba in bond_angles: + bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)] + bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)] + + # Compute distance between atom1 and atom3 using the law of cosines + # c^2 = a^2 + b^2 - 2ab*cos(gamma). + gamma = ba.angle_rad + length = np.sqrt(bond1.length**2 + bond2.length**2 + - 2 * bond1.length * bond2.length * np.cos(gamma)) + + # Propagation of uncertainty assuming uncorrelated errors. + dl_outer = 0.5 / length + dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer + dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer + dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer + stddev = np.sqrt((dl_dgamma * ba.stddev)**2 + + (dl_db1 * bond1.stddev)**2 + + (dl_db2 * bond2.stddev)**2) + residue_virtual_bonds[resname].append( + Bond(ba.atom1_name, ba.atom3name, length, stddev)) + + return (residue_bonds, + residue_virtual_bonds, + residue_bond_angles) + + +# Between-residue bond lengths for general bonds (first element) and for Proline +# (second element). +between_res_bond_length_c_n = [1.329, 1.341] +between_res_bond_length_stddev_c_n = [0.014, 0.016] + +# Between-residue cos_angles. +between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315 +between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995 + +# This mapping is used when we need to store atom data in a format that requires +# fixed atom data size for every residue (e.g. a numpy array). +atom_types = [ + 'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD', + 'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3', + 'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2', + 'CZ3', 'NZ', 'OXT' +] +atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_type_num = len(atom_types) # := 37. + +# A compact atom encoding with 14 columns +# pylint: disable=line-too-long +# pylint: disable=bad-whitespace +restype_name_to_atom14_names = { + 'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''], + 'ARG': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''], + 'ASN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''], + 'ASP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''], + 'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''], + 'GLN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''], + 'GLU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''], + 'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''], + 'HIS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''], + 'ILE': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''], + 'LEU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''], + 'LYS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''], + 'MET': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''], + 'PHE': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''], + 'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''], + 'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''], + 'THR': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''], + 'TRP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'], + 'TYR': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''], + 'VAL': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''], + 'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''], + +} +# pylint: enable=line-too-long +# pylint: enable=bad-whitespace + + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +restypes = [ + 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', + 'S', 'T', 'W', 'Y', 'V' +] +restype_order = {restype: i for i, restype in enumerate(restypes)} +restype_num = len(restypes) # := 20. +unk_restype_index = restype_num # Catch-all index for unknown restypes. + +restypes_with_x = restypes + ['X'] +restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)} + + +def sequence_to_onehot( + sequence: str, + mapping: Mapping[str, int], + map_unknown_to_x: bool = False) -> np.ndarray: + """Maps the given sequence into a one-hot encoded matrix. + + Args: + sequence: An amino acid sequence. + mapping: A dictionary mapping amino acids to integers. + map_unknown_to_x: If True, any amino acid that is not in the mapping will be + mapped to the unknown amino acid 'X'. If the mapping doesn't contain + amino acid 'X', an error will be thrown. If False, any amino acid not in + the mapping will throw an error. + + Returns: + A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of + the sequence. + + Raises: + ValueError: If the mapping doesn't contain values from 0 to + num_unique_aas - 1 without any gaps. + """ + num_entries = max(mapping.values()) + 1 + + if sorted(set(mapping.values())) != list(range(num_entries)): + raise ValueError('The mapping must have values from 0 to num_unique_aas-1 ' + 'without any gaps. Got: %s' % sorted(mapping.values())) + + one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32) + + for aa_index, aa_type in enumerate(sequence): + if map_unknown_to_x: + if aa_type.isalpha() and aa_type.isupper(): + aa_id = mapping.get(aa_type, mapping['X']) + else: + raise ValueError(f'Invalid character in the sequence: {aa_type}') + else: + aa_id = mapping[aa_type] + one_hot_arr[aa_index, aa_id] = 1 + + return one_hot_arr + + +restype_1to3 = { + 'A': 'ALA', + 'R': 'ARG', + 'N': 'ASN', + 'D': 'ASP', + 'C': 'CYS', + 'Q': 'GLN', + 'E': 'GLU', + 'G': 'GLY', + 'H': 'HIS', + 'I': 'ILE', + 'L': 'LEU', + 'K': 'LYS', + 'M': 'MET', + 'F': 'PHE', + 'P': 'PRO', + 'S': 'SER', + 'T': 'THR', + 'W': 'TRP', + 'Y': 'TYR', + 'V': 'VAL', +} + + +# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple +# 1-to-1 mapping of 3 letter names to one letter names. The latter contains +# many more, and less common, three letter names as keys and maps many of these +# to the same one letter name (including 'X' and 'U' which we don't use here). +restype_3to1 = {v: k for k, v in restype_1to3.items()} + +# Define a restype name for all unknown residues. +unk_restype = 'UNK' + +resnames = [restype_1to3[r] for r in restypes] + [unk_restype] +resname_to_idx = {resname: i for i, resname in enumerate(resnames)} + + +# The mapping here uses hhblits convention, so that B is mapped to D, J and O +# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the +# remaining 20 amino acids are kept in alphabetical order. +# There are 2 non-amino acid codes, X (representing any amino acid) and +# "-" representing a missing amino acid in an alignment. The id for these +# codes is put at the end (20 and 21) so that they can easily be ignored if +# desired. +HHBLITS_AA_TO_ID = { + 'A': 0, + 'B': 2, + 'C': 1, + 'D': 2, + 'E': 3, + 'F': 4, + 'G': 5, + 'H': 6, + 'I': 7, + 'J': 20, + 'K': 8, + 'L': 9, + 'M': 10, + 'N': 11, + 'O': 20, + 'P': 12, + 'Q': 13, + 'R': 14, + 'S': 15, + 'T': 16, + 'U': 1, + 'V': 17, + 'W': 18, + 'X': 20, + 'Y': 19, + 'Z': 3, + '-': 21, +} + +# Partial inversion of HHBLITS_AA_TO_ID. +ID_TO_HHBLITS_AA = { + 0: 'A', + 1: 'C', # Also U. + 2: 'D', # Also B. + 3: 'E', # Also Z. + 4: 'F', + 5: 'G', + 6: 'H', + 7: 'I', + 8: 'K', + 9: 'L', + 10: 'M', + 11: 'N', + 12: 'P', + 13: 'Q', + 14: 'R', + 15: 'S', + 16: 'T', + 17: 'V', + 18: 'W', + 19: 'Y', + 20: 'X', # Includes J and O. + 21: '-', +} + +restypes_with_x_and_gap = restypes + ['X', '-'] +MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple( + restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) + for i in range(len(restypes_with_x_and_gap))) + + +def _make_standard_atom_mask() -> np.ndarray: + """Returns [num_res_types, num_atom_types] mask array.""" + # +1 to account for unknown (all 0s). + mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3[restype_letter] + atom_names = residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = atom_order[atom_name] + mask[restype, atom_type] = 1 + return mask + + +STANDARD_ATOM_MASK = _make_standard_atom_mask() + + +# A one hot representation for the first and second atoms defining the axis +# of rotation for each chi-angle in each residue. +def chi_angle_atom(atom_index: int) -> np.ndarray: + """Define chi-angle rigid groups via one-hot representations.""" + chi_angles_index = {} + one_hots = [] + + for k, v in chi_angles_atoms.items(): + indices = [atom_types.index(s[atom_index]) for s in v] + indices.extend([-1]*(4-len(indices))) + chi_angles_index[k] = indices + + for r in restypes: + res3 = restype_1to3[r] + one_hot = np.eye(atom_type_num)[chi_angles_index[res3]] + one_hots.append(one_hot) + + one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`. + one_hot = np.stack(one_hots, axis=0) + one_hot = np.transpose(one_hot, [0, 2, 1]) + + return one_hot + +chi_atom_1_one_hot = chi_angle_atom(1) +chi_atom_2_one_hot = chi_angle_atom(2) + +# An array like chi_angles_atoms but using indices rather than names. +chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes] +chi_angles_atom_indices = tree.map_structure( + lambda atom_name: atom_order[atom_name], chi_angles_atom_indices) +chi_angles_atom_indices = np.array([ + chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) + for chi_atoms in chi_angles_atom_indices]) + +# Mapping from (res_name, atom_name) pairs to the atom's chi group index +# and atom index within that group. +chi_groups_for_atom = collections.defaultdict(list) +for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items(): + for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): + for atom_i, atom in enumerate(chi_group): + chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i)) +chi_groups_for_atom = dict(chi_groups_for_atom) + + +def _make_rigid_transformation_4x4(ex, ey, translation): + """Create a rigid 4x4 transformation matrix from two axes and transl.""" + # Normalize ex. + ex_normalized = ex / np.linalg.norm(ex) + + # make ey perpendicular to ex + ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized + ey_normalized /= np.linalg.norm(ey_normalized) + + # compute ez as cross product + eznorm = np.cross(ex_normalized, ey_normalized) + m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose() + m = np.concatenate([m, [[0., 0., 0., 1.]]], axis=0) + return m + + +# create an array with (restype, atomtype) --> rigid_group_idx +# and an array with (restype, atomtype, coord) for the atom positions +# and compute affine transformation matrices (4,4) from one rigid group to the +# previous group +restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int32) +restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) +restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32) +restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int32) +restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) +restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) +restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) + +############################################### +restype_atom14_to_atom37 = [] +restype_atom37_to_atom14 = [] +for rt in restypes: + atom_names = restype_name_to_atom14_names[restype_1to3[rt]] + restype_atom14_to_atom37.append([(atom_order[name] if name else 0) for name in atom_names]) + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append([(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in atom_types]) +restype_atom14_to_atom37.append([0] * 14) +restype_atom37_to_atom14.append([0] * 37) +restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32) +restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32) +################################################ + +def _make_rigid_group_constants(): + """Fill the arrays above.""" + + + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + for atomname, group_idx, atom_position in rigid_group_atom_positions[resname]: + atomtype = atom_order[atomname] + restype_atom37_to_rigid_group[restype, atomtype] = group_idx + restype_atom37_mask[restype, atomtype] = 1 + restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position + + atom14idx = restype_name_to_atom14_names[resname].index(atomname) + restype_atom14_to_rigid_group[restype, atom14idx] = group_idx + restype_atom14_mask[restype, atom14idx] = 1 + restype_atom14_rigid_group_positions[restype, atom14idx, :] = atom_position + + atom_names = residue_atoms[resname] + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_positions = {name: np.array(pos) for name, _, pos + in rigid_group_atom_positions[resname]} + + # backbone to backbone is the identity transform + restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4) + + # pre-omega-frame to backbone (currently dummy identity matrix) + restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4) + + # phi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['N'] - atom_positions['CA'], + ey=np.array([1., 0., 0.]), + translation=atom_positions['N']) + restype_rigid_group_default_frame[restype, 2, :, :] = mat + + # psi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['C'] - atom_positions['CA'], + ey=atom_positions['CA'] - atom_positions['N'], + translation=atom_positions['C']) + restype_rigid_group_default_frame[restype, 3, :, :] = mat + + # chi1-frame to backbone + if chi_angles_mask[restype][0]: + base_atom_names = chi_angles_atoms[resname][0] + base_atom_positions = [atom_positions[name] for name in base_atom_names] + mat = _make_rigid_transformation_4x4( + ex=base_atom_positions[2] - base_atom_positions[1], + ey=base_atom_positions[0] - base_atom_positions[1], + translation=base_atom_positions[2]) + restype_rigid_group_default_frame[restype, 4, :, :] = mat + + # chi2-frame to chi1-frame + # chi3-frame to chi2-frame + # chi4-frame to chi3-frame + # luckily all rotation axes for the next frame start at (0,0,0) of the + # previous frame + for chi_idx in range(1, 4): + if chi_angles_mask[restype][chi_idx]: + axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2] + axis_end_atom_position = atom_positions[axis_end_atom_name] + mat = _make_rigid_transformation_4x4( + ex=axis_end_atom_position, + ey=np.array([-1., 0., 0.]), + translation=axis_end_atom_position) + restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat + + +_make_rigid_group_constants() + + +def make_atom14_dists_bounds(overlap_tolerance=1.5, + bond_length_tolerance_factor=15): + """compute upper and lower bounds for bonds to assess violations.""" + restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32) + residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props() + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_list = restype_name_to_atom14_names[resname] + + # create lower and upper bounds for clashes + for atom1_idx, atom1_name in enumerate(atom_list): + if not atom1_name: + continue + atom1_radius = van_der_waals_radius[atom1_name[0]] + for atom2_idx, atom2_name in enumerate(atom_list): + if (not atom2_name) or atom1_idx == atom2_idx: + continue + atom2_radius = van_der_waals_radius[atom2_name[0]] + lower = atom1_radius + atom2_radius - overlap_tolerance + upper = 1e10 + restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper + + # overwrite lower and upper bounds for bonds and angles + for b in residue_bonds[resname] + residue_virtual_bonds[resname]: + atom1_idx = atom_list.index(b.atom1_name) + atom2_idx = atom_list.index(b.atom2_name) + lower = b.length - bond_length_tolerance_factor * b.stddev + upper = b.length + bond_length_tolerance_factor * b.stddev + restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper + restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev + restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev + return {'lower_bound': restype_atom14_bond_lower_bound, # shape (21,14,14) + 'upper_bound': restype_atom14_bond_upper_bound, # shape (21,14,14) + 'stddev': restype_atom14_bond_stddev, # shape (21,14,14) + } diff --git a/build/lib/colabdesign/af/alphafold/data/__init__.py b/build/lib/colabdesign/af/alphafold/data/__init__.py new file mode 100644 index 00000000..9821d212 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/data/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data pipeline for model features.""" diff --git a/build/lib/colabdesign/af/alphafold/data/mmcif_parsing.py b/build/lib/colabdesign/af/alphafold/data/mmcif_parsing.py new file mode 100644 index 00000000..18375165 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/data/mmcif_parsing.py @@ -0,0 +1,384 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parses the mmCIF file format.""" +import collections +import dataclasses +import io +from typing import Any, Mapping, Optional, Sequence, Tuple + +from absl import logging +from Bio import PDB +from Bio.Data import SCOPData + +# Type aliases: +ChainId = str +PdbHeader = Mapping[str, Any] +PdbStructure = PDB.Structure.Structure +SeqRes = str +MmCIFDict = Mapping[str, Sequence[str]] + + +@dataclasses.dataclass(frozen=True) +class Monomer: + id: str + num: int + + +# Note - mmCIF format provides no guarantees on the type of author-assigned +# sequence numbers. They need not be integers. +@dataclasses.dataclass(frozen=True) +class AtomSite: + residue_name: str + author_chain_id: str + mmcif_chain_id: str + author_seq_num: str + mmcif_seq_num: int + insertion_code: str + hetatm_atom: str + model_num: int + + +# Used to map SEQRES index to a residue in the structure. +@dataclasses.dataclass(frozen=True) +class ResiduePosition: + chain_id: str + residue_number: int + insertion_code: str + + +@dataclasses.dataclass(frozen=True) +class ResidueAtPosition: + position: Optional[ResiduePosition] + name: str + is_missing: bool + hetflag: str + + +@dataclasses.dataclass(frozen=True) +class MmcifObject: + """Representation of a parsed mmCIF file. + + Contains: + file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all + files being processed. + header: Biopython header. + structure: Biopython structure. + chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g. + {'A': 'ABCDEFG'} + seqres_to_structure: Dict; for each chain_id contains a mapping between + SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition, + 1: ResidueAtPosition, + ...}} + raw_string: The raw string used to construct the MmcifObject. + """ + file_id: str + header: PdbHeader + structure: PdbStructure + chain_to_seqres: Mapping[ChainId, SeqRes] + seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]] + raw_string: Any + + +@dataclasses.dataclass(frozen=True) +class ParsingResult: + """Returned by the parse function. + + Contains: + mmcif_object: A MmcifObject, may be None if no chain could be successfully + parsed. + errors: A dict mapping (file_id, chain_id) to any exception generated. + """ + mmcif_object: Optional[MmcifObject] + errors: Mapping[Tuple[str, str], Any] + + +class ParseError(Exception): + """An error indicating that an mmCIF file could not be parsed.""" + + +def mmcif_loop_to_list(prefix: str, + parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]: + """Extracts loop associated with a prefix from mmCIF data as a list. + + Reference for loop_ in mmCIF: + http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html + + Args: + prefix: Prefix shared by each of the data items in the loop. + e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, + _entity_poly_seq.mon_id. Should include the trailing period. + parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython + parser. + + Returns: + Returns a list of dicts; each dict represents 1 entry from an mmCIF loop. + """ + cols = [] + data = [] + for key, value in parsed_info.items(): + if key.startswith(prefix): + cols.append(key) + data.append(value) + + assert all([len(xs) == len(data[0]) for xs in data]), ( + 'mmCIF error: Not all loops are the same length: %s' % cols) + + return [dict(zip(cols, xs)) for xs in zip(*data)] + + +def mmcif_loop_to_dict(prefix: str, + index: str, + parsed_info: MmCIFDict, + ) -> Mapping[str, Mapping[str, str]]: + """Extracts loop associated with a prefix from mmCIF data as a dictionary. + + Args: + prefix: Prefix shared by each of the data items in the loop. + e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, + _entity_poly_seq.mon_id. Should include the trailing period. + index: Which item of loop data should serve as the key. + parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython + parser. + + Returns: + Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop, + indexed by the index column. + """ + entries = mmcif_loop_to_list(prefix, parsed_info) + return {entry[index]: entry for entry in entries} + + +def parse(*, + file_id: str, + mmcif_string: str, + catch_all_errors: bool = True) -> ParsingResult: + """Entry point, parses an mmcif_string. + + Args: + file_id: A string identifier for this file. Should be unique within the + collection of files being processed. + mmcif_string: Contents of an mmCIF file. + catch_all_errors: If True, all exceptions are caught and error messages are + returned as part of the ParsingResult. If False exceptions will be allowed + to propagate. + + Returns: + A ParsingResult. + """ + errors = {} + try: + parser = PDB.MMCIFParser(QUIET=True) + handle = io.StringIO(mmcif_string) + full_structure = parser.get_structure('', handle) + first_model_structure = _get_first_model(full_structure) + # Extract the _mmcif_dict from the parser, which contains useful fields not + # reflected in the Biopython structure. + parsed_info = parser._mmcif_dict # pylint:disable=protected-access + + # Ensure all values are lists, even if singletons. + for key, value in parsed_info.items(): + if not isinstance(value, list): + parsed_info[key] = [value] + + header = _get_header(parsed_info) + + # Determine the protein chains, and their start numbers according to the + # internal mmCIF numbering scheme (likely but not guaranteed to be 1). + valid_chains = _get_protein_chains(parsed_info=parsed_info) + if not valid_chains: + return ParsingResult( + None, {(file_id, ''): 'No protein chains found in this file.'}) + seq_start_num = {chain_id: min([monomer.num for monomer in seq]) + for chain_id, seq in valid_chains.items()} + + # Loop over the atoms for which we have coordinates. Populate two mappings: + # -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used + # the authors / Biopython). + # -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition). + mmcif_to_author_chain_id = {} + seq_to_structure_mappings = {} + for atom in _get_atom_site_list(parsed_info): + if atom.model_num != '1': + # We only process the first model at the moment. + continue + + mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id + + if atom.mmcif_chain_id in valid_chains: + hetflag = ' ' + if atom.hetatm_atom == 'HETATM': + # Water atoms are assigned a special hetflag of W in Biopython. We + # need to do the same, so that this hetflag can be used to fetch + # a residue from the Biopython structure by id. + if atom.residue_name in ('HOH', 'WAT'): + hetflag = 'W' + else: + hetflag = 'H_' + atom.residue_name + insertion_code = atom.insertion_code + if not _is_set(atom.insertion_code): + insertion_code = ' ' + position = ResiduePosition(chain_id=atom.author_chain_id, + residue_number=int(atom.author_seq_num), + insertion_code=insertion_code) + seq_idx = int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id] + current = seq_to_structure_mappings.get(atom.author_chain_id, {}) + current[seq_idx] = ResidueAtPosition(position=position, + name=atom.residue_name, + is_missing=False, + hetflag=hetflag) + seq_to_structure_mappings[atom.author_chain_id] = current + + # Add missing residue information to seq_to_structure_mappings. + for chain_id, seq_info in valid_chains.items(): + author_chain = mmcif_to_author_chain_id[chain_id] + current_mapping = seq_to_structure_mappings[author_chain] + for idx, monomer in enumerate(seq_info): + if idx not in current_mapping: + current_mapping[idx] = ResidueAtPosition(position=None, + name=monomer.id, + is_missing=True, + hetflag=' ') + + author_chain_to_sequence = {} + for chain_id, seq_info in valid_chains.items(): + author_chain = mmcif_to_author_chain_id[chain_id] + seq = [] + for monomer in seq_info: + code = SCOPData.protein_letters_3to1.get(monomer.id, 'X') + seq.append(code if len(code) == 1 else 'X') + seq = ''.join(seq) + author_chain_to_sequence[author_chain] = seq + + mmcif_object = MmcifObject( + file_id=file_id, + header=header, + structure=first_model_structure, + chain_to_seqres=author_chain_to_sequence, + seqres_to_structure=seq_to_structure_mappings, + raw_string=parsed_info) + + return ParsingResult(mmcif_object=mmcif_object, errors=errors) + except Exception as e: # pylint:disable=broad-except + errors[(file_id, '')] = e + if not catch_all_errors: + raise + return ParsingResult(mmcif_object=None, errors=errors) + + +def _get_first_model(structure: PdbStructure) -> PdbStructure: + """Returns the first model in a Biopython structure.""" + return next(structure.get_models()) + +_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21 + + +def get_release_date(parsed_info: MmCIFDict) -> str: + """Returns the oldest revision date.""" + revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date'] + return min(revision_dates) + + +def _get_header(parsed_info: MmCIFDict) -> PdbHeader: + """Returns a basic header containing method, release date and resolution.""" + header = {} + + experiments = mmcif_loop_to_list('_exptl.', parsed_info) + header['structure_method'] = ','.join([ + experiment['_exptl.method'].lower() for experiment in experiments]) + + # Note: The release_date here corresponds to the oldest revision. We prefer to + # use this for dataset filtering over the deposition_date. + if '_pdbx_audit_revision_history.revision_date' in parsed_info: + header['release_date'] = get_release_date(parsed_info) + else: + logging.warning('Could not determine release_date: %s', + parsed_info['_entry.id']) + + header['resolution'] = 0.00 + for res_key in ('_refine.ls_d_res_high', '_em_3d_reconstruction.resolution', + '_reflns.d_resolution_high'): + if res_key in parsed_info: + try: + raw_resolution = parsed_info[res_key][0] + header['resolution'] = float(raw_resolution) + except ValueError: + logging.warning('Invalid resolution format: %s', parsed_info[res_key]) + + return header + + +def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]: + """Returns list of atom sites; contains data not present in the structure.""" + return [AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension + parsed_info['_atom_site.label_comp_id'], + parsed_info['_atom_site.auth_asym_id'], + parsed_info['_atom_site.label_asym_id'], + parsed_info['_atom_site.auth_seq_id'], + parsed_info['_atom_site.label_seq_id'], + parsed_info['_atom_site.pdbx_PDB_ins_code'], + parsed_info['_atom_site.group_PDB'], + parsed_info['_atom_site.pdbx_PDB_model_num'], + )] + + +def _get_protein_chains( + *, parsed_info: Mapping[str, Any]) -> Mapping[ChainId, Sequence[Monomer]]: + """Extracts polymer information for protein chains only. + + Args: + parsed_info: _mmcif_dict produced by the Biopython parser. + + Returns: + A dict mapping mmcif chain id to a list of Monomers. + """ + # Get polymer information for each entity in the structure. + entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info) + + polymers = collections.defaultdict(list) + for entity_poly_seq in entity_poly_seqs: + polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append( + Monomer(id=entity_poly_seq['_entity_poly_seq.mon_id'], + num=int(entity_poly_seq['_entity_poly_seq.num']))) + + # Get chemical compositions. Will allow us to identify which of these polymers + # are proteins. + chem_comps = mmcif_loop_to_dict('_chem_comp.', '_chem_comp.id', parsed_info) + + # Get chains information for each entity. Necessary so that we can return a + # dict keyed on chain id rather than entity. + struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info) + + entity_to_mmcif_chains = collections.defaultdict(list) + for struct_asym in struct_asyms: + chain_id = struct_asym['_struct_asym.id'] + entity_id = struct_asym['_struct_asym.entity_id'] + entity_to_mmcif_chains[entity_id].append(chain_id) + + # Identify and return the valid protein chains. + valid_chains = {} + for entity_id, seq_info in polymers.items(): + chain_ids = entity_to_mmcif_chains[entity_id] + + # Reject polymers without any peptide-like components, such as DNA/RNA. + if any(['peptide' in chem_comps[monomer.id]['_chem_comp.type'] + for monomer in seq_info]): + for chain_id in chain_ids: + valid_chains[chain_id] = seq_info + return valid_chains + + +def _is_set(data: str) -> bool: + """Returns False if data is a special mmCIF character indicating 'unset'.""" + return data not in ('.', '?') diff --git a/build/lib/colabdesign/af/alphafold/data/parsers.py b/build/lib/colabdesign/af/alphafold/data/parsers.py new file mode 100644 index 00000000..edc21bbe --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/data/parsers.py @@ -0,0 +1,364 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for parsing various file formats.""" +import collections +import dataclasses +import re +import string +from typing import Dict, Iterable, List, Optional, Sequence, Tuple + +DeletionMatrix = Sequence[Sequence[int]] + + +@dataclasses.dataclass(frozen=True) +class TemplateHit: + """Class representing a template hit.""" + index: int + name: str + aligned_cols: int + sum_probs: float + query: str + hit_sequence: str + indices_query: List[int] + indices_hit: List[int] + + +def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: + """Parses FASTA string and returns list of strings with amino-acid sequences. + + Arguments: + fasta_string: The string contents of a FASTA file. + + Returns: + A tuple of two lists: + * A list of sequences. + * A list of sequence descriptions taken from the comment lines. In the + same order as the sequences. + """ + sequences = [] + descriptions = [] + index = -1 + for line in fasta_string.splitlines(): + line = line.strip() + if line.startswith('>'): + index += 1 + descriptions.append(line[1:]) # Remove the '>' at the beginning. + sequences.append('') + continue + elif not line: + continue # Skip blank lines. + sequences[index] += line + + return sequences, descriptions + + +def parse_stockholm( + stockholm_string: str +) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]: + """Parses sequences and deletion matrix from stockholm format alignment. + + Args: + stockholm_string: The string contents of a stockholm file. The first + sequence in the file should be the query sequence. + + Returns: + A tuple of: + * A list of sequences that have been aligned to the query. These + might contain duplicates. + * The deletion matrix for the alignment as a list of lists. The element + at `deletion_matrix[i][j]` is the number of residues deleted from + the aligned sequence i at residue position j. + * The names of the targets matched, including the jackhmmer subsequence + suffix. + """ + name_to_sequence = collections.OrderedDict() + for line in stockholm_string.splitlines(): + line = line.strip() + if not line or line.startswith(('#', '//')): + continue + name, sequence = line.split() + if name not in name_to_sequence: + name_to_sequence[name] = '' + name_to_sequence[name] += sequence + + msa = [] + deletion_matrix = [] + + query = '' + keep_columns = [] + for seq_index, sequence in enumerate(name_to_sequence.values()): + if seq_index == 0: + # Gather the columns with gaps from the query + query = sequence + keep_columns = [i for i, res in enumerate(query) if res != '-'] + + # Remove the columns with gaps in the query from all sequences. + aligned_sequence = ''.join([sequence[c] for c in keep_columns]) + + msa.append(aligned_sequence) + + # Count the number of deletions w.r.t. query. + deletion_vec = [] + deletion_count = 0 + for seq_res, query_res in zip(sequence, query): + if seq_res != '-' or query_res != '-': + if query_res == '-': + deletion_count += 1 + else: + deletion_vec.append(deletion_count) + deletion_count = 0 + deletion_matrix.append(deletion_vec) + + return msa, deletion_matrix, list(name_to_sequence.keys()) + + +def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: + """Parses sequences and deletion matrix from a3m format alignment. + + Args: + a3m_string: The string contents of a a3m file. The first sequence in the + file should be the query sequence. + + Returns: + A tuple of: + * A list of sequences that have been aligned to the query. These + might contain duplicates. + * The deletion matrix for the alignment as a list of lists. The element + at `deletion_matrix[i][j]` is the number of residues deleted from + the aligned sequence i at residue position j. + """ + sequences, _ = parse_fasta(a3m_string) + deletion_matrix = [] + for msa_sequence in sequences: + deletion_vec = [] + deletion_count = 0 + for j in msa_sequence: + if j.islower(): + deletion_count += 1 + else: + deletion_vec.append(deletion_count) + deletion_count = 0 + deletion_matrix.append(deletion_vec) + + # Make the MSA matrix out of aligned (deletion-free) sequences. + deletion_table = str.maketrans('', '', string.ascii_lowercase) + aligned_sequences = [s.translate(deletion_table) for s in sequences] + return aligned_sequences, deletion_matrix + + +def _convert_sto_seq_to_a3m( + query_non_gaps: Sequence[bool], sto_seq: str) -> Iterable[str]: + for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq): + if is_query_res_non_gap: + yield sequence_res + elif sequence_res != '-': + yield sequence_res.lower() + + +def convert_stockholm_to_a3m(stockholm_format: str, + max_sequences: Optional[int] = None) -> str: + """Converts MSA in Stockholm format to the A3M format.""" + descriptions = {} + sequences = {} + reached_max_sequences = False + + for line in stockholm_format.splitlines(): + reached_max_sequences = max_sequences and len(sequences) >= max_sequences + if line.strip() and not line.startswith(('#', '//')): + # Ignore blank lines, markup and end symbols - remainder are alignment + # sequence parts. + seqname, aligned_seq = line.split(maxsplit=1) + if seqname not in sequences: + if reached_max_sequences: + continue + sequences[seqname] = '' + sequences[seqname] += aligned_seq + + for line in stockholm_format.splitlines(): + if line[:4] == '#=GS': + # Description row - example format is: + # #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ... + columns = line.split(maxsplit=3) + seqname, feature = columns[1:3] + value = columns[3] if len(columns) == 4 else '' + if feature != 'DE': + continue + if reached_max_sequences and seqname not in sequences: + continue + descriptions[seqname] = value + if len(descriptions) == len(sequences): + break + + # Convert sto format to a3m line by line + a3m_sequences = {} + # query_sequence is assumed to be the first sequence + query_sequence = next(iter(sequences.values())) + query_non_gaps = [res != '-' for res in query_sequence] + for seqname, sto_sequence in sequences.items(): + a3m_sequences[seqname] = ''.join( + _convert_sto_seq_to_a3m(query_non_gaps, sto_sequence)) + + fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}" + for k in a3m_sequences) + return '\n'.join(fasta_chunks) + '\n' # Include terminating newline. + + +def _get_hhr_line_regex_groups( + regex_pattern: str, line: str) -> Sequence[Optional[str]]: + match = re.match(regex_pattern, line) + if match is None: + raise RuntimeError(f'Could not parse query line {line}') + return match.groups() + + +def _update_hhr_residue_indices_list( + sequence: str, start_index: int, indices_list: List[int]): + """Computes the relative indices for each residue with respect to the original sequence.""" + counter = start_index + for symbol in sequence: + if symbol == '-': + indices_list.append(-1) + else: + indices_list.append(counter) + counter += 1 + + +def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: + """Parses the detailed HMM HMM comparison section for a single Hit. + + This works on .hhr files generated from both HHBlits and HHSearch. + + Args: + detailed_lines: A list of lines from a single comparison section between 2 + sequences (which each have their own HMM's) + + Returns: + A dictionary with the information from that detailed comparison section + + Raises: + RuntimeError: If a certain line cannot be processed + """ + # Parse first 2 lines. + number_of_hit = int(detailed_lines[0].split()[-1]) + name_hit = detailed_lines[1][1:] + + # Parse the summary line. + pattern = ( + 'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t' + ' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t ' + ']*Template_Neff=(.*)') + match = re.match(pattern, detailed_lines[2]) + if match is None: + raise RuntimeError( + 'Could not parse section: %s. Expected this: \n%s to contain summary.' % + (detailed_lines, detailed_lines[2])) + (prob_true, e_value, _, aligned_cols, _, _, sum_probs, + neff) = [float(x) for x in match.groups()] + + # The next section reads the detailed comparisons. These are in a 'human + # readable' format which has a fixed length. The strategy employed is to + # assume that each block starts with the query sequence line, and to parse + # that with a regexp in order to deduce the fixed length used for that block. + query = '' + hit_sequence = '' + indices_query = [] + indices_hit = [] + length_block = None + + for line in detailed_lines[3:]: + # Parse the query sequence line + if (line.startswith('Q ') and not line.startswith('Q ss_dssp') and + not line.startswith('Q ss_pred') and + not line.startswith('Q Consensus')): + # Thus the first 17 characters must be 'Q ', and we can parse + # everything after that. + # start sequence end total_sequence_length + patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)' + groups = _get_hhr_line_regex_groups(patt, line[17:]) + + # Get the length of the parsed block using the start and finish indices, + # and ensure it is the same as the actual block length. + start = int(groups[0]) - 1 # Make index zero based. + delta_query = groups[1] + end = int(groups[2]) + num_insertions = len([x for x in delta_query if x == '-']) + length_block = end - start + num_insertions + assert length_block == len(delta_query) + + # Update the query sequence and indices list. + query += delta_query + _update_hhr_residue_indices_list(delta_query, start, indices_query) + + elif line.startswith('T '): + # Parse the hit sequence. + if (not line.startswith('T ss_dssp') and + not line.startswith('T ss_pred') and + not line.startswith('T Consensus')): + # Thus the first 17 characters must be 'T ', and we can + # parse everything after that. + # start sequence end total_sequence_length + patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)' + groups = _get_hhr_line_regex_groups(patt, line[17:]) + start = int(groups[0]) - 1 # Make index zero based. + delta_hit_sequence = groups[1] + assert length_block == len(delta_hit_sequence) + + # Update the hit sequence and indices list. + hit_sequence += delta_hit_sequence + _update_hhr_residue_indices_list(delta_hit_sequence, start, indices_hit) + + return TemplateHit( + index=number_of_hit, + name=name_hit, + aligned_cols=int(aligned_cols), + sum_probs=sum_probs, + query=query, + hit_sequence=hit_sequence, + indices_query=indices_query, + indices_hit=indices_hit, + ) + + +def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]: + """Parses the content of an entire HHR file.""" + lines = hhr_string.splitlines() + + # Each .hhr file starts with a results table, then has a sequence of hit + # "paragraphs", each paragraph starting with a line 'No '. We + # iterate through each paragraph to parse each hit. + + block_starts = [i for i, line in enumerate(lines) if line.startswith('No ')] + + hits = [] + if block_starts: + block_starts.append(len(lines)) # Add the end of the final block. + for i in range(len(block_starts) - 1): + hits.append(_parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]])) + return hits + + +def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]: + """Parse target to e-value mapping parsed from Jackhmmer tblout string.""" + e_values = {'query': 0} + lines = [line for line in tblout.splitlines() if line[0] != '#'] + # As per http://eddylab.org/software/hmmer/Userguide.pdf fields are + # space-delimited. Relevant fields are (1) target name: and + # (5) E-value (full sequence) (numbering from 1). + for line in lines: + fields = line.split() + e_value = fields[4] + target_name = fields[0] + e_values[target_name] = float(e_value) + return e_values diff --git a/build/lib/colabdesign/af/alphafold/data/pipeline.py b/build/lib/colabdesign/af/alphafold/data/pipeline.py new file mode 100644 index 00000000..e9388227 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/data/pipeline.py @@ -0,0 +1,72 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for building the input features for the AlphaFold model.""" + +import os +from typing import Mapping, Optional, Sequence +from absl import logging +from colabdesign.af.alphafold.common import residue_constants +from colabdesign.af.alphafold.data import parsers +import numpy as np + +# Internal import (7716). + +FeatureDict = Mapping[str, np.ndarray] +def make_sequence_features( + sequence: str, description: str, num_res: int) -> FeatureDict: + """Constructs a feature dict of sequence features.""" + features = {} + features['aatype'] = residue_constants.sequence_to_onehot( + sequence=sequence, + mapping=residue_constants.restype_order_with_x, + map_unknown_to_x=True) + features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32) + features['domain_name'] = np.array([description.encode('utf-8')], + dtype=np.object_) + features['residue_index'] = np.array(range(num_res), dtype=np.int32) + features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32) + features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_) + return features + + +def make_msa_features( + msas: Sequence[Sequence[str]], + deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict: + """Constructs a feature dict of MSA features.""" + if not msas: + raise ValueError('At least one MSA must be provided.') + + int_msa = [] + deletion_matrix = [] + seen_sequences = set() + for msa_index, msa in enumerate(msas): + if not msa: + raise ValueError(f'MSA {msa_index} must contain at least one sequence.') + for sequence_index, sequence in enumerate(msa): + if sequence in seen_sequences: + continue + seen_sequences.add(sequence) + int_msa.append( + [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) + deletion_matrix.append(deletion_matrices[msa_index][sequence_index]) + + num_res = len(msas[0][0]) + num_alignments = len(int_msa) + features = {} + features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) + features['msa'] = np.array(int_msa, dtype=np.int32) + features['num_alignments'] = np.array( + [num_alignments] * num_res, dtype=np.int32) + return features \ No newline at end of file diff --git a/build/lib/colabdesign/af/alphafold/data/pipeline_multimer.py b/build/lib/colabdesign/af/alphafold/data/pipeline_multimer.py new file mode 100644 index 00000000..012408cd --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/data/pipeline_multimer.py @@ -0,0 +1,284 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for building the features for the AlphaFold multimer model.""" + +import collections +import contextlib +import copy +import dataclasses +import json +import os +import tempfile +from typing import Mapping, MutableMapping, Sequence + +from absl import logging +from alphafold.common import protein +from alphafold.common import residue_constants +from alphafold.data import feature_processing +from alphafold.data import msa_pairing +from alphafold.data import parsers +from alphafold.data import pipeline +from alphafold.data.tools import jackhmmer +import numpy as np + +# Internal import (7716). + + +@dataclasses.dataclass(frozen=True) +class _FastaChain: + sequence: str + description: str + + +def _make_chain_id_map(*, + sequences: Sequence[str], + descriptions: Sequence[str], + ) -> Mapping[str, _FastaChain]: + """Makes a mapping from PDB-format chain ID to sequence and description.""" + if len(sequences) != len(descriptions): + raise ValueError('sequences and descriptions must have equal length. ' + f'Got {len(sequences)} != {len(descriptions)}.') + if len(sequences) > protein.PDB_MAX_CHAINS: + raise ValueError('Cannot process more chains than the PDB format supports. ' + f'Got {len(sequences)} chains.') + chain_id_map = {} + for chain_id, sequence, description in zip( + protein.PDB_CHAIN_IDS, sequences, descriptions): + chain_id_map[chain_id] = _FastaChain( + sequence=sequence, description=description) + return chain_id_map + + +@contextlib.contextmanager +def temp_fasta_file(fasta_str: str): + with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file: + fasta_file.write(fasta_str) + fasta_file.seek(0) + yield fasta_file.name + + +def convert_monomer_features( + monomer_features: pipeline.FeatureDict, + chain_id: str) -> pipeline.FeatureDict: + """Reshapes and modifies monomer features for multimer models.""" + converted = {} + converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_) + unnecessary_leading_dim_feats = { + 'sequence', 'domain_name', 'num_alignments', 'seq_length'} + for feature_name, feature in monomer_features.items(): + if feature_name in unnecessary_leading_dim_feats: + # asarray ensures it's a np.ndarray. + feature = np.asarray(feature[0], dtype=feature.dtype) + elif feature_name == 'aatype': + # The multimer model performs the one-hot operation itself. + feature = np.argmax(feature, axis=-1).astype(np.int32) + elif feature_name == 'template_aatype': + feature = np.argmax(feature, axis=-1).astype(np.int32) + new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + feature = np.take(new_order_list, feature.astype(np.int32), axis=0) + elif feature_name == 'template_all_atom_mask': + feature_name = 'template_all_atom_mask' + converted[feature_name] = feature + return converted + + +def int_id_to_str_id(num: int) -> str: + """Encodes a number as a string, using reverse spreadsheet style naming. + + Args: + num: A positive integer. + + Returns: + A string that encodes the positive integer using reverse spreadsheet style, + naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the + usual way to encode chain IDs in mmCIF files. + """ + if num <= 0: + raise ValueError(f'Only positive integers allowed, got {num}.') + + num = num - 1 # 1-based indexing. + output = [] + while num >= 0: + output.append(chr(num % 26 + ord('A'))) + num = num // 26 - 1 + return ''.join(output) + + +def add_assembly_features( + all_chain_features: MutableMapping[str, pipeline.FeatureDict], + ) -> MutableMapping[str, pipeline.FeatureDict]: + """Add features to distinguish between chains. + + Args: + all_chain_features: A dictionary which maps chain_id to a dictionary of + features for each chain. + + Returns: + all_chain_features: A dictionary which maps strings of the form + `_` to the corresponding chain features. E.g. two + chains from a homodimer would have keys A_1 and A_2. Two chains from a + heterodimer would have keys A_1 and B_1. + """ + # Group the chains by sequence + seq_to_entity_id = {} + grouped_chains = collections.defaultdict(list) + for chain_id, chain_features in all_chain_features.items(): + seq = str(chain_features['sequence']) + if seq not in seq_to_entity_id: + seq_to_entity_id[seq] = len(seq_to_entity_id) + 1 + grouped_chains[seq_to_entity_id[seq]].append(chain_features) + + new_all_chain_features = {} + chain_id = 1 + for entity_id, group_chain_features in grouped_chains.items(): + for sym_id, chain_features in enumerate(group_chain_features, start=1): + new_all_chain_features[ + f'{int_id_to_str_id(entity_id)}_{sym_id}'] = chain_features + seq_length = chain_features['seq_length'] + chain_features['asym_id'] = chain_id * np.ones(seq_length) + chain_features['sym_id'] = sym_id * np.ones(seq_length) + chain_features['entity_id'] = entity_id * np.ones(seq_length) + chain_id += 1 + + return new_all_chain_features + + +def pad_msa(np_example, min_num_seq): + np_example = dict(np_example) + num_seq = np_example['msa'].shape[0] + if num_seq < min_num_seq: + for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask'): + np_example[feat] = np.pad( + np_example[feat], ((0, min_num_seq - num_seq), (0, 0))) + np_example['cluster_bias_mask'] = np.pad( + np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq),)) + return np_example + + +class DataPipeline: + """Runs the alignment tools and assembles the input features.""" + + def __init__(self, + monomer_data_pipeline: pipeline.DataPipeline, + jackhmmer_binary_path: str, + uniprot_database_path: str, + max_uniprot_hits: int = 50000, + use_precomputed_msas: bool = False): + """Initializes the data pipeline. + + Args: + monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs + the data pipeline for the monomer AlphaFold system. + jackhmmer_binary_path: Location of the jackhmmer binary. + uniprot_database_path: Location of the unclustered uniprot sequences, that + will be searched with jackhmmer and used for MSA pairing. + max_uniprot_hits: The maximum number of hits to return from uniprot. + use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold. + """ + self._monomer_data_pipeline = monomer_data_pipeline + self._uniprot_msa_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=uniprot_database_path) + self._max_uniprot_hits = max_uniprot_hits + self.use_precomputed_msas = use_precomputed_msas + + def _process_single_chain( + self, + chain_id: str, + sequence: str, + description: str, + msa_output_dir: str, + is_homomer_or_monomer: bool) -> pipeline.FeatureDict: + """Runs the monomer pipeline on a single chain.""" + chain_fasta_str = f'>chain_{chain_id}\n{sequence}\n' + chain_msa_output_dir = os.path.join(msa_output_dir, chain_id) + if not os.path.exists(chain_msa_output_dir): + os.makedirs(chain_msa_output_dir) + with temp_fasta_file(chain_fasta_str) as chain_fasta_path: + logging.info('Running monomer pipeline on chain %s: %s', + chain_id, description) + chain_features = self._monomer_data_pipeline.process( + input_fasta_path=chain_fasta_path, + msa_output_dir=chain_msa_output_dir) + + # We only construct the pairing features if there are 2 or more unique + # sequences. + if not is_homomer_or_monomer: + all_seq_msa_features = self._all_seq_msa_features(chain_fasta_path, + chain_msa_output_dir) + chain_features.update(all_seq_msa_features) + return chain_features + + def _all_seq_msa_features(self, input_fasta_path, msa_output_dir): + """Get MSA features for unclustered uniprot, for pairing.""" + out_path = os.path.join(msa_output_dir, 'uniprot_hits.sto') + result = pipeline.run_msa_tool( + self._uniprot_msa_runner, input_fasta_path, out_path, 'sto', + self.use_precomputed_msas) + msa = parsers.parse_stockholm(result['sto']) + msa = msa.truncate(max_seqs=self._max_uniprot_hits) + all_seq_features = pipeline.make_msa_features([msa]) + valid_feats = msa_pairing.MSA_FEATURES + ( + 'msa_species_identifiers', + ) + feats = {f'{k}_all_seq': v for k, v in all_seq_features.items() + if k in valid_feats} + return feats + + def process(self, + input_fasta_path: str, + msa_output_dir: str) -> pipeline.FeatureDict: + """Runs alignment tools on the input sequences and creates features.""" + with open(input_fasta_path) as f: + input_fasta_str = f.read() + input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) + + chain_id_map = _make_chain_id_map(sequences=input_seqs, + descriptions=input_descs) + chain_id_map_path = os.path.join(msa_output_dir, 'chain_id_map.json') + with open(chain_id_map_path, 'w') as f: + chain_id_map_dict = {chain_id: dataclasses.asdict(fasta_chain) + for chain_id, fasta_chain in chain_id_map.items()} + json.dump(chain_id_map_dict, f, indent=4, sort_keys=True) + + all_chain_features = {} + sequence_features = {} + is_homomer_or_monomer = len(set(input_seqs)) == 1 + for chain_id, fasta_chain in chain_id_map.items(): + if fasta_chain.sequence in sequence_features: + all_chain_features[chain_id] = copy.deepcopy( + sequence_features[fasta_chain.sequence]) + continue + chain_features = self._process_single_chain( + chain_id=chain_id, + sequence=fasta_chain.sequence, + description=fasta_chain.description, + msa_output_dir=msa_output_dir, + is_homomer_or_monomer=is_homomer_or_monomer) + + chain_features = convert_monomer_features(chain_features, + chain_id=chain_id) + all_chain_features[chain_id] = chain_features + sequence_features[fasta_chain.sequence] = chain_features + + all_chain_features = add_assembly_features(all_chain_features) + + np_example = feature_processing.pair_and_merge( + all_chain_features=all_chain_features) + + # Pad MSA to avoid zero-sized extra_msa. + np_example = pad_msa(np_example, 512) + + return np_example diff --git a/build/lib/colabdesign/af/alphafold/data/prep_inputs.py b/build/lib/colabdesign/af/alphafold/data/prep_inputs.py new file mode 100644 index 00000000..f74b6f0f --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/data/prep_inputs.py @@ -0,0 +1,132 @@ +import numpy as np +from colabdesign.af.alphafold.common import residue_constants + +def make_atom14_positions(batch): + """Constructs denser atom positions (14 dimensions instead of 37).""" + restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37 + restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14 + restype_atom14_mask = [] + + for rt in residue_constants.restypes: + atom_names = residue_constants.restype_name_to_atom14_names[ + residue_constants.restype_1to3[rt]] + + restype_atom14_to_atom37.append([ + (residue_constants.atom_order[name] if name else 0) + for name in atom_names + ]) + + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append([ + (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in residue_constants.atom_types + ]) + + restype_atom14_mask.append([(1. if name else 0.) for name in atom_names]) + + # Add dummy mapping for restype 'UNK'. + restype_atom14_to_atom37.append([0] * 14) + restype_atom37_to_atom14.append([0] * 37) + restype_atom14_mask.append([0.] * 14) + + restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32) + restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32) + restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32) + + # Create the mapping for (residx, atom14) --> atom37, i.e. an array + # with shape (num_res, 14) containing the atom37 indices for this protein. + residx_atom14_to_atom37 = restype_atom14_to_atom37[batch["aatype"]] + residx_atom14_mask = restype_atom14_mask[batch["aatype"]] + + # Create a mask for known ground truth positions. + residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis( + batch["all_atom_mask"], residx_atom14_to_atom37, axis=1).astype(np.float32) + + # Gather the ground truth positions. + residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * ( + np.take_along_axis(batch["all_atom_positions"], + residx_atom14_to_atom37[..., None], + axis=1)) + + prot = {} + prot["atom14_atom_exists"] = residx_atom14_mask + prot["atom14_gt_exists"] = residx_atom14_gt_mask + prot["atom14_gt_positions"] = residx_atom14_gt_positions + + prot["residx_atom14_to_atom37"] = residx_atom14_to_atom37 + + # Create the gather indices for mapping back. + residx_atom37_to_atom14 = restype_atom37_to_atom14[batch["aatype"]] + prot["residx_atom37_to_atom14"] = residx_atom37_to_atom14 + + # Create the corresponding mask. + restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) + for restype, restype_letter in enumerate(residue_constants.restypes): + restype_name = residue_constants.restype_1to3[restype_letter] + atom_names = residue_constants.residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = residue_constants.atom_order[atom_name] + restype_atom37_mask[restype, atom_type] = 1 + + residx_atom37_mask = restype_atom37_mask[batch["aatype"]] + prot["atom37_atom_exists"] = residx_atom37_mask + + # As the atom naming is ambiguous for 7 of the 20 amino acids, provide + # alternative ground truth coordinates where the naming is swapped + restype_3 = [ + residue_constants.restype_1to3[res] for res in residue_constants.restypes + ] + restype_3 += ["UNK"] + + # Matrices for renaming ambiguous atoms. + all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3} + for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): + correspondences = np.arange(14) + for source_atom_swap, target_atom_swap in swap.items(): + source_index = residue_constants.restype_name_to_atom14_names[ + resname].index(source_atom_swap) + target_index = residue_constants.restype_name_to_atom14_names[ + resname].index(target_atom_swap) + correspondences[source_index] = target_index + correspondences[target_index] = source_index + renaming_matrix = np.zeros((14, 14), dtype=np.float32) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1. + all_matrices[resname] = renaming_matrix.astype(np.float32) + renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3]) + + # Pick the transformation matrices for the given residue sequence + # shape (num_res, 14, 14). + renaming_transform = renaming_matrices[batch["aatype"]] + + # Apply it to the ground truth positions. shape (num_res, 14, 3). + alternative_gt_positions = np.einsum("rac,rab->rbc", + residx_atom14_gt_positions, + renaming_transform) + prot["atom14_alt_gt_positions"] = alternative_gt_positions + + # Create the mask for the alternative ground truth (differs from the + # ground truth mask, if only one of the atoms in an ambiguous pair has a + # ground truth position). + alternative_gt_mask = np.einsum("ra,rab->rb", + residx_atom14_gt_mask, + renaming_transform) + + prot["atom14_alt_gt_exists"] = alternative_gt_mask + + # Create an ambiguous atoms mask. shape: (21, 14). + restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32) + for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): + for atom_name1, atom_name2 in swap.items(): + restype = residue_constants.restype_order[ + residue_constants.restype_3to1[resname]] + atom_idx1 = residue_constants.restype_name_to_atom14_names[resname].index( + atom_name1) + atom_idx2 = residue_constants.restype_name_to_atom14_names[resname].index( + atom_name2) + restype_atom14_is_ambiguous[restype, atom_idx1] = 1 + restype_atom14_is_ambiguous[restype, atom_idx2] = 1 + + # From this create an ambiguous_mask for the given sequence. + prot["atom14_atom_is_ambiguous"] = (restype_atom14_is_ambiguous[batch["aatype"]]) + return prot \ No newline at end of file diff --git a/build/lib/colabdesign/af/alphafold/data/tools/__init__.py b/build/lib/colabdesign/af/alphafold/data/tools/__init__.py new file mode 100644 index 00000000..903d0979 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/data/tools/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Python wrappers for third party tools.""" diff --git a/build/lib/colabdesign/af/alphafold/data/tools/utils.py b/build/lib/colabdesign/af/alphafold/data/tools/utils.py new file mode 100644 index 00000000..e65b8824 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/data/tools/utils.py @@ -0,0 +1,40 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common utilities for data pipeline tools.""" +import contextlib +import shutil +import tempfile +import time +from typing import Optional + +from absl import logging + + +@contextlib.contextmanager +def tmpdir_manager(base_dir: Optional[str] = None): + """Context manager that deletes a temporary directory on exit.""" + tmpdir = tempfile.mkdtemp(dir=base_dir) + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +@contextlib.contextmanager +def timing(msg: str): + logging.info('Started %s', msg) + tic = time.time() + yield + toc = time.time() + logging.info('Finished %s in %.3f seconds', msg, toc - tic) diff --git a/build/lib/colabdesign/af/alphafold/model/__init__.py b/build/lib/colabdesign/af/alphafold/model/__init__.py new file mode 100644 index 00000000..fc2efc8d --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Alphafold model.""" diff --git a/build/lib/colabdesign/af/alphafold/model/all_atom.py b/build/lib/colabdesign/af/alphafold/model/all_atom.py new file mode 100644 index 00000000..43331586 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/all_atom.py @@ -0,0 +1,1131 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Ops for all atom representations. + +Generally we employ two different representations for all atom coordinates, +one is atom37 where each heavy atom corresponds to a given position in a 37 +dimensional array, This mapping is non amino acid specific, but each slot +corresponds to an atom of a given name, for example slot 12 always corresponds +to 'C delta 1', positions that are not present for a given amino acid are +zeroed out and denoted by a mask. +The other representation we employ is called atom14, this is a more dense way +of representing atoms with 14 slots. Here a given slot will correspond to a +different kind of atom depending on amino acid type, for example slot 5 +corresponds to 'N delta 2' for Aspargine, but to 'C delta 1' for Isoleucine. +14 is chosen because it is the maximum number of heavy atoms for any standard +amino acid. +The order of slots can be found in 'residue_constants.residue_atoms'. +Internally the model uses the atom14 representation because it is +computationally more efficient. +The internal atom14 representation is turned into the atom37 at the output of +the network to facilitate easier conversion to existing protein datastructures. +""" + +from typing import Dict, Optional +from colabdesign.af.alphafold.common import residue_constants + +from colabdesign.af.alphafold.model import r3 +from colabdesign.af.alphafold.model import utils +import jax +import jax.numpy as jnp +import numpy as np + + +def squared_difference(x, y): + return jnp.square(x - y) + + +def get_chi_atom_indices(): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are + in the order specified in residue_constants.restypes + unknown residue type + at the end. For chi angles which are not defined on the residue, the + positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in residue_constants.restypes: + residue_name = residue_constants.restype_1to3[residue_name] + residue_chi_angles = residue_constants.chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append( + [residue_constants.atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return jnp.asarray(chi_atom_indices) + + +def atom14_to_atom37(atom14_data: jnp.ndarray, # (N, 14, ...) + batch: Dict[str, jnp.ndarray] + ) -> jnp.ndarray: # (N, 37, ...) + """Convert atom14 to atom37 representation.""" + assert len(atom14_data.shape) in [2, 3] + assert 'residx_atom37_to_atom14' in batch + assert 'atom37_atom_exists' in batch + + atom37_data = utils.batched_gather(atom14_data, batch['residx_atom37_to_atom14'], batch_dims=1) + + if len(atom14_data.shape) == 2: + atom37_data *= batch['atom37_atom_exists'] + elif len(atom14_data.shape) == 3: + atom37_data *= batch['atom37_atom_exists'][:, :, None].astype(atom37_data.dtype) + return atom37_data + +def atom37_to_atom14( + atom37_data: jnp.ndarray, # (N, 37, ...) + batch: Dict[str, jnp.ndarray]) -> jnp.ndarray: # (N, 14, ...) + """Convert atom14 to atom37 representation.""" + assert len(atom37_data.shape) in [2, 3] + assert 'residx_atom14_to_atom37' in batch + assert 'atom14_atom_exists' in batch + + atom14_data = utils.batched_gather(atom37_data, batch['residx_atom14_to_atom37'], batch_dims=1) + + if len(atom37_data.shape) == 2: + atom14_data *= batch['atom14_atom_exists'].astype(atom14_data.dtype) + elif len(atom37_data.shape) == 3: + atom14_data *= batch['atom14_atom_exists'][:, :, None].astype(atom14_data.dtype) + return atom14_data + + +def atom37_to_frames( + aatype: jnp.ndarray, # (...) + all_atom_positions: jnp.ndarray, # (..., 37, 3) + all_atom_mask: jnp.ndarray, # (..., 37) +) -> Dict[str, jnp.ndarray]: + """Computes the frames for the up to 8 rigid groups for each residue. + + The rigid groups are defined by the possible torsions in a given amino acid. + We group the atoms according to their dependence on the torsion angles into + "rigid groups". E.g., the position of atoms in the chi2-group depend on + chi1 and chi2, but do not depend on chi3 or chi4. + Jumper et al. (2021) Suppl. Table 2 and corresponding text. + + Args: + aatype: Amino acid type, given as array with integers. + all_atom_positions: atom37 representation of all atom coordinates. + all_atom_mask: atom37 representation of mask on all atom coordinates. + Returns: + Dictionary containing: + * 'rigidgroups_gt_frames': 8 Frames corresponding to 'all_atom_positions' + represented as flat 12 dimensional array. + * 'rigidgroups_gt_exists': Mask denoting whether the atom positions for + the given frame are available in the ground truth, e.g. if they were + resolved in the experiment. + * 'rigidgroups_group_exists': Mask denoting whether given group is in + principle present for given amino acid type. + * 'rigidgroups_group_is_ambiguous': Mask denoting whether frame is + affected by naming ambiguity. + * 'rigidgroups_alt_gt_frames': 8 Frames with alternative atom renaming + corresponding to 'all_atom_positions' represented as flat + 12 dimensional array. + """ + # 0: 'backbone group', + # 1: 'pre-omega-group', (empty) + # 2: 'phi-group', (currently empty, because it defines only hydrogens) + # 3: 'psi-group', + # 4,5,6,7: 'chi1,2,3,4-group' + aatype_in_shape = aatype.shape + + # If there is a batch axis, just flatten it away, and reshape everything + # back at the end of the function. + aatype = jnp.reshape(aatype, [-1]) + all_atom_positions = jnp.reshape(all_atom_positions, [-1, 37, 3]) + all_atom_mask = jnp.reshape(all_atom_mask, [-1, 37]) + + # Create an array with the atom names. + # shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3) + restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object) + + # 0: backbone frame + restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N'] + + # 3: 'psi-group' + restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O'] + + # 4,5,6,7: 'chi1,2,3,4-group' + for restype, restype_letter in enumerate(residue_constants.restypes): + resname = residue_constants.restype_1to3[restype_letter] + for chi_idx in range(4): + if residue_constants.chi_angles_mask[restype][chi_idx]: + atom_names = residue_constants.chi_angles_atoms[resname][chi_idx] + restype_rigidgroup_base_atom_names[ + restype, chi_idx + 4, :] = atom_names[1:] + + # Create mask for existing rigid groups. + restype_rigidgroup_mask = np.zeros([21, 8], dtype=np.float32) + restype_rigidgroup_mask[:, 0] = 1 + restype_rigidgroup_mask[:, 3] = 1 + restype_rigidgroup_mask[:20, 4:] = residue_constants.chi_angles_mask + + # Translate atom names into atom37 indices. + lookuptable = residue_constants.atom_order.copy() + lookuptable[''] = 0 + restype_rigidgroup_base_atom37_idx = np.vectorize(lambda x: lookuptable[x])( + restype_rigidgroup_base_atom_names) + + # Compute the gather indices for all residues in the chain. + # shape (N, 8, 3) + residx_rigidgroup_base_atom37_idx = utils.batched_gather( + restype_rigidgroup_base_atom37_idx, aatype) + + # Gather the base atom positions for each rigid group. + base_atom_pos = utils.batched_gather( + all_atom_positions, + residx_rigidgroup_base_atom37_idx, + batch_dims=1) + + # Compute the Rigids. + gt_frames = r3.rigids_from_3_points( + point_on_neg_x_axis=r3.vecs_from_tensor(base_atom_pos[:, :, 0, :]), + origin=r3.vecs_from_tensor(base_atom_pos[:, :, 1, :]), + point_on_xy_plane=r3.vecs_from_tensor(base_atom_pos[:, :, 2, :]) + ) + + # Compute a mask whether the group exists. + # (N, 8) + group_exists = utils.batched_gather(restype_rigidgroup_mask, aatype) + + # Compute a mask whether ground truth exists for the group + gt_atoms_exist = utils.batched_gather( # shape (N, 8, 3) + all_atom_mask.astype(jnp.float32), + residx_rigidgroup_base_atom37_idx, + batch_dims=1) + gt_exists = jnp.min(gt_atoms_exist, axis=-1) * group_exists # (N, 8) + + # Adapt backbone frame to old convention (mirror x-axis and z-axis). + rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1]) + rots[0, 0, 0] = -1 + rots[0, 2, 2] = -1 + gt_frames = r3.rigids_mul_rots(gt_frames, r3.rots_from_tensor3x3(rots)) + + # The frames for ambiguous rigid groups are just rotated by 180 degree around + # the x-axis. The ambiguous group is always the last chi-group. + restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32) + restype_rigidgroup_rots = np.tile(np.eye(3, dtype=np.float32), [21, 8, 1, 1]) + + for resname, _ in residue_constants.residue_atom_renaming_swaps.items(): + restype = residue_constants.restype_order[ + residue_constants.restype_3to1[resname]] + chi_idx = int(sum(residue_constants.chi_angles_mask[restype]) - 1) + restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1 + restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1 + restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1 + + # Gather the ambiguity information for each residue. + residx_rigidgroup_is_ambiguous = utils.batched_gather( + restype_rigidgroup_is_ambiguous, aatype) + residx_rigidgroup_ambiguity_rot = utils.batched_gather( + restype_rigidgroup_rots, aatype) + + # Create the alternative ground truth frames. + alt_gt_frames = r3.rigids_mul_rots( + gt_frames, r3.rots_from_tensor3x3(residx_rigidgroup_ambiguity_rot)) + + gt_frames_flat12 = r3.rigids_to_tensor_flat12(gt_frames) + alt_gt_frames_flat12 = r3.rigids_to_tensor_flat12(alt_gt_frames) + + # reshape back to original residue layout + gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12)) + gt_exists = jnp.reshape(gt_exists, aatype_in_shape + (8,)) + group_exists = jnp.reshape(group_exists, aatype_in_shape + (8,)) + gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12)) + residx_rigidgroup_is_ambiguous = jnp.reshape(residx_rigidgroup_is_ambiguous, + aatype_in_shape + (8,)) + alt_gt_frames_flat12 = jnp.reshape(alt_gt_frames_flat12, + aatype_in_shape + (8, 12,)) + + return { + 'rigidgroups_gt_frames': gt_frames_flat12, # (..., 8, 12) + 'rigidgroups_gt_exists': gt_exists, # (..., 8) + 'rigidgroups_group_exists': group_exists, # (..., 8) + 'rigidgroups_group_is_ambiguous': + residx_rigidgroup_is_ambiguous, # (..., 8) + 'rigidgroups_alt_gt_frames': alt_gt_frames_flat12, # (..., 8, 12) + } + + +def atom37_to_torsion_angles( + aatype: jnp.ndarray, # (B, N) + all_atom_pos: jnp.ndarray, # (B, N, 37, 3) + all_atom_mask: jnp.ndarray, # (B, N, 37) + placeholder_for_undefined=False, +) -> Dict[str, jnp.ndarray]: + """Computes the 7 torsion angles (in sin, cos encoding) for each residue. + + The 7 torsion angles are in the order + '[pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4]', + here pre_omega denotes the omega torsion angle between the given amino acid + and the previous amino acid. + + Args: + aatype: Amino acid type, given as array with integers. + all_atom_pos: atom37 representation of all atom coordinates. + all_atom_mask: atom37 representation of mask on all atom coordinates. + placeholder_for_undefined: flag denoting whether to set masked torsion + angles to zero. + Returns: + Dict containing: + * 'torsion_angles_sin_cos': Array with shape (B, N, 7, 2) where the final + 2 dimensions denote sin and cos respectively + * 'alt_torsion_angles_sin_cos': same as 'torsion_angles_sin_cos', but + with the angle shifted by pi for all chi angles affected by the naming + ambiguities. + * 'torsion_angles_mask': Mask for which chi angles are present. + """ + + # Map aatype > 20 to 'Unknown' (20). + aatype = jnp.minimum(aatype, 20) + + # Compute the backbone angles. + num_batch, num_res = aatype.shape + + pad = jnp.zeros([num_batch, 1, 37, 3], jnp.float32) + prev_all_atom_pos = jnp.concatenate([pad, all_atom_pos[:, :-1, :, :]], axis=1) + + pad = jnp.zeros([num_batch, 1, 37], jnp.float32) + prev_all_atom_mask = jnp.concatenate([pad, all_atom_mask[:, :-1, :]], axis=1) + + # For each torsion angle collect the 4 atom positions that define this angle. + # shape (B, N, atoms=4, xyz=3) + pre_omega_atom_pos = jnp.concatenate( + [prev_all_atom_pos[:, :, 1:3, :], # prev CA, C + all_atom_pos[:, :, 0:2, :] # this N, CA + ], axis=-2) + phi_atom_pos = jnp.concatenate( + [prev_all_atom_pos[:, :, 2:3, :], # prev C + all_atom_pos[:, :, 0:3, :] # this N, CA, C + ], axis=-2) + psi_atom_pos = jnp.concatenate( + [all_atom_pos[:, :, 0:3, :], # this N, CA, C + all_atom_pos[:, :, 4:5, :] # this O + ], axis=-2) + + # Collect the masks from these atoms. + # Shape [batch, num_res] + pre_omega_mask = ( + jnp.prod(prev_all_atom_mask[:, :, 1:3], axis=-1) # prev CA, C + * jnp.prod(all_atom_mask[:, :, 0:2], axis=-1)) # this N, CA + phi_mask = ( + prev_all_atom_mask[:, :, 2] # prev C + * jnp.prod(all_atom_mask[:, :, 0:3], axis=-1)) # this N, CA, C + psi_mask = ( + jnp.prod(all_atom_mask[:, :, 0:3], axis=-1) * # this N, CA, C + all_atom_mask[:, :, 4]) # this O + + # Collect the atoms for the chi-angles. + # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4]. + chi_atom_indices = get_chi_atom_indices() + # Select atoms to compute chis. Shape: [batch, num_res, chis=4, atoms=4]. + atom_indices = utils.batched_gather( + params=chi_atom_indices, indices=aatype, axis=0, batch_dims=0) + # Gather atom positions. Shape: [batch, num_res, chis=4, atoms=4, xyz=3]. + chis_atom_pos = utils.batched_gather( + params=all_atom_pos, indices=atom_indices, axis=-2, + batch_dims=2) + + # Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4]. + chi_angles_mask = list(residue_constants.chi_angles_mask) + chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) + chi_angles_mask = jnp.asarray(chi_angles_mask) + + # Compute the chi angle mask. I.e. which chis angles exist according to the + # aatype. Shape [batch, num_res, chis=4]. + chis_mask = utils.batched_gather(params=chi_angles_mask, indices=aatype, + axis=0, batch_dims=0) + + # Constrain the chis_mask to those chis, where the ground truth coordinates of + # all defining four atoms are available. + # Gather the chi angle atoms mask. Shape: [batch, num_res, chis=4, atoms=4]. + chi_angle_atoms_mask = utils.batched_gather( + params=all_atom_mask, indices=atom_indices, axis=-1, + batch_dims=2) + # Check if all 4 chi angle atoms were set. Shape: [batch, num_res, chis=4]. + chi_angle_atoms_mask = jnp.prod(chi_angle_atoms_mask, axis=[-1]) + chis_mask = chis_mask * (chi_angle_atoms_mask).astype(jnp.float32) + + # Stack all torsion angle atom positions. + # Shape (B, N, torsions=7, atoms=4, xyz=3) + torsions_atom_pos = jnp.concatenate( + [pre_omega_atom_pos[:, :, None, :, :], + phi_atom_pos[:, :, None, :, :], + psi_atom_pos[:, :, None, :, :], + chis_atom_pos + ], axis=2) + + # Stack up masks for all torsion angles. + # shape (B, N, torsions=7) + torsion_angles_mask = jnp.concatenate( + [pre_omega_mask[:, :, None], + phi_mask[:, :, None], + psi_mask[:, :, None], + chis_mask + ], axis=2) + + # Create a frame from the first three atoms: + # First atom: point on x-y-plane + # Second atom: point on negative x-axis + # Third atom: origin + # r3.Rigids (B, N, torsions=7) + torsion_frames = r3.rigids_from_3_points( + point_on_neg_x_axis=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 1, :]), + origin=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 2, :]), + point_on_xy_plane=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 0, :])) + + # Compute the position of the forth atom in this frame (y and z coordinate + # define the chi angle) + # r3.Vecs (B, N, torsions=7) + forth_atom_rel_pos = r3.rigids_mul_vecs( + r3.invert_rigids(torsion_frames), + r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 3, :])) + + # Normalize to have the sin and cos of the torsion angle. + # jnp.ndarray (B, N, torsions=7, sincos=2) + torsion_angles_sin_cos = jnp.stack( + [forth_atom_rel_pos.z, forth_atom_rel_pos.y], axis=-1) + torsion_angles_sin_cos /= jnp.sqrt( + jnp.sum(jnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True) + + 1e-8) + + # Mirror psi, because we computed it from the Oxygen-atom. + torsion_angles_sin_cos *= jnp.asarray( + [1., 1., -1., 1., 1., 1., 1.])[None, None, :, None] + + # Create alternative angles for ambiguous atom names. + chi_is_ambiguous = utils.batched_gather( + jnp.asarray(residue_constants.chi_pi_periodic), aatype) + mirror_torsion_angles = jnp.concatenate( + [jnp.ones([num_batch, num_res, 3]), + 1.0 - 2.0 * chi_is_ambiguous], axis=-1) + alt_torsion_angles_sin_cos = ( + torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None]) + + if placeholder_for_undefined: + # Add placeholder torsions in place of undefined torsion angles + # (e.g. N-terminus pre-omega) + placeholder_torsions = jnp.stack([ + jnp.ones(torsion_angles_sin_cos.shape[:-1]), + jnp.zeros(torsion_angles_sin_cos.shape[:-1]) + ], axis=-1) + torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[ + ..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) + alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask[ + ..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) + + return { + 'torsion_angles_sin_cos': torsion_angles_sin_cos, # (B, N, 7, 2) + 'alt_torsion_angles_sin_cos': alt_torsion_angles_sin_cos, # (B, N, 7, 2) + 'torsion_angles_mask': torsion_angles_mask # (B, N, 7) + } + + +def torsion_angles_to_frames( + aatype: jnp.ndarray, # (N) + backb_to_global: r3.Rigids, # (N) + torsion_angles_sin_cos: jnp.ndarray # (N, 7, 2) +) -> r3.Rigids: # (N, 8) + """Compute rigid group frames from torsion angles. + + Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" lines 2-10 + Jumper et al. (2021) Suppl. Alg. 25 "makeRotX" + + Args: + aatype: aatype for each residue + backb_to_global: Rigid transformations describing transformation from + backbone frame to global frame. + torsion_angles_sin_cos: sin and cosine of the 7 torsion angles + Returns: + Frames corresponding to all the Sidechain Rigid Transforms + """ + assert len(aatype.shape) == 1 + assert len(backb_to_global.rot.xx.shape) == 1 + assert len(torsion_angles_sin_cos.shape) == 3 + assert torsion_angles_sin_cos.shape[1] == 7 + assert torsion_angles_sin_cos.shape[2] == 2 + + # Gather the default frames for all rigid groups. + # r3.Rigids with shape (N, 8) + + m = utils.batched_gather(residue_constants.restype_rigid_group_default_frame, aatype) + + default_frames = r3.rigids_from_tensor4x4(m) + + # Create the rotation matrices according to the given angles (each frame is + # defined such that its rotation is around the x-axis). + sin_angles = torsion_angles_sin_cos[..., 0] + cos_angles = torsion_angles_sin_cos[..., 1] + + # insert zero rotation for backbone group. + num_residues, = aatype.shape + sin_angles = jnp.concatenate([jnp.zeros([num_residues, 1]), sin_angles],axis=-1) + cos_angles = jnp.concatenate([jnp.ones([num_residues, 1]), cos_angles],axis=-1) + zeros = jnp.zeros_like(sin_angles) + ones = jnp.ones_like(sin_angles) + + # all_rots are r3.Rots with shape (N, 8) + all_rots = r3.Rots(ones, zeros, zeros, + zeros, cos_angles, -sin_angles, + zeros, sin_angles, cos_angles) + + # Apply rotations to the frames. + all_frames = r3.rigids_mul_rots(default_frames, all_rots) + + # chi2, chi3, and chi4 frames do not transform to the backbone frame but to + # the previous frame. So chain them up accordingly. + chi2_frame_to_frame = jax.tree_map(lambda x: x[:, 5], all_frames) + chi3_frame_to_frame = jax.tree_map(lambda x: x[:, 6], all_frames) + chi4_frame_to_frame = jax.tree_map(lambda x: x[:, 7], all_frames) + + chi1_frame_to_backb = jax.tree_map(lambda x: x[:, 4], all_frames) + chi2_frame_to_backb = r3.rigids_mul_rigids(chi1_frame_to_backb, + chi2_frame_to_frame) + chi3_frame_to_backb = r3.rigids_mul_rigids(chi2_frame_to_backb, + chi3_frame_to_frame) + chi4_frame_to_backb = r3.rigids_mul_rigids(chi3_frame_to_backb, + chi4_frame_to_frame) + + # Recombine them to a r3.Rigids with shape (N, 8). + def _concat_frames(xall, x5, x6, x7): + return jnp.concatenate( + [xall[:, 0:5], x5[:, None], x6[:, None], x7[:, None]], axis=-1) + + all_frames_to_backb = jax.tree_map( + _concat_frames, + all_frames, + chi2_frame_to_backb, + chi3_frame_to_backb, + chi4_frame_to_backb) + + # Create the global frames. + # shape (N, 8) + all_frames_to_global = r3.rigids_mul_rigids( + jax.tree_map(lambda x: x[:, None], backb_to_global), + all_frames_to_backb) + + return all_frames_to_global + + +def frames_and_literature_positions_to_atom14_pos( + aatype: jnp.ndarray, # (N) + all_frames_to_global: r3.Rigids # (N, 8) +) -> r3.Vecs: # (N, 14) + """Put atom literature positions (atom14 encoding) in each rigid group. + + Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" line 11 + + Args: + aatype: aatype for each residue. + all_frames_to_global: All per residue coordinate frames. + Returns: + Positions of all atom coordinates in global frame. + """ + + # Pick the appropriate transform for every atom. + residx_to_group_idx = utils.batched_gather(residue_constants.restype_atom14_to_rigid_group, aatype) + group_mask = jax.nn.one_hot(residx_to_group_idx, num_classes=8) # shape (N, 14, 8) + + # r3.Rigids with shape (N, 14) + map_atoms_to_global = jax.tree_map( + lambda x: jnp.sum(x[:, None, :] * group_mask, axis=-1), + all_frames_to_global) + + # Gather the literature atom positions for each residue. + # r3.Vecs with shape (N, 14) + group_pos = utils.batched_gather(residue_constants.restype_atom14_rigid_group_positions, aatype) + lit_positions = r3.vecs_from_tensor(group_pos) + + # Transform each atom from its local frame to the global frame. + # r3.Vecs with shape (N, 14) + pred_positions = r3.rigids_mul_vecs(map_atoms_to_global, lit_positions) + + # Mask out non-existing atoms. + mask = utils.batched_gather(residue_constants.restype_atom14_mask, aatype) + pred_positions = jax.tree_map(lambda x: x * mask, pred_positions) + return pred_positions + + +def extreme_ca_ca_distance_violations( + pred_atom_positions: jnp.ndarray, # (N, 37(14), 3) + pred_atom_mask: jnp.ndarray, # (N, 37(14)) + residue_index: jnp.ndarray, # (N) + max_angstrom_tolerance=1.5 + ) -> jnp.ndarray: + """Counts residues whose Ca is a large distance from its neighbour. + + Measures the fraction of CA-CA pairs between consecutive amino acids that are + more than 'max_angstrom_tolerance' apart. + + Args: + pred_atom_positions: Atom positions in atom37/14 representation + pred_atom_mask: Atom mask in atom37/14 representation + residue_index: Residue index for given amino acid, this is assumed to be + monotonically increasing. + max_angstrom_tolerance: Maximum distance allowed to not count as violation. + Returns: + Fraction of consecutive CA-CA pairs with violation. + """ + this_ca_pos = pred_atom_positions[:-1, 1, :] # (N - 1, 3) + this_ca_mask = pred_atom_mask[:-1, 1] # (N - 1) + next_ca_pos = pred_atom_positions[1:, 1, :] # (N - 1, 3) + next_ca_mask = pred_atom_mask[1:, 1] # (N - 1) + has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype( + jnp.float32) + ca_ca_distance = jnp.sqrt( + 1e-6 + jnp.sum(squared_difference(this_ca_pos, next_ca_pos), axis=-1)) + violations = (ca_ca_distance - + residue_constants.ca_ca) > max_angstrom_tolerance + mask = this_ca_mask * next_ca_mask * has_no_gap_mask + return utils.mask_mean(mask=mask, value=violations) + + +def between_residue_bond_loss( + pred_atom_positions: jnp.ndarray, # (N, 37(14), 3) + pred_atom_mask: jnp.ndarray, # (N, 37(14)) + residue_index: jnp.ndarray, # (N) + aatype: jnp.ndarray, # (N) + tolerance_factor_soft=12.0, + tolerance_factor_hard=12.0 +) -> Dict[str, jnp.ndarray]: + """Flat-bottom loss to penalize structural violations between residues. + + This is a loss penalizing any violation of the geometry around the peptide + bond between consecutive amino acids. This loss corresponds to + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45. + + Args: + pred_atom_positions: Atom positions in atom37/14 representation + pred_atom_mask: Atom mask in atom37/14 representation + residue_index: Residue index for given amino acid, this is assumed to be + monotonically increasing. + aatype: Amino acid type of given residue + tolerance_factor_soft: soft tolerance factor measured in standard deviations + of pdb distributions + tolerance_factor_hard: hard tolerance factor measured in standard deviations + of pdb distributions + + Returns: + Dict containing: + * 'c_n_loss_mean': Loss for peptide bond length violations + * 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned + by CA, C, N + * 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned + by C, N, CA + * 'per_residue_loss_sum': sum of all losses for each residue + * 'per_residue_violation_mask': mask denoting all residues with violation + present. + """ + assert len(pred_atom_positions.shape) == 3 + assert len(pred_atom_mask.shape) == 2 + assert len(residue_index.shape) == 1 + assert len(aatype.shape) == 1 + + # Get the positions of the relevant backbone atoms. + this_ca_pos = pred_atom_positions[:-1, 1, :] # (N - 1, 3) + this_ca_mask = pred_atom_mask[:-1, 1] # (N - 1) + this_c_pos = pred_atom_positions[:-1, 2, :] # (N - 1, 3) + this_c_mask = pred_atom_mask[:-1, 2] # (N - 1) + next_n_pos = pred_atom_positions[1:, 0, :] # (N - 1, 3) + next_n_mask = pred_atom_mask[1:, 0] # (N - 1) + next_ca_pos = pred_atom_positions[1:, 1, :] # (N - 1, 3) + next_ca_mask = pred_atom_mask[1:, 1] # (N - 1) + has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype( + jnp.float32) + + # Compute loss for the C--N bond. + c_n_bond_length = jnp.sqrt( + 1e-6 + jnp.sum(squared_difference(this_c_pos, next_n_pos), axis=-1)) + + # The C-N bond to proline has slightly different length because of the ring. + next_is_proline = ( + aatype[1:] == residue_constants.resname_to_idx['PRO']).astype(jnp.float32) + gt_length = ( + (1. - next_is_proline) * residue_constants.between_res_bond_length_c_n[0] + + next_is_proline * residue_constants.between_res_bond_length_c_n[1]) + gt_stddev = ( + (1. - next_is_proline) * + residue_constants.between_res_bond_length_stddev_c_n[0] + + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[1]) + c_n_bond_length_error = jnp.sqrt(1e-6 + + jnp.square(c_n_bond_length - gt_length)) + c_n_loss_per_residue = jax.nn.relu( + c_n_bond_length_error - tolerance_factor_soft * gt_stddev) + mask = this_c_mask * next_n_mask * has_no_gap_mask + c_n_loss = jnp.sum(mask * c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6) + c_n_violation_mask = mask * ( + c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)) + + # Compute loss for the angles. + ca_c_bond_length = jnp.sqrt(1e-6 + jnp.sum( + squared_difference(this_ca_pos, this_c_pos), axis=-1)) + n_ca_bond_length = jnp.sqrt(1e-6 + jnp.sum( + squared_difference(next_n_pos, next_ca_pos), axis=-1)) + + c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[:, None] + c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[:, None] + n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[:, None] + + ca_c_n_cos_angle = jnp.sum(c_ca_unit_vec * c_n_unit_vec, axis=-1) + gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0] + gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0] + ca_c_n_cos_angle_error = jnp.sqrt( + 1e-6 + jnp.square(ca_c_n_cos_angle - gt_angle)) + ca_c_n_loss_per_residue = jax.nn.relu( + ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev) + mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask + ca_c_n_loss = jnp.sum(mask * ca_c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6) + ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error > + (tolerance_factor_hard * gt_stddev)) + + c_n_ca_cos_angle = jnp.sum((-c_n_unit_vec) * n_ca_unit_vec, axis=-1) + gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0] + gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1] + c_n_ca_cos_angle_error = jnp.sqrt( + 1e-6 + jnp.square(c_n_ca_cos_angle - gt_angle)) + c_n_ca_loss_per_residue = jax.nn.relu( + c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev) + mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask + c_n_ca_loss = jnp.sum(mask * c_n_ca_loss_per_residue) / (jnp.sum(mask) + 1e-6) + c_n_ca_violation_mask = mask * ( + c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)) + + # Compute a per residue loss (equally distribute the loss to both + # neighbouring residues). + per_residue_loss_sum = (c_n_loss_per_residue + + ca_c_n_loss_per_residue + + c_n_ca_loss_per_residue) + per_residue_loss_sum = 0.5 * (jnp.pad(per_residue_loss_sum, [[0, 1]]) + + jnp.pad(per_residue_loss_sum, [[1, 0]])) + + # Compute hard violations. + violation_mask = jnp.max( + jnp.stack([c_n_violation_mask, + ca_c_n_violation_mask, + c_n_ca_violation_mask]), axis=0) + violation_mask = jnp.maximum( + jnp.pad(violation_mask, [[0, 1]]), + jnp.pad(violation_mask, [[1, 0]])) + + return {'c_n_loss_mean': c_n_loss, # shape () + 'ca_c_n_loss_mean': ca_c_n_loss, # shape () + 'c_n_ca_loss_mean': c_n_ca_loss, # shape () + 'per_residue_loss_sum': per_residue_loss_sum, # shape (N) + 'per_residue_violation_mask': violation_mask # shape (N) + } + + +def between_residue_clash_loss( + atom14_pred_positions: jnp.ndarray, # (N, 14, 3) + atom14_atom_exists: jnp.ndarray, # (N, 14) + atom14_atom_radius: jnp.ndarray, # (N, 14) + residue_index: jnp.ndarray, # (N) + overlap_tolerance_soft=1.5, + overlap_tolerance_hard=1.5 +) -> Dict[str, jnp.ndarray]: + """Loss to penalize steric clashes between residues. + + This is a loss penalizing any steric clashes due to non bonded atoms in + different peptides coming too close. This loss corresponds to the part with + different residues of + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46. + + Args: + atom14_pred_positions: Predicted positions of atoms in + global prediction frame + atom14_atom_exists: Mask denoting whether atom at positions exists for given + amino acid type + atom14_atom_radius: Van der Waals radius for each atom. + residue_index: Residue index for given amino acid. + overlap_tolerance_soft: Soft tolerance factor. + overlap_tolerance_hard: Hard tolerance factor. + + Returns: + Dict containing: + * 'mean_loss': average clash loss + * 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14) + * 'per_atom_clash_mask': mask whether atom clashes with any other atom + shape (N, 14) + """ + assert len(atom14_pred_positions.shape) == 3 + assert len(atom14_atom_exists.shape) == 2 + assert len(atom14_atom_radius.shape) == 2 + assert len(residue_index.shape) == 1 + + # Create the distance matrix. + # (N, N, 14, 14) + dists = jnp.sqrt(1e-10 + jnp.sum( + squared_difference( + atom14_pred_positions[:, None, :, None, :], + atom14_pred_positions[None, :, None, :, :]), + axis=-1)) + + # Create the mask for valid distances. + # shape (N, N, 14, 14) + dists_mask = (atom14_atom_exists[:, None, :, None] * + atom14_atom_exists[None, :, None, :]) + + # Mask out all the duplicate entries in the lower triangular matrix. + # Also mask out the diagonal (atom-pairs from the same residue) -- these atoms + # are handled separately. + dists_mask *= ( + residue_index[:, None, None, None] < residue_index[None, :, None, None]) + + # Backbone C--N bond between subsequent residues is no clash. + c_one_hot = jax.nn.one_hot(2, num_classes=14) + n_one_hot = jax.nn.one_hot(0, num_classes=14) + neighbour_mask = ((residue_index[:, None, None, None] + + 1) == residue_index[None, :, None, None]) + c_n_bonds = neighbour_mask * c_one_hot[None, None, :, + None] * n_one_hot[None, None, None, :] + dists_mask *= (1. - c_n_bonds) + + # Disulfide bridge between two cysteines is no clash. + cys_sg_idx = residue_constants.restype_name_to_atom14_names['CYS'].index('SG') + cys_sg_one_hot = jax.nn.one_hot(cys_sg_idx, num_classes=14) + disulfide_bonds = (cys_sg_one_hot[None, None, :, None] * + cys_sg_one_hot[None, None, None, :]) + dists_mask *= (1. - disulfide_bonds) + + # Compute the lower bound for the allowed distances. + # shape (N, N, 14, 14) + dists_lower_bound = dists_mask * (atom14_atom_radius[:, None, :, None] + + atom14_atom_radius[None, :, None, :]) + + # Compute the error. + # shape (N, N, 14, 14) + dists_to_low_error = dists_mask * jax.nn.relu( + dists_lower_bound - overlap_tolerance_soft - dists) + + # Compute the mean loss. + # shape () + mean_loss = (jnp.sum(dists_to_low_error) + / (1e-6 + jnp.sum(dists_mask))) + + # Compute the per atom loss sum. + # shape (N, 14) + per_atom_loss_sum = (jnp.sum(dists_to_low_error, axis=[0, 2]) + + jnp.sum(dists_to_low_error, axis=[1, 3])) + + # Compute the hard clash mask. + # shape (N, N, 14, 14) + clash_mask = dists_mask * ( + dists < (dists_lower_bound - overlap_tolerance_hard)) + + # Compute the per atom clash. + # shape (N, 14) + per_atom_clash_mask = jnp.maximum( + jnp.max(clash_mask, axis=[0, 2]), + jnp.max(clash_mask, axis=[1, 3])) + + return {'mean_loss': mean_loss, # shape () + 'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14) + 'per_atom_clash_mask': per_atom_clash_mask # shape (N, 14) + } + + +def within_residue_violations( + atom14_pred_positions: jnp.ndarray, # (N, 14, 3) + atom14_atom_exists: jnp.ndarray, # (N, 14) + atom14_dists_lower_bound: jnp.ndarray, # (N, 14, 14) + atom14_dists_upper_bound: jnp.ndarray, # (N, 14, 14) + tighten_bounds_for_loss=0.0, +) -> Dict[str, jnp.ndarray]: + """Loss to penalize steric clashes within residues. + + This is a loss penalizing any steric violations or clashes of non-bonded atoms + in a given peptide. This loss corresponds to the part with + the same residues of + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46. + + Args: + atom14_pred_positions: Predicted positions of atoms in + global prediction frame + atom14_atom_exists: Mask denoting whether atom at positions exists for given + amino acid type + atom14_dists_lower_bound: Lower bound on allowed distances. + atom14_dists_upper_bound: Upper bound on allowed distances + tighten_bounds_for_loss: Extra factor to tighten loss + + Returns: + Dict containing: + * 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14) + * 'per_atom_clash_mask': mask whether atom clashes with any other atom + shape (N, 14) + """ + assert len(atom14_pred_positions.shape) == 3 + assert len(atom14_atom_exists.shape) == 2 + assert len(atom14_dists_lower_bound.shape) == 3 + assert len(atom14_dists_upper_bound.shape) == 3 + + # Compute the mask for each residue. + # shape (N, 14, 14) + dists_masks = (1. - jnp.eye(14, 14)[None]) + dists_masks *= (atom14_atom_exists[:, :, None] * + atom14_atom_exists[:, None, :]) + + # Distance matrix + # shape (N, 14, 14) + dists = jnp.sqrt(1e-10 + jnp.sum( + squared_difference( + atom14_pred_positions[:, :, None, :], + atom14_pred_positions[:, None, :, :]), + axis=-1)) + + # Compute the loss. + # shape (N, 14, 14) + dists_to_low_error = jax.nn.relu( + atom14_dists_lower_bound + tighten_bounds_for_loss - dists) + dists_to_high_error = jax.nn.relu( + dists - (atom14_dists_upper_bound - tighten_bounds_for_loss)) + loss = dists_masks * (dists_to_low_error + dists_to_high_error) + + # Compute the per atom loss sum. + # shape (N, 14) + per_atom_loss_sum = (jnp.sum(loss, axis=1) + + jnp.sum(loss, axis=2)) + + # Compute the violations mask. + # shape (N, 14, 14) + violations = dists_masks * ((dists < atom14_dists_lower_bound) | + (dists > atom14_dists_upper_bound)) + + # Compute the per atom violations. + # shape (N, 14) + per_atom_violations = jnp.maximum( + jnp.max(violations, axis=1), jnp.max(violations, axis=2)) + + return {'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14) + 'per_atom_violations': per_atom_violations # shape (N, 14) + } + + +def find_optimal_renaming( + atom14_gt_positions: jnp.ndarray, # (N, 14, 3) + atom14_alt_gt_positions: jnp.ndarray, # (N, 14, 3) + atom14_atom_is_ambiguous: jnp.ndarray, # (N, 14) + atom14_gt_exists: jnp.ndarray, # (N, 14) + atom14_pred_positions: jnp.ndarray, # (N, 14, 3) + atom14_atom_exists: jnp.ndarray, # (N, 14) +) -> jnp.ndarray: # (N): + """Find optimal renaming for ground truth that maximizes LDDT. + + Jumper et al. (2021) Suppl. Alg. 26 + "renameSymmetricGroundTruthAtoms" lines 1-5 + + Args: + atom14_gt_positions: Ground truth positions in global frame of ground truth. + atom14_alt_gt_positions: Alternate ground truth positions in global frame of + ground truth with coordinates of ambiguous atoms swapped relative to + 'atom14_gt_positions'. + atom14_atom_is_ambiguous: Mask denoting whether atom is among ambiguous + atoms, see Jumper et al. (2021) Suppl. Table 3 + atom14_gt_exists: Mask denoting whether atom at positions exists in ground + truth. + atom14_pred_positions: Predicted positions of atoms in + global prediction frame + atom14_atom_exists: Mask denoting whether atom at positions exists for given + amino acid type + + Returns: + Float array of shape [N] with 1. where atom14_alt_gt_positions is closer to + prediction and 0. otherwise + """ + assert len(atom14_gt_positions.shape) == 3 + assert len(atom14_alt_gt_positions.shape) == 3 + assert len(atom14_atom_is_ambiguous.shape) == 2 + assert len(atom14_gt_exists.shape) == 2 + assert len(atom14_pred_positions.shape) == 3 + assert len(atom14_atom_exists.shape) == 2 + + # Create the pred distance matrix. + # shape (N, N, 14, 14) + pred_dists = jnp.sqrt(1e-10 + jnp.sum( + squared_difference( + atom14_pred_positions[:, None, :, None, :], + atom14_pred_positions[None, :, None, :, :]), + axis=-1)) + + # Compute distances for ground truth with original and alternative names. + # shape (N, N, 14, 14) + gt_dists = jnp.sqrt(1e-10 + jnp.sum( + squared_difference( + atom14_gt_positions[:, None, :, None, :], + atom14_gt_positions[None, :, None, :, :]), + axis=-1)) + alt_gt_dists = jnp.sqrt(1e-10 + jnp.sum( + squared_difference( + atom14_alt_gt_positions[:, None, :, None, :], + atom14_alt_gt_positions[None, :, None, :, :]), + axis=-1)) + + # Compute LDDT's. + # shape (N, N, 14, 14) + lddt = jnp.sqrt(1e-10 + squared_difference(pred_dists, gt_dists)) + alt_lddt = jnp.sqrt(1e-10 + squared_difference(pred_dists, alt_gt_dists)) + + # Create a mask for ambiguous atoms in rows vs. non-ambiguous atoms + # in cols. + # shape (N ,N, 14, 14) + mask = (atom14_gt_exists[:, None, :, None] * # rows + atom14_atom_is_ambiguous[:, None, :, None] * # rows + atom14_gt_exists[None, :, None, :] * # cols + (1. - atom14_atom_is_ambiguous[None, :, None, :])) # cols + + # Aggregate distances for each residue to the non-amibuguous atoms. + # shape (N) + per_res_lddt = jnp.sum(mask * lddt, axis=[1, 2, 3]) + alt_per_res_lddt = jnp.sum(mask * alt_lddt, axis=[1, 2, 3]) + + # Decide for each residue, whether alternative naming is better. + # shape (N) + alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).astype(jnp.float32) + + return alt_naming_is_better # shape (N) + + +def frame_aligned_point_error( + pred_frames: r3.Rigids, # shape (num_frames) + target_frames: r3.Rigids, # shape (num_frames) + frames_mask: jnp.ndarray, # shape (num_frames) + pred_positions: r3.Vecs, # shape (num_positions) + target_positions: r3.Vecs, # shape (num_positions) + positions_mask: jnp.ndarray, # shape (num_positions) + length_scale: float, + l1_clamp_distance: Optional[float] = None, + epsilon=1e-4) -> jnp.ndarray: # shape () + """Measure point error under different alignments. + + Jumper et al. (2021) Suppl. Alg. 28 "computeFAPE" + + Computes error between two structures with B points under A alignments derived + from the given pairs of frames. + Args: + pred_frames: num_frames reference frames for 'pred_positions'. + target_frames: num_frames reference frames for 'target_positions'. + frames_mask: Mask for frame pairs to use. + pred_positions: num_positions predicted positions of the structure. + target_positions: num_positions target positions of the structure. + positions_mask: Mask on which positions to score. + length_scale: length scale to divide loss by. + l1_clamp_distance: Distance cutoff on error beyond which gradients will + be zero. + epsilon: small value used to regularize denominator for masked average. + Returns: + Masked Frame Aligned Point Error. + """ + assert pred_frames.rot.xx.ndim == 1 + assert target_frames.rot.xx.ndim == 1 + assert frames_mask.ndim == 1, frames_mask.ndim + assert pred_positions.x.ndim == 1 + assert target_positions.x.ndim == 1 + assert positions_mask.ndim == 1 + + # Compute array of predicted positions in the predicted frames. + # r3.Vecs (num_frames, num_positions) + local_pred_pos = r3.rigids_mul_vecs( + jax.tree_map(lambda r: r[:, None], r3.invert_rigids(pred_frames)), + jax.tree_map(lambda x: x[None, :], pred_positions)) + + # Compute array of target positions in the target frames. + # r3.Vecs (num_frames, num_positions) + local_target_pos = r3.rigids_mul_vecs( + jax.tree_map(lambda r: r[:, None], r3.invert_rigids(target_frames)), + jax.tree_map(lambda x: x[None, :], target_positions)) + + # Compute errors between the structures. + # jnp.ndarray (num_frames, num_positions) + error_dist = jnp.sqrt( + r3.vecs_squared_distance(local_pred_pos, local_target_pos) + + epsilon) + + if l1_clamp_distance: + error_dist = jnp.clip(error_dist, 0, l1_clamp_distance) + + normed_error = error_dist / length_scale + normed_error *= jnp.expand_dims(frames_mask, axis=-1) + normed_error *= jnp.expand_dims(positions_mask, axis=-2) + + mask = (jnp.expand_dims(frames_mask, axis=-1) * + jnp.expand_dims(positions_mask, axis=-2)) + normalization_factor = jnp.sum(mask, axis=(-1, -2)) + return (jnp.sum(normed_error, axis=(-2, -1)) / + (epsilon + normalization_factor)) + + +def _make_renaming_matrices(): + """Matrices to map atoms to symmetry partners in ambiguous case.""" + # As the atom naming is ambiguous for 7 of the 20 amino acids, provide + # alternative groundtruth coordinates where the naming is swapped + restype_3 = [ + residue_constants.restype_1to3[res] for res in residue_constants.restypes + ] + restype_3 += ['UNK'] + # Matrices for renaming ambiguous atoms. + all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3} + for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): + correspondences = np.arange(14) + for source_atom_swap, target_atom_swap in swap.items(): + source_index = residue_constants.restype_name_to_atom14_names[ + resname].index(source_atom_swap) + target_index = residue_constants.restype_name_to_atom14_names[ + resname].index(target_atom_swap) + correspondences[source_index] = target_index + correspondences[target_index] = source_index + renaming_matrix = np.zeros((14, 14), dtype=np.float32) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1. + all_matrices[resname] = renaming_matrix.astype(np.float32) + renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3]) + return renaming_matrices + + +RENAMING_MATRICES = _make_renaming_matrices() + + +def get_alt_atom14(aatype, positions, mask): + """Get alternative atom14 positions. + + Constructs renamed atom positions for ambiguous residues. + + Jumper et al. (2021) Suppl. Table 3 "Ambiguous atom names due to 180 degree- + rotation-symmetry" + + Args: + aatype: Amino acid at given position + positions: Atom positions as r3.Vecs in atom14 representation, (N, 14) + mask: Atom masks in atom14 representation, (N, 14) + Returns: + renamed atom positions, renamed atom mask + """ + # pick the transformation matrices for the given residue sequence + # shape (num_res, 14, 14) + renaming_transform = utils.batched_gather( + jnp.asarray(RENAMING_MATRICES), aatype) + + positions = jax.tree_map(lambda x: x[:, :, None], positions) + alternative_positions = jax.tree_map( + lambda x: jnp.sum(x, axis=1), positions * renaming_transform) + + # Create the mask for the alternative ground truth (differs from the + # ground truth mask, if only one of the atoms in an ambiguous pair has a + # ground truth position) + alternative_mask = jnp.sum(mask[..., None] * renaming_transform, axis=1) + + return alternative_positions, alternative_mask diff --git a/build/lib/colabdesign/af/alphafold/model/all_atom_multimer.py b/build/lib/colabdesign/af/alphafold/model/all_atom_multimer.py new file mode 100644 index 00000000..fc45c30a --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/all_atom_multimer.py @@ -0,0 +1,966 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Ops for all atom representations.""" + +from typing import Dict, Text + +from colabdesign.af.alphafold.common import residue_constants +from colabdesign.af.alphafold.model import geometry +from colabdesign.af.alphafold.model import utils + +import jax +import jax.numpy as jnp +import numpy as np + +def squared_difference(x, y): + return jnp.square(x - y) + +def _make_chi_atom_indices(): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are + in the order specified in residue_constants.restypes + unknown residue type + at the end. For chi angles which are not defined on the residue, the + positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in residue_constants.restypes: + residue_name = residue_constants.restype_1to3[residue_name] + residue_chi_angles = residue_constants.chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append( + [residue_constants.atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return np.array(chi_atom_indices) + + +def _make_renaming_matrices(): + """Matrices to map atoms to symmetry partners in ambiguous case.""" + # As the atom naming is ambiguous for 7 of the 20 amino acids, provide + # alternative groundtruth coordinates where the naming is swapped + restype_3 = [ + residue_constants.restype_1to3[res] for res in residue_constants.restypes + ] + restype_3 += ['UNK'] + # Matrices for renaming ambiguous atoms. + all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3} + for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): + correspondences = np.arange(14) + for source_atom_swap, target_atom_swap in swap.items(): + source_index = residue_constants.restype_name_to_atom14_names[ + resname].index(source_atom_swap) + target_index = residue_constants.restype_name_to_atom14_names[ + resname].index(target_atom_swap) + correspondences[source_index] = target_index + correspondences[target_index] = source_index + renaming_matrix = np.zeros((14, 14), dtype=np.float32) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1. + all_matrices[resname] = renaming_matrix.astype(np.float32) + renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3]) + return renaming_matrices + + +def _make_restype_atom37_mask(): + """Mask of which atoms are present for which residue type in atom37.""" + # create the corresponding mask + restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) + for restype, restype_letter in enumerate(residue_constants.restypes): + restype_name = residue_constants.restype_1to3[restype_letter] + atom_names = residue_constants.residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = residue_constants.atom_order[atom_name] + restype_atom37_mask[restype, atom_type] = 1 + return restype_atom37_mask + + +def _make_restype_atom14_mask(): + """Mask of which atoms are present for which residue type in atom14.""" + restype_atom14_mask = [] + + for rt in residue_constants.restypes: + atom_names = residue_constants.restype_name_to_atom14_names[ + residue_constants.restype_1to3[rt]] + restype_atom14_mask.append([(1. if name else 0.) for name in atom_names]) + + restype_atom14_mask.append([0.] * 14) + restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32) + return restype_atom14_mask + + +def _make_restype_atom37_to_atom14(): + """Map from atom37 to atom14 per residue type.""" + restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14 + for rt in residue_constants.restypes: + atom_names = residue_constants.restype_name_to_atom14_names[ + residue_constants.restype_1to3[rt]] + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append([ + (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in residue_constants.atom_types + ]) + + restype_atom37_to_atom14.append([0] * 37) + restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32) + return restype_atom37_to_atom14 + + +def _make_restype_atom14_to_atom37(): + """Map from atom14 to atom37 per residue type.""" + restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37 + for rt in residue_constants.restypes: + atom_names = residue_constants.restype_name_to_atom14_names[ + residue_constants.restype_1to3[rt]] + restype_atom14_to_atom37.append([ + (residue_constants.atom_order[name] if name else 0) + for name in atom_names + ]) + # Add dummy mapping for restype 'UNK' + restype_atom14_to_atom37.append([0] * 14) + restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32) + return restype_atom14_to_atom37 + + +def _make_restype_atom14_is_ambiguous(): + """Mask which atoms are ambiguous in atom14.""" + # create an ambiguous atoms mask. shape: (21, 14) + restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32) + for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): + for atom_name1, atom_name2 in swap.items(): + restype = residue_constants.restype_order[ + residue_constants.restype_3to1[resname]] + atom_idx1 = residue_constants.restype_name_to_atom14_names[resname].index( + atom_name1) + atom_idx2 = residue_constants.restype_name_to_atom14_names[resname].index( + atom_name2) + restype_atom14_is_ambiguous[restype, atom_idx1] = 1 + restype_atom14_is_ambiguous[restype, atom_idx2] = 1 + + return restype_atom14_is_ambiguous + + +def _make_restype_rigidgroup_base_atom37_idx(): + """Create Map from rigidgroups to atom37 indices.""" + # Create an array with the atom names. + # shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3) + base_atom_names = np.full([21, 8, 3], '', dtype=object) + + # 0: backbone frame + base_atom_names[:, 0, :] = ['C', 'CA', 'N'] + + # 3: 'psi-group' + base_atom_names[:, 3, :] = ['CA', 'C', 'O'] + + # 4,5,6,7: 'chi1,2,3,4-group' + for restype, restype_letter in enumerate(residue_constants.restypes): + resname = residue_constants.restype_1to3[restype_letter] + for chi_idx in range(4): + if residue_constants.chi_angles_mask[restype][chi_idx]: + atom_names = residue_constants.chi_angles_atoms[resname][chi_idx] + base_atom_names[restype, chi_idx + 4, :] = atom_names[1:] + + # Translate atom names into atom37 indices. + lookuptable = residue_constants.atom_order.copy() + lookuptable[''] = 0 + restype_rigidgroup_base_atom37_idx = np.vectorize(lambda x: lookuptable[x])( + base_atom_names) + return restype_rigidgroup_base_atom37_idx + + +CHI_ATOM_INDICES = _make_chi_atom_indices() +RENAMING_MATRICES = _make_renaming_matrices() +RESTYPE_ATOM14_TO_ATOM37 = _make_restype_atom14_to_atom37() +RESTYPE_ATOM37_TO_ATOM14 = _make_restype_atom37_to_atom14() +RESTYPE_ATOM37_MASK = _make_restype_atom37_mask() +RESTYPE_ATOM14_MASK = _make_restype_atom14_mask() +RESTYPE_ATOM14_IS_AMBIGUOUS = _make_restype_atom14_is_ambiguous() +RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX = _make_restype_rigidgroup_base_atom37_idx() + +# Create mask for existing rigid groups. +RESTYPE_RIGIDGROUP_MASK = np.zeros([21, 8], dtype=np.float32) +RESTYPE_RIGIDGROUP_MASK[:, 0] = 1 +RESTYPE_RIGIDGROUP_MASK[:, 3] = 1 +RESTYPE_RIGIDGROUP_MASK[:20, 4:] = residue_constants.chi_angles_mask + + +def get_atom37_mask(aatype): + return utils.batched_gather(jnp.asarray(RESTYPE_ATOM37_MASK), aatype) + +def get_atom14_mask(aatype): + return utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_MASK), aatype) + +def get_atom14_is_ambiguous(aatype): + return utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_IS_AMBIGUOUS), aatype) + +def get_atom14_to_atom37_map(aatype): + return utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_TO_ATOM37), aatype) + +def get_atom37_to_atom14_map(aatype): + return utils.batched_gather(jnp.asarray(RESTYPE_ATOM37_TO_ATOM14), aatype) + +def atom14_to_atom37(atom14_data: jnp.ndarray, # (N, 14, ...) + aatype: jnp.ndarray + ) -> jnp.ndarray: # (N, 37, ...) + """Convert atom14 to atom37 representation.""" + + assert len(atom14_data.shape) in [2, 3] + idx_atom37_to_atom14 = get_atom37_to_atom14_map(aatype) + atom37_data = utils.batched_gather( + atom14_data, idx_atom37_to_atom14, batch_dims=1) + atom37_mask = get_atom37_mask(aatype) + if len(atom14_data.shape) == 2: + atom37_data *= atom37_mask + elif len(atom14_data.shape) == 3: + atom37_data *= atom37_mask[:, :, None].astype(atom37_data.dtype) + return atom37_data + + +def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask): + """Convert Atom37 positions to Atom14 positions.""" + residx_atom14_to_atom37 = utils.batched_gather( + jnp.asarray(RESTYPE_ATOM14_TO_ATOM37), aatype) + atom14_mask = utils.batched_gather( + all_atom_mask, residx_atom14_to_atom37, batch_dims=1).astype(jnp.float32) + # create a mask for known groundtruth positions + atom14_mask *= utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_MASK), aatype) + # gather the groundtruth positions + atom14_positions = jax.tree_map( + lambda x: utils.batched_gather(x, residx_atom14_to_atom37, batch_dims=1), + all_atom_pos) + atom14_positions = atom14_mask * atom14_positions + return atom14_positions, atom14_mask + + +def get_alt_atom14(aatype, positions: geometry.Vec3Array, mask): + """Get alternative atom14 positions.""" + # pick the transformation matrices for the given residue sequence + # shape (num_res, 14, 14) + renaming_transform = utils.batched_gather( + jnp.asarray(RENAMING_MATRICES), aatype) + + alternative_positions = jax.tree_map( + lambda x: jnp.sum(x, axis=1), positions[:, :, None] * renaming_transform) + + # Create the mask for the alternative ground truth (differs from the + # ground truth mask, if only one of the atoms in an ambiguous pair has a + # ground truth position) + alternative_mask = jnp.sum(mask[..., None] * renaming_transform, axis=1) + + return alternative_positions, alternative_mask + + +def atom37_to_frames( + aatype: jnp.ndarray, # (...) + all_atom_positions: geometry.Vec3Array, # (..., 37) + all_atom_mask: jnp.ndarray, # (..., 37) +) -> Dict[Text, jnp.ndarray]: + + """Computes the frames for the up to 8 rigid groups for each residue.""" + # 0: 'backbone group', + # 1: 'pre-omega-group', (empty) + # 2: 'phi-group', (currently empty, because it defines only hydrogens) + # 3: 'psi-group', + # 4,5,6,7: 'chi1,2,3,4-group' + aatype_in_shape = aatype.shape + + # If there is a batch axis, just flatten it away, and reshape everything + # back at the end of the function. + aatype = jnp.reshape(aatype, [-1]) + all_atom_positions = jax.tree_map(lambda x: jnp.reshape(x, [-1, 37]), + all_atom_positions) + all_atom_mask = jnp.reshape(all_atom_mask, [-1, 37]) + + # Compute the gather indices for all residues in the chain. + # shape (N, 8, 3) + residx_rigidgroup_base_atom37_idx = utils.batched_gather( + RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX, aatype) + + # Gather the base atom positions for each rigid group. + base_atom_pos = jax.tree_map( + lambda x: utils.batched_gather( # pylint: disable=g-long-lambda + x, residx_rigidgroup_base_atom37_idx, batch_dims=1), + all_atom_positions) + + # Compute the Rigids. + point_on_neg_x_axis = base_atom_pos[:, :, 0] + origin = base_atom_pos[:, :, 1] + point_on_xy_plane = base_atom_pos[:, :, 2] + gt_rotation = geometry.Rot3Array.from_two_vectors( + origin - point_on_neg_x_axis, point_on_xy_plane - origin) + + gt_frames = geometry.Rigid3Array(gt_rotation, origin) + + # Compute a mask whether the group exists. + # (N, 8) + group_exists = utils.batched_gather(RESTYPE_RIGIDGROUP_MASK, aatype) + + # Compute a mask whether ground truth exists for the group + gt_atoms_exist = utils.batched_gather( # shape (N, 8, 3) + all_atom_mask.astype(jnp.float32), + residx_rigidgroup_base_atom37_idx, + batch_dims=1) + gt_exists = jnp.min(gt_atoms_exist, axis=-1) * group_exists # (N, 8) + + # Adapt backbone frame to old convention (mirror x-axis and z-axis). + rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1]) + rots[0, 0, 0] = -1 + rots[0, 2, 2] = -1 + gt_frames = gt_frames.compose_rotation( + geometry.Rot3Array.from_array(rots)) + + # The frames for ambiguous rigid groups are just rotated by 180 degree around + # the x-axis. The ambiguous group is always the last chi-group. + restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32) + restype_rigidgroup_rots = np.tile(np.eye(3, dtype=np.float32), [21, 8, 1, 1]) + + for resname, _ in residue_constants.residue_atom_renaming_swaps.items(): + restype = residue_constants.restype_order[ + residue_constants.restype_3to1[resname]] + chi_idx = int(sum(residue_constants.chi_angles_mask[restype]) - 1) + restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1 + restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1 + restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1 + + # Gather the ambiguity information for each residue. + residx_rigidgroup_is_ambiguous = utils.batched_gather( + restype_rigidgroup_is_ambiguous, aatype) + ambiguity_rot = utils.batched_gather(restype_rigidgroup_rots, aatype) + ambiguity_rot = geometry.Rot3Array.from_array(ambiguity_rot) + + # Create the alternative ground truth frames. + alt_gt_frames = gt_frames.compose_rotation(ambiguity_rot) + + fix_shape = lambda x: jnp.reshape(x, aatype_in_shape + (8,)) + + # reshape back to original residue layout + gt_frames = jax.tree_map(fix_shape, gt_frames) + gt_exists = fix_shape(gt_exists) + group_exists = fix_shape(group_exists) + residx_rigidgroup_is_ambiguous = fix_shape(residx_rigidgroup_is_ambiguous) + alt_gt_frames = jax.tree_map(fix_shape, alt_gt_frames) + + return { + 'rigidgroups_gt_frames': gt_frames, # Rigid (..., 8) + 'rigidgroups_gt_exists': gt_exists, # (..., 8) + 'rigidgroups_group_exists': group_exists, # (..., 8) + 'rigidgroups_group_is_ambiguous': + residx_rigidgroup_is_ambiguous, # (..., 8) + 'rigidgroups_alt_gt_frames': alt_gt_frames, # Rigid (..., 8) + } + + +def torsion_angles_to_frames( + aatype: jnp.ndarray, # (N) + backb_to_global: geometry.Rigid3Array, # (N) + torsion_angles_sin_cos: jnp.ndarray # (N, 7, 2) +) -> geometry.Rigid3Array: # (N, 8) + """Compute rigid group frames from torsion angles.""" + + assert len(aatype.shape) == 1, ( + f'Expected array of rank 1, got array with shape: {aatype.shape}.') + assert len(backb_to_global.rotation.shape) == 1, ( + f'Expected array of rank 1, got array with shape: ' + f'{backb_to_global.rotation.shape}') + assert len(torsion_angles_sin_cos.shape) == 3, ( + f'Expected array of rank 3, got array with shape: ' + f'{torsion_angles_sin_cos.shape}') + assert torsion_angles_sin_cos.shape[1] == 7, ( + f'wrong shape {torsion_angles_sin_cos.shape}') + assert torsion_angles_sin_cos.shape[2] == 2, ( + f'wrong shape {torsion_angles_sin_cos.shape}') + + # Gather the default frames for all rigid groups. + # geometry.Rigid3Array with shape (N, 8) + m = utils.batched_gather(residue_constants.restype_rigid_group_default_frame, + aatype) + default_frames = geometry.Rigid3Array.from_array4x4(m) + + # Create the rotation matrices according to the given angles (each frame is + # defined such that its rotation is around the x-axis). + sin_angles = torsion_angles_sin_cos[..., 0] + cos_angles = torsion_angles_sin_cos[..., 1] + + # insert zero rotation for backbone group. + num_residues, = aatype.shape + sin_angles = jnp.concatenate([jnp.zeros([num_residues, 1]), sin_angles], + axis=-1) + cos_angles = jnp.concatenate([jnp.ones([num_residues, 1]), cos_angles], + axis=-1) + zeros = jnp.zeros_like(sin_angles) + ones = jnp.ones_like(sin_angles) + + # all_rots are geometry.Rot3Array with shape (N, 8) + all_rots = geometry.Rot3Array(ones, zeros, zeros, + zeros, cos_angles, -sin_angles, + zeros, sin_angles, cos_angles) + + # Apply rotations to the frames. + all_frames = default_frames.compose_rotation(all_rots) + + # chi2, chi3, and chi4 frames do not transform to the backbone frame but to + # the previous frame. So chain them up accordingly. + + chi1_frame_to_backb = all_frames[:, 4] + chi2_frame_to_backb = chi1_frame_to_backb @ all_frames[:, 5] + chi3_frame_to_backb = chi2_frame_to_backb @ all_frames[:, 6] + chi4_frame_to_backb = chi3_frame_to_backb @ all_frames[:, 7] + + all_frames_to_backb = jax.tree_map( + lambda *x: jnp.concatenate(x, axis=-1), all_frames[:, 0:5], + chi2_frame_to_backb[:, None], chi3_frame_to_backb[:, None], + chi4_frame_to_backb[:, None]) + + # Create the global frames. + # shape (N, 8) + all_frames_to_global = backb_to_global[:, None] @ all_frames_to_backb + + return all_frames_to_global + + +def frames_and_literature_positions_to_atom14_pos( + aatype: jnp.ndarray, # (N) + all_frames_to_global: geometry.Rigid3Array # (N, 8) +) -> geometry.Vec3Array: # (N, 14) + """Put atom literature positions (atom14 encoding) in each rigid group.""" + + # Pick the appropriate transform for every atom. + residx_to_group_idx = utils.batched_gather( + residue_constants.restype_atom14_to_rigid_group, aatype) + group_mask = jax.nn.one_hot( + residx_to_group_idx, num_classes=8) # shape (N, 14, 8) + + # geometry.Rigid3Array with shape (N, 14) + map_atoms_to_global = jax.tree_map( + lambda x: jnp.sum(x[:, None, :] * group_mask, axis=-1), + all_frames_to_global) + + # Gather the literature atom positions for each residue. + # geometry.Vec3Array with shape (N, 14) + lit_positions = geometry.Vec3Array.from_array( + utils.batched_gather( + residue_constants.restype_atom14_rigid_group_positions, aatype)) + + # Transform each atom from its local frame to the global frame. + # geometry.Vec3Array with shape (N, 14) + pred_positions = map_atoms_to_global.apply_to_point(lit_positions) + + # Mask out non-existing atoms. + mask = utils.batched_gather(residue_constants.restype_atom14_mask, aatype) + pred_positions = pred_positions * mask + + return pred_positions + + +def extreme_ca_ca_distance_violations( + positions: geometry.Vec3Array, # (N, 37(14)) + mask: jnp.ndarray, # (N, 37(14)) + residue_index: jnp.ndarray, # (N) + max_angstrom_tolerance=1.5 + ) -> jnp.ndarray: + """Counts residues whose Ca is a large distance from its neighbor.""" + this_ca_pos = positions[:-1, 1] # (N - 1,) + this_ca_mask = mask[:-1, 1] # (N - 1) + next_ca_pos = positions[1:, 1] # (N - 1,) + next_ca_mask = mask[1:, 1] # (N - 1) + has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype( + jnp.float32) + ca_ca_distance = geometry.euclidean_distance(this_ca_pos, next_ca_pos, 1e-6) + violations = (ca_ca_distance - + residue_constants.ca_ca) > max_angstrom_tolerance + mask = this_ca_mask * next_ca_mask * has_no_gap_mask + return utils.mask_mean(mask=mask, value=violations) + + +def between_residue_bond_loss( + pred_atom_positions: geometry.Vec3Array, # (N, 37(14)) + pred_atom_mask: jnp.ndarray, # (N, 37(14)) + residue_index: jnp.ndarray, # (N) + aatype: jnp.ndarray, # (N) + tolerance_factor_soft=12.0, + tolerance_factor_hard=12.0) -> Dict[Text, jnp.ndarray]: + """Flat-bottom loss to penalize structural violations between residues.""" + + assert len(pred_atom_positions.shape) == 2 + assert len(pred_atom_mask.shape) == 2 + assert len(residue_index.shape) == 1 + assert len(aatype.shape) == 1 + + # Get the positions of the relevant backbone atoms. + this_ca_pos = pred_atom_positions[:-1, 1] # (N - 1) + this_ca_mask = pred_atom_mask[:-1, 1] # (N - 1) + this_c_pos = pred_atom_positions[:-1, 2] # (N - 1) + this_c_mask = pred_atom_mask[:-1, 2] # (N - 1) + next_n_pos = pred_atom_positions[1:, 0] # (N - 1) + next_n_mask = pred_atom_mask[1:, 0] # (N - 1) + next_ca_pos = pred_atom_positions[1:, 1] # (N - 1) + next_ca_mask = pred_atom_mask[1:, 1] # (N - 1) + has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype( + jnp.float32) + + # Compute loss for the C--N bond. + c_n_bond_length = geometry.euclidean_distance(this_c_pos, next_n_pos, 1e-6) + + # The C-N bond to proline has slightly different length because of the ring. + next_is_proline = ( + aatype[1:] == residue_constants.restype_order['P']).astype(jnp.float32) + gt_length = ( + (1. - next_is_proline) * residue_constants.between_res_bond_length_c_n[0] + + next_is_proline * residue_constants.between_res_bond_length_c_n[1]) + gt_stddev = ( + (1. - next_is_proline) * + residue_constants.between_res_bond_length_stddev_c_n[0] + + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[1]) + c_n_bond_length_error = jnp.sqrt(1e-6 + + jnp.square(c_n_bond_length - gt_length)) + c_n_loss_per_residue = jax.nn.relu( + c_n_bond_length_error - tolerance_factor_soft * gt_stddev) + mask = this_c_mask * next_n_mask * has_no_gap_mask + c_n_loss = jnp.sum(mask * c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6) + c_n_violation_mask = mask * ( + c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)) + + # Compute loss for the angles. + c_ca_unit_vec = (this_ca_pos - this_c_pos).normalized(1e-6) + c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length + n_ca_unit_vec = (next_ca_pos - next_n_pos).normalized(1e-6) + + ca_c_n_cos_angle = c_ca_unit_vec.dot(c_n_unit_vec) + gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0] + gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0] + ca_c_n_cos_angle_error = jnp.sqrt( + 1e-6 + jnp.square(ca_c_n_cos_angle - gt_angle)) + ca_c_n_loss_per_residue = jax.nn.relu( + ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev) + mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask + ca_c_n_loss = jnp.sum(mask * ca_c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6) + ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error > + (tolerance_factor_hard * gt_stddev)) + + c_n_ca_cos_angle = (-c_n_unit_vec).dot(n_ca_unit_vec) + gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0] + gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1] + c_n_ca_cos_angle_error = jnp.sqrt( + 1e-6 + jnp.square(c_n_ca_cos_angle - gt_angle)) + c_n_ca_loss_per_residue = jax.nn.relu( + c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev) + mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask + c_n_ca_loss = jnp.sum(mask * c_n_ca_loss_per_residue) / (jnp.sum(mask) + 1e-6) + c_n_ca_violation_mask = mask * ( + c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)) + + # Compute a per residue loss (equally distribute the loss to both + # neighbouring residues). + per_residue_loss_sum = (c_n_loss_per_residue + + ca_c_n_loss_per_residue + + c_n_ca_loss_per_residue) + per_residue_loss_sum = 0.5 * (jnp.pad(per_residue_loss_sum, [[0, 1]]) + + jnp.pad(per_residue_loss_sum, [[1, 0]])) + + # Compute hard violations. + violation_mask = jnp.max( + jnp.stack([c_n_violation_mask, + ca_c_n_violation_mask, + c_n_ca_violation_mask]), axis=0) + violation_mask = jnp.maximum( + jnp.pad(violation_mask, [[0, 1]]), + jnp.pad(violation_mask, [[1, 0]])) + + return {'c_n_loss_mean': c_n_loss, # shape () + 'ca_c_n_loss_mean': ca_c_n_loss, # shape () + 'c_n_ca_loss_mean': c_n_ca_loss, # shape () + 'per_residue_loss_sum': per_residue_loss_sum, # shape (N) + 'per_residue_violation_mask': violation_mask # shape (N) + } + + +def between_residue_clash_loss( + pred_positions: geometry.Vec3Array, # (N, 14) + atom_exists: jnp.ndarray, # (N, 14) + atom_radius: jnp.ndarray, # (N, 14) + residue_index: jnp.ndarray, # (N) + asym_id: jnp.ndarray, # (N) + overlap_tolerance_soft=1.5, + overlap_tolerance_hard=1.5) -> Dict[Text, jnp.ndarray]: + """Loss to penalize steric clashes between residues.""" + assert len(pred_positions.shape) == 2 + assert len(atom_exists.shape) == 2 + assert len(atom_radius.shape) == 2 + assert len(residue_index.shape) == 1 + + # Create the distance matrix. + # (N, N, 14, 14) + dists = geometry.euclidean_distance(pred_positions[:, None, :, None], + pred_positions[None, :, None, :], 1e-10) + + # Create the mask for valid distances. + # shape (N, N, 14, 14) + dists_mask = (atom_exists[:, None, :, None] * atom_exists[None, :, None, :]) + + # Mask out all the duplicate entries in the lower triangular matrix. + # Also mask out the diagonal (atom-pairs from the same residue) -- these atoms + # are handled separately. + dists_mask *= ( + residue_index[:, None, None, None] < residue_index[None, :, None, None]) + + # Backbone C--N bond between subsequent residues is no clash. + c_one_hot = jax.nn.one_hot(2, num_classes=14) + n_one_hot = jax.nn.one_hot(0, num_classes=14) + neighbour_mask = ((residue_index[:, None] + 1) == residue_index[None, :]) + neighbour_mask &= (asym_id[:, None] == asym_id[None, :]) + neighbour_mask = neighbour_mask[..., None, None] + c_n_bonds = neighbour_mask * c_one_hot[None, None, :, + None] * n_one_hot[None, None, None, :] + dists_mask *= (1. - c_n_bonds) + + # Disulfide bridge between two cysteines is no clash. + cys_sg_idx = residue_constants.restype_name_to_atom14_names['CYS'].index('SG') + cys_sg_one_hot = jax.nn.one_hot(cys_sg_idx, num_classes=14) + disulfide_bonds = (cys_sg_one_hot[None, None, :, None] * + cys_sg_one_hot[None, None, None, :]) + dists_mask *= (1. - disulfide_bonds) + + # Compute the lower bound for the allowed distances. + # shape (N, N, 14, 14) + dists_lower_bound = dists_mask * ( + atom_radius[:, None, :, None] + atom_radius[None, :, None, :]) + + # Compute the error. + # shape (N, N, 14, 14) + dists_to_low_error = dists_mask * jax.nn.relu( + dists_lower_bound - overlap_tolerance_soft - dists) + + # Compute the mean loss. + # shape () + mean_loss = (jnp.sum(dists_to_low_error) + / (1e-6 + jnp.sum(dists_mask))) + + # Compute the per atom loss sum. + # shape (N, 14) + per_atom_loss_sum = (jnp.sum(dists_to_low_error, axis=[0, 2]) + + jnp.sum(dists_to_low_error, axis=[1, 3])) + + # Compute the hard clash mask. + # shape (N, N, 14, 14) + clash_mask = dists_mask * ( + dists < (dists_lower_bound - overlap_tolerance_hard)) + + # Compute the per atom clash. + # shape (N, 14) + per_atom_clash_mask = jnp.maximum( + jnp.max(clash_mask, axis=[0, 2]), + jnp.max(clash_mask, axis=[1, 3])) + + return {'mean_loss': mean_loss, # shape () + 'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14) + 'per_atom_clash_mask': per_atom_clash_mask # shape (N, 14) + } + + +def within_residue_violations( + pred_positions: geometry.Vec3Array, # (N, 14) + atom_exists: jnp.ndarray, # (N, 14) + dists_lower_bound: jnp.ndarray, # (N, 14, 14) + dists_upper_bound: jnp.ndarray, # (N, 14, 14) + tighten_bounds_for_loss=0.0, +) -> Dict[Text, jnp.ndarray]: + """Find within-residue violations.""" + assert len(pred_positions.shape) == 2 + assert len(atom_exists.shape) == 2 + assert len(dists_lower_bound.shape) == 3 + assert len(dists_upper_bound.shape) == 3 + + # Compute the mask for each residue. + # shape (N, 14, 14) + dists_masks = (1. - jnp.eye(14, 14)[None]) + dists_masks *= (atom_exists[:, :, None] * atom_exists[:, None, :]) + + # Distance matrix + # shape (N, 14, 14) + dists = geometry.euclidean_distance(pred_positions[:, :, None], + pred_positions[:, None, :], 1e-10) + + # Compute the loss. + # shape (N, 14, 14) + dists_to_low_error = jax.nn.relu( + dists_lower_bound + tighten_bounds_for_loss - dists) + dists_to_high_error = jax.nn.relu( + dists + tighten_bounds_for_loss - dists_upper_bound) + loss = dists_masks * (dists_to_low_error + dists_to_high_error) + + # Compute the per atom loss sum. + # shape (N, 14) + per_atom_loss_sum = (jnp.sum(loss, axis=1) + + jnp.sum(loss, axis=2)) + + # Compute the violations mask. + # shape (N, 14, 14) + violations = dists_masks * ((dists < dists_lower_bound) | + (dists > dists_upper_bound)) + + # Compute the per atom violations. + # shape (N, 14) + per_atom_violations = jnp.maximum( + jnp.max(violations, axis=1), jnp.max(violations, axis=2)) + + return {'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14) + 'per_atom_violations': per_atom_violations # shape (N, 14) + } + + +def find_optimal_renaming( + gt_positions: geometry.Vec3Array, # (N, 14) + alt_gt_positions: geometry.Vec3Array, # (N, 14) + atom_is_ambiguous: jnp.ndarray, # (N, 14) + gt_exists: jnp.ndarray, # (N, 14) + pred_positions: geometry.Vec3Array, # (N, 14) +) -> jnp.ndarray: # (N): + """Find optimal renaming for ground truth that maximizes LDDT.""" + assert len(gt_positions.shape) == 2 + assert len(alt_gt_positions.shape) == 2 + assert len(atom_is_ambiguous.shape) == 2 + assert len(gt_exists.shape) == 2 + assert len(pred_positions.shape) == 2 + + # Create the pred distance matrix. + # shape (N, N, 14, 14) + pred_dists = geometry.euclidean_distance(pred_positions[:, None, :, None], + pred_positions[None, :, None, :], + 1e-10) + + # Compute distances for ground truth with original and alternative names. + # shape (N, N, 14, 14) + gt_dists = geometry.euclidean_distance(gt_positions[:, None, :, None], + gt_positions[None, :, None, :], 1e-10) + + alt_gt_dists = geometry.euclidean_distance(alt_gt_positions[:, None, :, None], + alt_gt_positions[None, :, None, :], + 1e-10) + + # Compute LDDT's. + # shape (N, N, 14, 14) + lddt = jnp.sqrt(1e-10 + squared_difference(pred_dists, gt_dists)) + alt_lddt = jnp.sqrt(1e-10 + squared_difference(pred_dists, alt_gt_dists)) + + # Create a mask for ambiguous atoms in rows vs. non-ambiguous atoms + # in cols. + # shape (N ,N, 14, 14) + mask = ( + gt_exists[:, None, :, None] * # rows + atom_is_ambiguous[:, None, :, None] * # rows + gt_exists[None, :, None, :] * # cols + (1. - atom_is_ambiguous[None, :, None, :])) # cols + + # Aggregate distances for each residue to the non-amibuguous atoms. + # shape (N) + per_res_lddt = jnp.sum(mask * lddt, axis=[1, 2, 3]) + alt_per_res_lddt = jnp.sum(mask * alt_lddt, axis=[1, 2, 3]) + + # Decide for each residue, whether alternative naming is better. + # shape (N) + alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).astype(jnp.float32) + + return alt_naming_is_better # shape (N) + + +def frame_aligned_point_error( + pred_frames: geometry.Rigid3Array, # shape (num_frames) + target_frames: geometry.Rigid3Array, # shape (num_frames) + frames_mask: jnp.ndarray, # shape (num_frames) + pred_positions: geometry.Vec3Array, # shape (num_positions) + target_positions: geometry.Vec3Array, # shape (num_positions) + positions_mask: jnp.ndarray, # shape (num_positions) + pair_mask: jnp.ndarray, # shape (num_frames, num_posiitons) + l1_clamp_distance: float, + length_scale=20., + epsilon=1e-4) -> jnp.ndarray: # shape () + """Measure point error under different alignements. + + Computes error between two structures with B points + under A alignments derived form the given pairs of frames. + Args: + pred_frames: num_frames reference frames for 'pred_positions'. + target_frames: num_frames reference frames for 'target_positions'. + frames_mask: Mask for frame pairs to use. + pred_positions: num_positions predicted positions of the structure. + target_positions: num_positions target positions of the structure. + positions_mask: Mask on which positions to score. + pair_mask: A (num_frames, num_positions) mask to use in the loss, useful + for separating intra from inter chain losses. + l1_clamp_distance: Distance cutoff on error beyond which gradients will + be zero. + length_scale: length scale to divide loss by. + epsilon: small value used to regularize denominator for masked average. + Returns: + Masked Frame aligned point error. + """ + # For now we do not allow any batch dimensions. + assert len(pred_frames.rotation.shape) == 1 + assert len(target_frames.rotation.shape) == 1 + assert frames_mask.ndim == 1 + assert pred_positions.x.ndim == 1 + assert target_positions.x.ndim == 1 + assert positions_mask.ndim == 1 + + # Compute array of predicted positions in the predicted frames. + # geometry.Vec3Array (num_frames, num_positions) + local_pred_pos = pred_frames[:, None].inverse().apply_to_point( + pred_positions[None, :]) + + # Compute array of target positions in the target frames. + # geometry.Vec3Array (num_frames, num_positions) + local_target_pos = target_frames[:, None].inverse().apply_to_point( + target_positions[None, :]) + + # Compute errors between the structures. + # jnp.ndarray (num_frames, num_positions) + error_dist = geometry.euclidean_distance(local_pred_pos, local_target_pos, + epsilon) + + clipped_error_dist = jnp.clip(error_dist, 0, l1_clamp_distance) + + normed_error = clipped_error_dist / length_scale + normed_error *= jnp.expand_dims(frames_mask, axis=-1) + normed_error *= jnp.expand_dims(positions_mask, axis=-2) + if pair_mask is not None: + normed_error *= pair_mask + + mask = (jnp.expand_dims(frames_mask, axis=-1) * + jnp.expand_dims(positions_mask, axis=-2)) + if pair_mask is not None: + mask *= pair_mask + normalization_factor = jnp.sum(mask, axis=(-1, -2)) + return (jnp.sum(normed_error, axis=(-2, -1)) / + (epsilon + normalization_factor)) + + +def get_chi_atom_indices(): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are + in the order specified in residue_constants.restypes + unknown residue type + at the end. For chi angles which are not defined on the residue, the + positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in residue_constants.restypes: + residue_name = residue_constants.restype_1to3[residue_name] + residue_chi_angles = residue_constants.chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append( + [residue_constants.atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return jnp.asarray(chi_atom_indices) + + +def compute_chi_angles(positions: geometry.Vec3Array, + mask: geometry.Vec3Array, + aatype: geometry.Vec3Array): + """Computes the chi angles given all atom positions and the amino acid type. + + Args: + positions: A Vec3Array of shape + [num_res, residue_constants.atom_type_num], with positions of + atoms needed to calculate chi angles. Supports up to 1 batch dimension. + mask: An optional tensor of shape + [num_res, residue_constants.atom_type_num] that masks which atom + positions are set for each residue. If given, then the chi mask will be + set to 1 for a chi angle only if the amino acid has that chi angle and all + the chi atoms needed to calculate that chi angle are set. If not given + (set to None), the chi mask will be set to 1 for a chi angle if the amino + acid has that chi angle and whether the actual atoms needed to calculate + it were set will be ignored. + aatype: A tensor of shape [num_res] with amino acid type integer + code (0 to 21). Supports up to 1 batch dimension. + + Returns: + A tuple of tensors (chi_angles, mask), where both have shape + [num_res, 4]. The mask masks out unused chi angles for amino acid + types that have less than 4 chi angles. If atom_positions_mask is set, the + chi mask will also mask out uncomputable chi angles. + """ + + # Don't assert on the num_res and batch dimensions as they might be unknown. + + assert positions.shape[-1] == residue_constants.atom_type_num + assert mask.shape[-1] == residue_constants.atom_type_num + + # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4]. + chi_atom_indices = get_chi_atom_indices() + # Select atoms to compute chis. Shape: [num_res, chis=4, atoms=4]. + atom_indices = utils.batched_gather( + params=chi_atom_indices, indices=aatype, axis=0) + # Gather atom positions. Shape: [num_res, chis=4, atoms=4, xyz=3]. + chi_angle_atoms = jax.tree_map( + lambda x: utils.batched_gather( # pylint: disable=g-long-lambda + params=x, indices=atom_indices, axis=-1, batch_dims=1), positions) + a, b, c, d = [chi_angle_atoms[..., i] for i in range(4)] + + chi_angles = geometry.dihedral_angle(a, b, c, d) + + # Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4]. + chi_angles_mask = list(residue_constants.chi_angles_mask) + chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) + chi_angles_mask = jnp.asarray(chi_angles_mask) + # Compute the chi angle mask. Shape [num_res, chis=4]. + chi_mask = utils.batched_gather(params=chi_angles_mask, indices=aatype, + axis=0) + + # The chi_mask is set to 1 only when all necessary chi angle atoms were set. + # Gather the chi angle atoms mask. Shape: [num_res, chis=4, atoms=4]. + chi_angle_atoms_mask = utils.batched_gather( + params=mask, indices=atom_indices, axis=-1, batch_dims=1) + # Check if all 4 chi angle atoms were set. Shape: [num_res, chis=4]. + chi_angle_atoms_mask = jnp.prod(chi_angle_atoms_mask, axis=[-1]) + chi_mask = chi_mask * chi_angle_atoms_mask.astype(jnp.float32) + + return chi_angles, chi_mask + +def make_transform_from_reference( + a_xyz: geometry.Vec3Array, + b_xyz: geometry.Vec3Array, + c_xyz: geometry.Vec3Array) -> geometry.Rigid3Array: + """Returns rotation and translation matrices to convert from reference. + + Note that this method does not take care of symmetries. If you provide the + coordinates in the non-standard way, the A atom will end up in the negative + y-axis rather than in the positive y-axis. You need to take care of such + cases in your code. + + Args: + a_xyz: A Vec3Array. + b_xyz: A Vec3Array. + c_xyz: A Vec3Array. + + Returns: + A Rigid3Array which, when applied to coordinates in a canonicalized + reference frame, will give coordinates approximately equal + the original coordinates (in the global frame). + """ + rotation = geometry.Rot3Array.from_two_vectors(c_xyz - b_xyz, + a_xyz - b_xyz) + return geometry.Rigid3Array(rotation, b_xyz) diff --git a/build/lib/colabdesign/af/alphafold/model/common_modules.py b/build/lib/colabdesign/af/alphafold/model/common_modules.py new file mode 100644 index 00000000..dcc66ab1 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/common_modules.py @@ -0,0 +1,184 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A collection of common Haiku modules for use in protein folding.""" +import numbers +from typing import Union, Sequence + +import haiku as hk +import jax.numpy as jnp +import numpy as np + + +# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) +TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray(.87962566103423978, + dtype=np.float32) + + +def get_initializer_scale(initializer_name, input_shape): + """Get Initializer for weights and scale to multiply activations by.""" + + if initializer_name == 'zeros': + w_init = hk.initializers.Constant(0.0) + else: + # fan-in scaling + scale = 1. + for channel_dim in input_shape: + scale /= channel_dim + if initializer_name == 'relu': + scale *= 2 + + noise_scale = scale + + stddev = np.sqrt(noise_scale) + # Adjust stddev for truncation. + stddev = stddev / TRUNCATED_NORMAL_STDDEV_FACTOR + w_init = hk.initializers.TruncatedNormal(mean=0.0, stddev=stddev) + + return w_init + + +class Linear(hk.Module): + """Protein folding specific Linear module. + This differs from the standard Haiku Linear in a few ways: + * It supports inputs and outputs of arbitrary rank + * Initializers are specified by strings + """ + + def __init__(self, + num_output: Union[int, Sequence[int]], + initializer: str = 'linear', + num_input_dims: int = 1, + use_bias: bool = True, + bias_init: float = 0., + precision = None, + name: str = 'linear'): + """Constructs Linear Module. + Args: + num_output: Number of output channels. Can be tuple when outputting + multiple dimensions. + initializer: What initializer to use, should be one of {'linear', 'relu', + 'zeros'} + num_input_dims: Number of dimensions from the end to project. + use_bias: Whether to include trainable bias + bias_init: Value used to initialize bias. + precision: What precision to use for matrix multiplication, defaults + to None. + name: Name of module, used for name scopes. + """ + super().__init__(name=name) + if isinstance(num_output, numbers.Integral): + self.output_shape = (num_output,) + else: + self.output_shape = tuple(num_output) + self.initializer = initializer + self.use_bias = use_bias + self.bias_init = bias_init + self.num_input_dims = num_input_dims + self.num_output_dims = len(self.output_shape) + self.precision = precision + + def __call__(self, inputs): + """Connects Module. + Args: + inputs: Tensor with at least num_input_dims dimensions. + Returns: + output of shape [...] + num_output. + """ + + num_input_dims = self.num_input_dims + + if self.num_input_dims > 0: + in_shape = inputs.shape[-self.num_input_dims:] + else: + in_shape = () + + weight_init = get_initializer_scale(self.initializer, in_shape) + + in_letters = 'abcde'[:self.num_input_dims] + out_letters = 'hijkl'[:self.num_output_dims] + + weight_shape = in_shape + self.output_shape + weights = hk.get_parameter('weights', weight_shape, inputs.dtype, + weight_init) + + equation = f'...{in_letters}, {in_letters}{out_letters}->...{out_letters}' + + output = jnp.einsum(equation, inputs, weights, precision=self.precision) + + if self.use_bias: + bias = hk.get_parameter('bias', self.output_shape, inputs.dtype, + hk.initializers.Constant(self.bias_init)) + output += bias + + return output + +class LayerNorm(hk.LayerNorm): + """LayerNorm module. + Equivalent to hk.LayerNorm but with different parameter shapes: they are + always vectors rather than possibly higher-rank tensors. This makes it easier + to change the layout whilst keep the model weight-compatible. + """ + + def __init__(self, + axis, + create_scale: bool, + create_offset: bool, + eps: float = 1e-5, + scale_init=None, + offset_init=None, + use_fast_variance: bool = False, + name=None, + param_axis=None): + super().__init__( + axis=axis, + create_scale=False, + create_offset=False, + eps=eps, + scale_init=None, + offset_init=None, + use_fast_variance=use_fast_variance, + name=name, + param_axis=param_axis) + self._temp_create_scale = create_scale + self._temp_create_offset = create_offset + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + is_bf16 = (x.dtype == jnp.bfloat16) + if is_bf16: + x = x.astype(jnp.float32) + + param_axis = self.param_axis[0] if self.param_axis else -1 + param_shape = (x.shape[param_axis],) + + param_broadcast_shape = [1] * x.ndim + param_broadcast_shape[param_axis] = x.shape[param_axis] + scale = None + offset = None + if self._temp_create_scale: + scale = hk.get_parameter( + 'scale', param_shape, x.dtype, init=self.scale_init) + scale = scale.reshape(param_broadcast_shape) + + if self._temp_create_offset: + offset = hk.get_parameter( + 'offset', param_shape, x.dtype, init=self.offset_init) + offset = offset.reshape(param_broadcast_shape) + + out = super().__call__(x, scale=scale, offset=offset) + + if is_bf16: + out = out.astype(jnp.bfloat16) + + return out \ No newline at end of file diff --git a/build/lib/colabdesign/af/alphafold/model/config.py b/build/lib/colabdesign/af/alphafold/model/config.py new file mode 100644 index 00000000..15eb8f17 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/config.py @@ -0,0 +1,611 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model config.""" + +import copy +from colabdesign.af.alphafold.model.tf import shape_placeholders +import ml_collections + + +NUM_RES = shape_placeholders.NUM_RES +NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ +NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ +NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES + +def model_config(name: str) -> ml_collections.ConfigDict: + """Get the ConfigDict of a CASP14 model.""" + + if 'multimer' in name: + return CONFIG_MULTIMER + + if name not in CONFIG_DIFFS: + raise ValueError(f'Invalid model name {name}.') + cfg = copy.deepcopy(CONFIG) + cfg.update_from_flattened_dict(CONFIG_DIFFS[name]) + return cfg + +CONFIG_DIFFS = { + 'model_1': { + # Jumper et al. (2021) Suppl. Table 5, Model 1.1.1 + 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, + 'model.embeddings_and_evoformer.template.enabled': True + }, + 'model_2': { + # Jumper et al. (2021) Suppl. Table 5, Model 1.1.2 + 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, + 'model.embeddings_and_evoformer.template.enabled': True + }, + 'model_3': { + # Jumper et al. (2021) Suppl. Table 5, Model 1.2.1 + }, + 'model_4': { + # Jumper et al. (2021) Suppl. Table 5, Model 1.2.2 + }, + 'model_5': { + # Jumper et al. (2021) Suppl. Table 5, Model 1.2.3 + }, + + # The following models are fine-tuned from the corresponding models above + # with an additional predicted_aligned_error head that can produce + # predicted TM-score (pTM) and predicted aligned errors. + 'model_1_ptm': { + 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, + 'model.embeddings_and_evoformer.template.enabled': True, + 'model.heads.predicted_aligned_error.weight': 0.1 + }, + 'model_2_ptm': { + 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, + 'model.embeddings_and_evoformer.template.enabled': True, + 'model.heads.predicted_aligned_error.weight': 0.1 + }, + 'model_3_ptm': { + 'model.heads.predicted_aligned_error.weight': 0.1 + }, + 'model_4_ptm': { + 'model.heads.predicted_aligned_error.weight': 0.1 + }, + 'model_5_ptm': { + 'model.heads.predicted_aligned_error.weight': 0.1 + } +} + +CONFIG = ml_collections.ConfigDict({ + 'data': { + 'eval': { + 'feat': { + 'aatype': [NUM_RES], + 'all_atom_mask': [NUM_RES, None], + 'all_atom_positions': [NUM_RES, None, None], + 'alt_chi_angles': [NUM_RES, None], + 'atom14_alt_gt_exists': [NUM_RES, None], + 'atom14_alt_gt_positions': [NUM_RES, None, None], + 'atom14_atom_exists': [NUM_RES, None], + 'atom14_atom_is_ambiguous': [NUM_RES, None], + 'atom14_gt_exists': [NUM_RES, None], + 'atom14_gt_positions': [NUM_RES, None, None], + 'atom37_atom_exists': [NUM_RES, None], + 'backbone_affine_mask': [NUM_RES], + 'backbone_affine_tensor': [NUM_RES, None], + 'bert_mask': [NUM_MSA_SEQ, NUM_RES], + 'chi_angles': [NUM_RES, None], + 'chi_mask': [NUM_RES, None], + 'extra_deletion_value': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_has_deletion': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_msa': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_msa_row_mask': [NUM_EXTRA_SEQ], + 'is_distillation': [], + 'msa_feat': [NUM_MSA_SEQ, NUM_RES, None], + 'msa_mask': [NUM_MSA_SEQ, NUM_RES], + 'msa_row_mask': [NUM_MSA_SEQ], + 'pseudo_beta': [NUM_RES, None], + 'pseudo_beta_mask': [NUM_RES], + 'random_crop_to_size_seed': [None], + 'residue_index': [NUM_RES], + 'residx_atom14_to_atom37': [NUM_RES, None], + 'residx_atom37_to_atom14': [NUM_RES, None], + 'resolution': [], + 'rigidgroups_alt_gt_frames': [NUM_RES, None, None], + 'rigidgroups_group_exists': [NUM_RES, None], + 'rigidgroups_group_is_ambiguous': [NUM_RES, None], + 'rigidgroups_gt_exists': [NUM_RES, None], + 'rigidgroups_gt_frames': [NUM_RES, None, None], + 'seq_length': [], + 'seq_mask': [NUM_RES], + 'target_feat': [NUM_RES, None], + 'template_aatype': [NUM_TEMPLATES, NUM_RES], + 'template_all_atom_mask': [NUM_TEMPLATES, NUM_RES, None], + 'template_all_atom_positions': [ + NUM_TEMPLATES, NUM_RES, None, None], + 'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES], + 'template_backbone_affine_tensor': [ + NUM_TEMPLATES, NUM_RES, None], + 'template_mask': [NUM_TEMPLATES], + 'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None], + 'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES], + 'template_sum_probs': [NUM_TEMPLATES, None], + 'true_msa': [NUM_MSA_SEQ, NUM_RES], + 'asym_id': [NUM_RES], + 'sym_id': [NUM_RES], + 'entity_id': [NUM_RES], + + # extras + 'prev_pos': [NUM_RES, None, None], + 'prev_pair': [NUM_RES, NUM_RES, None], + 'prev_msa_first_row': [NUM_RES, None], + 'rm_template': [NUM_RES], + 'rm_template_seq': [NUM_RES], + 'rm_template_sc': [NUM_RES] + }, + }, + }, + 'model': { + 'embeddings_and_evoformer': { + 'evoformer_num_block': 48, + 'evoformer': { + 'msa_row_attention_with_pair_bias': { + 'dropout_rate': 0.15, + 'gating': True, + 'num_head': 8, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'msa_column_attention': { + 'dropout_rate': 0.0, + 'gating': True, + 'num_head': 8, + 'orientation': 'per_column', + 'shared_dropout': True + }, + 'msa_transition': { + 'dropout_rate': 0.0, + 'num_intermediate_factor': 4, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'outer_product_mean': { + 'first': False, + 'chunk_size': 128, + 'dropout_rate': 0.0, + 'num_outer_channel': 32, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_attention_starting_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'num_head': 4, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_attention_ending_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'num_head': 4, + 'orientation': 'per_column', + 'shared_dropout': True + }, + 'triangle_multiplication_outgoing': { + 'dropout_rate': 0.25, + 'equation': 'ikc,jkc->ijc', + 'num_intermediate_channel': 128, + 'orientation': 'per_row', + 'shared_dropout': True, + 'fuse_projection_weights': True, + }, + 'triangle_multiplication_incoming': { + 'dropout_rate': 0.25, + 'equation': 'kjc,kic->ijc', + 'num_intermediate_channel': 128, + 'orientation': 'per_row', + 'shared_dropout': True, + 'fuse_projection_weights': True, + }, + 'pair_transition': { + 'dropout_rate': 0.0, + 'num_intermediate_factor': 4, + 'orientation': 'per_row', + 'shared_dropout': True + } + }, + 'use_msa': True, + 'use_extra_msa': True, + 'extra_msa_channel': 64, + 'extra_msa_stack_num_block': 4, + 'max_relative_feature': 32, + 'custom_relative_features': False, + 'msa_channel': 256, + 'pair_channel': 128, + 'prev_pos': { + 'min_bin': 3.25, + 'max_bin': 20.75, + 'num_bins': 15 + }, + 'recycle_features': True, + 'recycle_pos': True, + 'recycle_dgram': False, + 'backprop_dgram': False, + 'backprop_dgram_temp': 1.0, + 'seq_channel': 384, + 'template': { + 'attention': { + 'gating': False, + 'key_dim': 64, + 'num_head': 4, + 'value_dim': 64 + }, + 'dgram_features': { + 'min_bin': 3.25, + 'max_bin': 50.75, + 'num_bins': 39 + }, + 'backprop_dgram': False, + 'backprop_dgram_temp': 1.0, + 'embed_torsion_angles': False, + 'enabled': False, + 'template_pair_stack': { + 'num_block': 2, + 'triangle_attention_starting_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'key_dim': 64, + 'num_head': 4, + 'orientation': 'per_row', + 'shared_dropout': True, + 'value_dim': 64 + }, + 'triangle_attention_ending_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'key_dim': 64, + 'num_head': 4, + 'orientation': 'per_column', + 'shared_dropout': True, + 'value_dim': 64 + }, + 'triangle_multiplication_outgoing': { + 'dropout_rate': 0.25, + 'equation': 'ikc,jkc->ijc', + 'num_intermediate_channel': 64, + 'orientation': 'per_row', + 'shared_dropout': True, + 'fuse_projection_weights': True, + }, + 'triangle_multiplication_incoming': { + 'dropout_rate': 0.25, + 'equation': 'kjc,kic->ijc', + 'num_intermediate_channel': 64, + 'orientation': 'per_row', + 'shared_dropout': True, + 'fuse_projection_weights': True, + }, + 'pair_transition': { + 'dropout_rate': 0.0, + 'num_intermediate_factor': 2, + 'orientation': 'per_row', + 'shared_dropout': True + } + }, + 'subbatch_size': 128, + 'use_template_unit_vector': False, + } + }, + 'global_config': { + 'bfloat16': True, + 'bfloat16_output': False, + 'multimer_mode': False, + 'subbatch_size': 4, + 'use_remat': False, + 'zero_init': True + }, + 'heads': { + 'distogram': { + 'first_break': 2.3125, + 'last_break': 21.6875, + 'num_bins': 64, + 'weight': 0.3 + }, + 'predicted_aligned_error': { + # `num_bins - 1` bins uniformly space the + # [0, max_error_bin A] range. + # The final bin covers [max_error_bin A, +infty] + # 31A gives bins with 0.5A width. + 'max_error_bin': 31., + 'num_bins': 64, + 'num_channels': 128, + 'filter_by_resolution': True, + 'min_resolution': 0.1, + 'max_resolution': 3.0, + 'weight': 0.0, + }, + 'experimentally_resolved': { + 'filter_by_resolution': True, + 'max_resolution': 3.0, + 'min_resolution': 0.1, + 'weight': 0.01 + }, + 'structure_module': { + 'num_layer': 8, + 'fape': { + 'clamp_distance': 10.0, + 'clamp_type': 'relu', + 'loss_unit_distance': 10.0 + }, + 'angle_norm_weight': 0.01, + 'chi_weight': 0.5, + 'clash_overlap_tolerance': 1.5, + 'compute_in_graph_metrics': True, + 'dropout': 0.1, + 'num_channel': 384, + 'num_head': 12, + 'num_layer_in_transition': 3, + 'num_point_qk': 4, + 'num_point_v': 8, + 'num_scalar_qk': 16, + 'num_scalar_v': 16, + 'position_scale': 10.0, + 'sidechain': { + 'atom_clamp_distance': 10.0, + 'num_channel': 128, + 'num_residual_block': 2, + 'weight_frac': 0.5, + 'length_scale': 10., + }, + 'structural_violation_loss_weight': 1.0, + 'violation_tolerance_factor': 12.0, + 'weight': 1.0 + }, + 'predicted_lddt': { + 'filter_by_resolution': True, + 'max_resolution': 3.0, + 'min_resolution': 0.1, + 'num_bins': 50, + 'num_channels': 128, + 'weight': 0.01 + }, + 'masked_msa': { + 'num_output': 23, + 'weight': 2.0 + }, + }, + 'num_recycle': 3 + }, +}) + +CONFIG_MULTIMER = ml_collections.ConfigDict({ + 'model': { + 'embeddings_and_evoformer': { + 'evoformer_num_block': 48, + 'evoformer': { + 'msa_column_attention': { + 'dropout_rate': 0.0, + 'gating': True, + 'num_head': 8, + 'orientation': 'per_column', + 'shared_dropout': True + }, + 'msa_row_attention_with_pair_bias': { + 'dropout_rate': 0.15, + 'gating': True, + 'num_head': 8, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'msa_transition': { + 'dropout_rate': 0.0, + 'num_intermediate_factor': 4, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'outer_product_mean': { + 'chunk_size': 128, + 'dropout_rate': 0.0, + 'first': True, + 'num_outer_channel': 32, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'pair_transition': { + 'dropout_rate': 0.0, + 'num_intermediate_factor': 4, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_attention_ending_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'num_head': 4, + 'orientation': 'per_column', + 'shared_dropout': True + }, + 'triangle_attention_starting_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'num_head': 4, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_multiplication_incoming': { + 'dropout_rate': 0.25, + 'equation': 'kjc,kic->ijc', + 'num_intermediate_channel': 128, + 'orientation': 'per_row', + 'shared_dropout': True, + 'fuse_projection_weights': True + }, + 'triangle_multiplication_outgoing': { + 'dropout_rate': 0.25, + 'equation': 'ikc,jkc->ijc', + 'num_intermediate_channel': 128, + 'orientation': 'per_row', + 'shared_dropout': True, + 'fuse_projection_weights': True + } + }, + 'extra_msa_channel': 64, + 'extra_msa_stack_num_block': 4, + 'num_extra_msa': 1, + 'masked_msa': { + 'profile_prob': 0.1, + 'replace_fraction': 0.15, + 'same_prob': 0.1, + 'uniform_prob': 0.1 + }, + 'use_chain_relative': True, + 'max_relative_chain': 2, + 'max_relative_idx': 32, + 'seq_channel': 384, + 'msa_channel': 256, + 'pair_channel': 128, + 'prev_pos': { + 'max_bin': 20.75, + 'min_bin': 3.25, + 'num_bins': 15 + }, + 'recycle_features': True, + 'recycle_pos': True, + 'template': { + 'attention': { + 'gating': False, + 'num_head': 4 + }, + 'dgram_features': { + 'max_bin': 50.75, + 'min_bin': 3.25, + 'num_bins': 39 + }, + 'enabled': True, + 'num_channels': 64, + 'subbatch_size': 128, + 'template_pair_stack': { + 'num_block': 2, + 'pair_transition': { + 'dropout_rate': 0.0, + 'num_intermediate_factor': 2, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_attention_ending_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'num_head': 4, + 'orientation': 'per_column', + 'shared_dropout': True + }, + 'triangle_attention_starting_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'num_head': 4, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_multiplication_incoming': { + 'dropout_rate': 0.25, + 'equation': 'kjc,kic->ijc', + 'num_intermediate_channel': 64, + 'orientation': 'per_row', + 'shared_dropout': True, + 'fuse_projection_weights': True, + }, + 'triangle_multiplication_outgoing': { + 'dropout_rate': 0.25, + 'equation': 'ikc,jkc->ijc', + 'num_intermediate_channel': 64, + 'orientation': 'per_row', + 'shared_dropout': True, + 'fuse_projection_weights': True, + } + } + }, + }, + 'global_config': { + 'bfloat16': True, + 'bfloat16_output': False, + 'multimer_mode': True, + 'subbatch_size': 4, + 'use_remat': False, + 'zero_init': True, + 'use_dgram': False + }, + 'heads': { + 'distogram': { + 'first_break': 2.3125, + 'last_break': 21.6875, + 'num_bins': 64, + 'weight': 0.3 + }, + 'experimentally_resolved': { + 'filter_by_resolution': True, + 'max_resolution': 3.0, + 'min_resolution': 0.1, + 'weight': 0.01 + }, + 'masked_msa': { + 'weight': 2.0 + }, + 'predicted_aligned_error': { + 'filter_by_resolution': True, + 'max_error_bin': 31.0, + 'max_resolution': 3.0, + 'min_resolution': 0.1, + 'num_bins': 64, + 'num_channels': 128, + 'weight': 0.1 + }, + 'predicted_lddt': { + 'filter_by_resolution': True, + 'max_resolution': 3.0, + 'min_resolution': 0.1, + 'num_bins': 50, + 'num_channels': 128, + 'weight': 0.01 + }, + 'structure_module': { + 'angle_norm_weight': 0.01, + 'chi_weight': 0.5, + 'clash_overlap_tolerance': 1.5, + 'dropout': 0.1, + 'interface_fape': { + 'atom_clamp_distance': 1000.0, + 'loss_unit_distance': 20.0 + }, + 'intra_chain_fape': { + 'atom_clamp_distance': 10.0, + 'loss_unit_distance': 10.0 + }, + 'num_channel': 384, + 'num_head': 12, + 'num_layer': 8, + 'num_layer_in_transition': 3, + 'num_point_qk': 4, + 'num_point_v': 8, + 'num_scalar_qk': 16, + 'num_scalar_v': 16, + 'position_scale': 20.0, + 'sidechain': { + 'atom_clamp_distance': 10.0, + 'loss_unit_distance': 10.0, + 'num_channel': 128, + 'num_residual_block': 2, + 'weight_frac': 0.5 + }, + 'structural_violation_loss_weight': 1.0, + 'violation_tolerance_factor': 12.0, + 'weight': 1.0 + } + }, + 'num_recycle': 3, + 'use_struct': True, + } +}) \ No newline at end of file diff --git a/build/lib/colabdesign/af/alphafold/model/data.py b/build/lib/colabdesign/af/alphafold/model/data.py new file mode 100644 index 00000000..d34918a5 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/data.py @@ -0,0 +1,41 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convenience functions for reading data.""" + +import io +import os +from typing import List +from colabdesign.af.alphafold.model import utils +import haiku as hk +import numpy as np +# Internal import (7716). + + +def casp_model_names(data_dir: str) -> List[str]: + params = os.listdir(os.path.join(data_dir, 'params')) + return [os.path.splitext(filename)[0] for filename in params] + + +def get_model_haiku_params(model_name: str, data_dir: str, fuse: bool = None) -> hk.Params: + """Get the Haiku parameters from a model name.""" + + path = os.path.join(data_dir, 'params', f'params_{model_name}.npz') + if not os.path.isfile(path): path = os.path.join(data_dir, f'params_{model_name}.npz') + if not os.path.isfile(path): path = os.path.join(data_dir, 'params', f'{model_name}.npz') + if not os.path.isfile(path): path = os.path.join(data_dir, f'{model_name}.npz') + if os.path.isfile(path): + with open(path, 'rb') as f: + params = np.load(io.BytesIO(f.read()), allow_pickle=False) + return utils.flat_params_to_haiku(params, fuse=fuse) \ No newline at end of file diff --git a/build/lib/colabdesign/af/alphafold/model/folding.py b/build/lib/colabdesign/af/alphafold/model/folding.py new file mode 100644 index 00000000..ac125859 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/folding.py @@ -0,0 +1,982 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Modules and utilities for the structure module.""" + +import functools +from typing import Dict +from colabdesign.af.alphafold.common import residue_constants +from colabdesign.af.alphafold.model import all_atom +from colabdesign.af.alphafold.model import common_modules +from colabdesign.af.alphafold.model import prng +from colabdesign.af.alphafold.model import quat_affine +from colabdesign.af.alphafold.model import r3 +from colabdesign.af.alphafold.model import utils +import haiku as hk +import jax +import jax.numpy as jnp +import ml_collections +import numpy as np + + +def squared_difference(x, y): + return jnp.square(x - y) + + +class InvariantPointAttention(hk.Module): + """Invariant Point attention module. + + The high-level idea is that this attention module works over a set of points + and associated orientations in 3D space (e.g. protein residues). + + Each residue outputs a set of queries and keys as points in their local + reference frame. The attention is then defined as the euclidean distance + between the queries and keys in the global frame. + + Jumper et al. (2021) Suppl. Alg. 22 "InvariantPointAttention" + """ + + def __init__(self, + config, + global_config, + dist_epsilon=1e-8, + name='invariant_point_attention'): + """Initialize. + + Args: + config: Structure Module Config + global_config: Global Config of Model. + dist_epsilon: Small value to avoid NaN in distance calculation. + name: Haiku Module name. + """ + super().__init__(name=name) + + self._dist_epsilon = dist_epsilon + self._zero_initialize_last = global_config.zero_init + + self.config = config + + self.global_config = global_config + + def __call__(self, inputs_1d, inputs_2d, mask, affine): + """Compute geometry-aware attention. + + Given a set of query residues (defined by affines and associated scalar + features), this function computes geometry-aware attention between the + query residues and target residues. + + The residues produce points in their local reference frame, which + are converted into the global frame in order to compute attention via + euclidean distance. + + Equivalently, the target residues produce points in their local frame to be + used as attention values, which are converted into the query residues' + local frames. + + Args: + inputs_1d: (N, C) 1D input embedding that is the basis for the + scalar queries. + inputs_2d: (N, M, C') 2D input embedding, used for biases and values. + mask: (N, 1) mask to indicate which elements of inputs_1d participate + in the attention. + affine: QuatAffine object describing the position and orientation of + every element in inputs_1d. + + Returns: + Transformation of the input embedding. + """ + num_residues, _ = inputs_1d.shape + + # Improve readability by removing a large number of 'self's. + num_head = self.config.num_head + num_scalar_qk = self.config.num_scalar_qk + num_point_qk = self.config.num_point_qk + num_scalar_v = self.config.num_scalar_v + num_point_v = self.config.num_point_v + num_output = self.config.num_channel + + assert num_scalar_qk > 0 + assert num_point_qk > 0 + assert num_point_v > 0 + + # Construct scalar queries of shape: + # [num_query_residues, num_head, num_points] + q_scalar = common_modules.Linear( + num_head * num_scalar_qk, name='q_scalar')( + inputs_1d) + q_scalar = jnp.reshape( + q_scalar, [num_residues, num_head, num_scalar_qk]) + + # Construct scalar keys/values of shape: + # [num_target_residues, num_head, num_points] + kv_scalar = common_modules.Linear( + num_head * (num_scalar_v + num_scalar_qk), name='kv_scalar')( + inputs_1d) + kv_scalar = jnp.reshape(kv_scalar, + [num_residues, num_head, + num_scalar_v + num_scalar_qk]) + k_scalar, v_scalar = jnp.split(kv_scalar, [num_scalar_qk], axis=-1) + + # Construct query points of shape: + # [num_residues, num_head, num_point_qk] + + # First construct query points in local frame. + q_point_local = common_modules.Linear( + num_head * 3 * num_point_qk, name='q_point_local')( + inputs_1d) + q_point_local = jnp.split(q_point_local, 3, axis=-1) + # Project query points into global frame. + q_point_global = affine.apply_to_point(q_point_local, extra_dims=1) + # Reshape query point for later use. + q_point = [ + jnp.reshape(x, [num_residues, num_head, num_point_qk]) + for x in q_point_global] + + # Construct key and value points. + # Key points have shape [num_residues, num_head, num_point_qk] + # Value points have shape [num_residues, num_head, num_point_v] + + # Construct key and value points in local frame. + kv_point_local = common_modules.Linear( + num_head * 3 * (num_point_qk + num_point_v), name='kv_point_local')( + inputs_1d) + kv_point_local = jnp.split(kv_point_local, 3, axis=-1) + # Project key and value points into global frame. + kv_point_global = affine.apply_to_point(kv_point_local, extra_dims=1) + kv_point_global = [ + jnp.reshape(x, [num_residues, + num_head, (num_point_qk + num_point_v)]) + for x in kv_point_global] + # Split key and value points. + k_point, v_point = list( + zip(*[ + jnp.split(x, [num_point_qk,], axis=-1) + for x in kv_point_global + ])) + + # We assume that all queries and keys come iid from N(0, 1) distribution + # and compute the variances of the attention logits. + # Each scalar pair (q, k) contributes Var q*k = 1 + scalar_variance = max(num_scalar_qk, 1) * 1. + # Each point pair (q, k) contributes Var [0.5 ||q||^2 - ] = 9 / 2 + point_variance = max(num_point_qk, 1) * 9. / 2 + + # Allocate equal variance to scalar, point and attention 2d parts so that + # the sum is 1. + + num_logit_terms = 3 + + scalar_weights = np.sqrt(1.0 / (num_logit_terms * scalar_variance)) + point_weights = np.sqrt(1.0 / (num_logit_terms * point_variance)) + attention_2d_weights = np.sqrt(1.0 / (num_logit_terms)) + + # Trainable per-head weights for points. + trainable_point_weights = jax.nn.softplus(hk.get_parameter( + 'trainable_point_weights', shape=[num_head], + # softplus^{-1} (1) + init=hk.initializers.Constant(np.log(np.exp(1.) - 1.)))) + point_weights *= jnp.expand_dims(trainable_point_weights, axis=1) + + v_point = [jnp.swapaxes(x, -2, -3) for x in v_point] + + q_point = [jnp.swapaxes(x, -2, -3) for x in q_point] + k_point = [jnp.swapaxes(x, -2, -3) for x in k_point] + dist2 = [ + squared_difference(qx[:, :, None, :], kx[:, None, :, :]) + for qx, kx in zip(q_point, k_point) + ] + dist2 = sum(dist2) + attn_qk_point = -0.5 * jnp.sum( + point_weights[:, None, None, :] * dist2, axis=-1) + + v = jnp.swapaxes(v_scalar, -2, -3) + q = jnp.swapaxes(scalar_weights * q_scalar, -2, -3) + k = jnp.swapaxes(k_scalar, -2, -3) + attn_qk_scalar = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) + attn_logits = attn_qk_scalar + attn_qk_point + + attention_2d = common_modules.Linear( + num_head, name='attention_2d')( + inputs_2d) + + attention_2d = jnp.transpose(attention_2d, [2, 0, 1]) + attention_2d = attention_2d_weights * attention_2d + attn_logits += attention_2d + + mask_2d = mask * jnp.swapaxes(mask, -1, -2) + attn_logits -= 1e5 * (1. - mask_2d) + + # [num_head, num_query_residues, num_target_residues] + attn = jax.nn.softmax(attn_logits) + + # [num_head, num_query_residues, num_head * num_scalar_v] + result_scalar = jnp.matmul(attn, v) + + # For point result, implement matmul manually so that it will be a float32 + # on TPU. This is equivalent to + # result_point_global = [jnp.einsum('bhqk,bhkc->bhqc', attn, vx) + # for vx in v_point] + # but on the TPU, doing the multiply and reduce_sum ensures the + # computation happens in float32 instead of bfloat16. + result_point_global = [jnp.sum( + attn[:, :, :, None] * vx[:, None, :, :], + axis=-2) for vx in v_point] + + # [num_query_residues, num_head, num_head * num_(scalar|point)_v] + result_scalar = jnp.swapaxes(result_scalar, -2, -3) + result_point_global = [ + jnp.swapaxes(x, -2, -3) + for x in result_point_global] + + # Features used in the linear output projection. Should have the size + # [num_query_residues, ?] + output_features = [] + + result_scalar = jnp.reshape( + result_scalar, [num_residues, num_head * num_scalar_v]) + output_features.append(result_scalar) + + result_point_global = [ + jnp.reshape(r, [num_residues, num_head * num_point_v]) + for r in result_point_global] + result_point_local = affine.invert_point(result_point_global, extra_dims=1) + output_features.extend(result_point_local) + + output_features.append(jnp.sqrt(self._dist_epsilon + + jnp.square(result_point_local[0]) + + jnp.square(result_point_local[1]) + + jnp.square(result_point_local[2]))) + + # Dimensions: h = heads, i and j = residues, + # c = inputs_2d channels + # Contraction happens over the second residue dimension, similarly to how + # the usual attention is performed. + result_attention_over_2d = jnp.einsum('hij, ijc->ihc', attn, inputs_2d) + num_out = num_head * result_attention_over_2d.shape[-1] + output_features.append( + jnp.reshape(result_attention_over_2d, + [num_residues, num_out])) + + final_init = 'zeros' if self._zero_initialize_last else 'linear' + + final_act = jnp.concatenate(output_features, axis=-1) + + return common_modules.Linear( + num_output, + initializer=final_init, + name='output_projection')(final_act) + + +class FoldIteration(hk.Module): + """A single iteration of the main structure module loop. + + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" lines 6-21 + + First, each residue attends to all residues using InvariantPointAttention. + Then, we apply transition layers to update the hidden representations. + Finally, we use the hidden representations to produce an update to the + affine of each residue. + """ + + def __init__(self, config, global_config, + name='fold_iteration'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, + activations, + sequence_mask, + update_affine, + initial_act, + use_dropout, + safe_key=None, + static_feat_2d=None, + aatype=None): + c = self.config + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + def safe_dropout_fn(tensor, safe_key): + return prng.safe_dropout( + tensor=tensor, + safe_key=safe_key, + rate=jnp.where(use_dropout, c.dropout, 0)) + + affine = quat_affine.QuatAffine.from_tensor(activations['affine']) + + act = activations['act'] + attention_module = InvariantPointAttention(self.config, self.global_config) + # Attention + attn = attention_module( + inputs_1d=act, + inputs_2d=static_feat_2d, + mask=sequence_mask, + affine=affine) + act += attn + safe_key, *sub_keys = safe_key.split(3) + sub_keys = iter(sub_keys) + act = safe_dropout_fn(act, next(sub_keys)) + act = common_modules.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='attention_layer_norm')( + act) + + final_init = 'zeros' if self.global_config.zero_init else 'linear' + + # Transition + input_act = act + for i in range(c.num_layer_in_transition): + init = 'relu' if i < c.num_layer_in_transition - 1 else final_init + act = common_modules.Linear( + c.num_channel, + initializer=init, + name='transition')( + act) + if i < c.num_layer_in_transition - 1: + act = jax.nn.relu(act) + act += input_act + act = safe_dropout_fn(act, next(sub_keys)) + act = common_modules.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='transition_layer_norm')(act) + + if update_affine: + # This block corresponds to + # Jumper et al. (2021) Alg. 23 "Backbone update" + affine_update_size = 6 + + # Affine update + affine_update = common_modules.Linear( + affine_update_size, + initializer=final_init, + name='affine_update')( + act) + + affine = affine.pre_compose(affine_update) + + sc = MultiRigidSidechain(c.sidechain, self.global_config)( + affine.scale_translation(c.position_scale), [act, initial_act], aatype) + + outputs = {'affine': affine.to_tensor(), 'sc': sc} + + # affine = affine.apply_rotation_tensor_fn(jax.lax.stop_gradient) + + new_activations = { + 'act': act, + 'affine': affine.to_tensor() + } + return new_activations, outputs + + +def generate_affines(representations, batch, config, global_config, safe_key): + """Generate predicted affines for a single chain. + + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" + + This is the main part of the structure module - it iteratively applies + folding to produce a set of predicted residue positions. + + Args: + representations: Representations dictionary. + batch: Batch dictionary. + config: Config for the structure module. + global_config: Global config. + safe_key: A prng.SafeKey object that wraps a PRNG key. + + Returns: + A dictionary containing residue affines and sidechain positions. + """ + c = config + sequence_mask = batch['seq_mask'][:, None] + + act = common_modules.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='single_layer_norm')( + representations['single']) + + initial_act = act + act = common_modules.Linear( + c.num_channel, name='initial_projection')( + act) + + if "initial_atom_pos" in batch: + atom = residue_constants.atom_order + atom_pos = batch["initial_atom_pos"] + if global_config.bfloat16: + atom_pos = atom_pos.astype(jnp.float32) + rot, trans = quat_affine.make_transform_from_reference( + n_xyz=atom_pos[:, atom["N"]], + ca_xyz=atom_pos[:, atom["CA"]], + c_xyz=atom_pos[:, atom["C"]]) + + affine = quat_affine.QuatAffine( + quaternion=quat_affine.rot_to_quat(rot, unstack_inputs=True), + translation=trans, + rotation=rot, + unstack_inputs=True).scale_translation(1/c.position_scale) + else: + affine = generate_new_affine(sequence_mask) + + fold_iteration = FoldIteration( + c, global_config, name='fold_iteration') + + assert len(batch['seq_mask'].shape) == 1 + + activations = {'act': act, + 'affine': affine.to_tensor(), + } + + act_2d = common_modules.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='pair_layer_norm')( + representations['pair']) + + def fold_iter(act, key): + act, out = fold_iteration( + act, + initial_act=initial_act, + static_feat_2d=act_2d, + safe_key=prng.SafeKey(key), + sequence_mask=sequence_mask, + update_affine=True, + aatype=batch['aatype'], + use_dropout=batch["use_dropout"]) + return act, out + keys = jax.random.split(safe_key.get(), c.num_layer) + activations, output = hk.scan(fold_iter, activations, keys) + + # Include the activations in the output dict for use by the LDDT-Head. + output['act'] = activations['act'] + + return output + + +class dummy(hk.Module): + def __init__(self, config, global_config): + super().__init__(name="dummy") + def __call__(self, representations, batch, safe_key=None): + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + return {} + +class StructureModule(hk.Module): + """StructureModule as a network head. + + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" + """ + + def __init__(self, config, global_config, + name='structure_module'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, representations, batch, safe_key=None): + c = self.config + ret = {} + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + output = generate_affines( + representations=representations, + batch=batch, + config=self.config, + global_config=self.global_config, + safe_key=safe_key) + + ret['representations'] = {'structure_module': output['act']} + + ret['traj'] = output['affine'] * jnp.array([1.] * 4 + [c.position_scale] * 3) + ret['sidechains'] = output['sc'] + atom14_pred_positions = r3.vecs_to_tensor(output['sc']['atom_pos'])[-1] + ret['final_atom14_positions'] = atom14_pred_positions # (N, 14, 3) + ret['final_atom14_mask'] = batch['atom14_atom_exists'] # (N, 14) + + atom37_pred_positions = all_atom.atom14_to_atom37(atom14_pred_positions, batch) + atom37_pred_positions *= batch['atom37_atom_exists'][:, :, None] + ret['final_atom_positions'] = atom37_pred_positions # (N, 37, 3) + ret['final_atom_mask'] = batch['atom37_atom_exists'] # (N, 37) + ret['final_affines'] = ret['traj'][-1] + + return ret + + +def compute_renamed_ground_truth( + batch: Dict[str, jnp.ndarray], + atom14_pred_positions: jnp.ndarray, + ) -> Dict[str, jnp.ndarray]: + """Find optimal renaming of ground truth based on the predicted positions. + + Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" + + This renamed ground truth is then used for all losses, + such that each loss moves the atoms in the same direction. + Shape (N). + + Args: + batch: Dictionary containing: + * atom14_gt_positions: Ground truth positions. + * atom14_alt_gt_positions: Ground truth positions with renaming swaps. + * atom14_atom_is_ambiguous: 1.0 for atoms that are affected by + renaming swaps. + * atom14_gt_exists: Mask for which atoms exist in ground truth. + * atom14_alt_gt_exists: Mask for which atoms exist in ground truth + after renaming. + * atom14_atom_exists: Mask for whether each atom is part of the given + amino acid type. + atom14_pred_positions: Array of atom positions in global frame with shape + (N, 14, 3). + Returns: + Dictionary containing: + alt_naming_is_better: Array with 1.0 where alternative swap is better. + renamed_atom14_gt_positions: Array of optimal ground truth positions + after renaming swaps are performed. + renamed_atom14_gt_exists: Mask after renaming swap is performed. + """ + alt_naming_is_better = all_atom.find_optimal_renaming( + atom14_gt_positions=batch['atom14_gt_positions'], + atom14_alt_gt_positions=batch['atom14_alt_gt_positions'], + atom14_atom_is_ambiguous=batch['atom14_atom_is_ambiguous'], + atom14_gt_exists=batch['atom14_gt_exists'], + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=batch['atom14_atom_exists']) + + renamed_atom14_gt_positions = ( + (1. - alt_naming_is_better[:, None, None]) + * batch['atom14_gt_positions'] + + alt_naming_is_better[:, None, None] + * batch['atom14_alt_gt_positions']) + + renamed_atom14_gt_mask = ( + (1. - alt_naming_is_better[:, None]) * batch['atom14_gt_exists'] + + alt_naming_is_better[:, None] * batch['atom14_alt_gt_exists']) + + return { + 'alt_naming_is_better': alt_naming_is_better, # (N) + 'renamed_atom14_gt_positions': renamed_atom14_gt_positions, # (N, 14, 3) + 'renamed_atom14_gt_exists': renamed_atom14_gt_mask, # (N, 14) + } + + +def backbone_loss(batch, value, config): + """Backbone FAPE Loss. + + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" line 17 + + Args: + ret: Dictionary to write outputs into, needs to contain 'loss'. + batch: Batch, needs to contain 'backbone_affine_tensor', + 'backbone_affine_mask'. + value: Dictionary containing structure module output, needs to contain + 'traj', a trajectory of rigids. + config: Configuration of loss, should contain 'fape.clamp_distance' and + 'fape.loss_unit_distance'. + """ + affine_trajectory = quat_affine.QuatAffine.from_tensor(value['traj']) + rigid_trajectory = r3.rigids_from_quataffine(affine_trajectory) + + if 'backbone_affine_tensor' in batch: + gt_affine = quat_affine.QuatAffine.from_tensor(batch['backbone_affine_tensor']) + backbone_mask = batch['backbone_affine_mask'] + else: + n_xyz = batch['all_atom_positions'][...,0,:] + ca_xyz = batch['all_atom_positions'][...,1,:] + c_xyz = batch['all_atom_positions'][...,2,:] + rot, trans = quat_affine.make_transform_from_reference(n_xyz, ca_xyz, c_xyz) + gt_affine = quat_affine.QuatAffine(quaternion=None, + translation=trans, + rotation=rot, + unstack_inputs=True) + backbone_mask = batch['all_atom_mask'][...,0] + + gt_rigid = r3.rigids_from_quataffine(gt_affine) + + fape_loss_fn = functools.partial( + all_atom.frame_aligned_point_error, + l1_clamp_distance=config.fape.clamp_distance, + length_scale=config.fape.loss_unit_distance) + + fape_loss_fn = jax.vmap(fape_loss_fn, (0, None, None, 0, None, None)) + fape_loss = fape_loss_fn(rigid_trajectory, gt_rigid, backbone_mask, + rigid_trajectory.trans, gt_rigid.trans, + backbone_mask) + + if 'use_clamped_fape' in batch: + # Jumper et al. (2021) Suppl. Sec. 1.11.5 "Loss clamping details" + use_clamped_fape = jnp.asarray(batch['use_clamped_fape'], jnp.float32) + unclamped_fape_loss_fn = functools.partial( + all_atom.frame_aligned_point_error, + l1_clamp_distance=None, + length_scale=config.fape.loss_unit_distance) + unclamped_fape_loss_fn = jax.vmap(unclamped_fape_loss_fn, + (0, None, None, 0, None, None)) + fape_loss_unclamped = unclamped_fape_loss_fn(rigid_trajectory, gt_rigid, + backbone_mask, + rigid_trajectory.trans, + gt_rigid.trans, + backbone_mask) + + fape_loss = (fape_loss * use_clamped_fape + fape_loss_unclamped * (1 - use_clamped_fape)) + + return jnp.mean(fape_loss), fape_loss[-1] + + +def sidechain_loss(batch, value, config): + """All Atom FAPE Loss using renamed rigids.""" + # Rename Frames + # Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" line 7 + alt_naming_is_better = value['alt_naming_is_better'] + renamed_gt_frames = ( + (1. - alt_naming_is_better[:, None, None]) + * batch['rigidgroups_gt_frames'] + + alt_naming_is_better[:, None, None] + * batch['rigidgroups_alt_gt_frames']) + + flat_gt_frames = r3.rigids_from_tensor_flat12(jnp.reshape(renamed_gt_frames, [-1, 12])) + flat_frames_mask = jnp.reshape(batch['rigidgroups_gt_exists'], [-1]) + + flat_gt_positions = r3.vecs_from_tensor(jnp.reshape(value['renamed_atom14_gt_positions'], [-1, 3])) + flat_positions_mask = jnp.reshape(value['renamed_atom14_gt_exists'], [-1]) + + # Compute frame_aligned_point_error score for the final layer. + pred_frames = value['sidechains']['frames'] + pred_positions = value['sidechains']['atom_pos'] + + def _slice_last_layer_and_flatten(x): + return jnp.reshape(x[-1], [-1]) + + flat_pred_frames = jax.tree_map(_slice_last_layer_and_flatten, pred_frames) + flat_pred_positions = jax.tree_map(_slice_last_layer_and_flatten, pred_positions) + # FAPE Loss on sidechains + fape = all_atom.frame_aligned_point_error( + pred_frames=flat_pred_frames, + target_frames=flat_gt_frames, + frames_mask=flat_frames_mask, + pred_positions=flat_pred_positions, + target_positions=flat_gt_positions, + positions_mask=flat_positions_mask, + l1_clamp_distance=config.sidechain.atom_clamp_distance, + length_scale=config.sidechain.length_scale) + + return { + 'fape': fape, + 'loss': fape} + + +def structural_violation_loss(ret, batch, value, config): + """Computes loss for structural violations.""" + assert config.sidechain.weight_frac + + # Put all violation losses together to one large loss. + violations = value['violations'] + num_atoms = jnp.sum(batch['atom14_atom_exists']).astype(jnp.float32) + ret['loss'] += (config.structural_violation_loss_weight * ( + violations['between_residues']['bonds_c_n_loss_mean'] + + violations['between_residues']['angles_ca_c_n_loss_mean'] + + violations['between_residues']['angles_c_n_ca_loss_mean'] + + jnp.sum( + violations['between_residues']['clashes_per_atom_loss_sum'] + + violations['within_residues']['per_atom_loss_sum']) / + (1e-6 + num_atoms))) + + +def find_structural_violations( + batch: Dict[str, jnp.ndarray], + atom14_pred_positions: jnp.ndarray, # (N, 14, 3) + config: ml_collections.ConfigDict + ): + """Computes several checks for structural violations.""" + + # Compute between residue backbone violations of bonds and angles. + connection_violations = all_atom.between_residue_bond_loss( + pred_atom_positions=atom14_pred_positions, + pred_atom_mask=batch['atom14_atom_exists'].astype(jnp.float32), + residue_index=batch['residue_index'].astype(jnp.float32), + aatype=batch['aatype'], + tolerance_factor_soft=config.violation_tolerance_factor, + tolerance_factor_hard=config.violation_tolerance_factor) + + # Compute the Van der Waals radius for every atom + # (the first letter of the atom name is the element type). + # Shape: (N, 14). + atomtype_radius = [ + residue_constants.van_der_waals_radius[name[0]] + for name in residue_constants.atom_types + ] + atom14_atom_radius = batch['atom14_atom_exists'] * utils.batched_gather( + atomtype_radius, batch['residx_atom14_to_atom37']) + + # Compute the between residue clash loss. + between_residue_clashes = all_atom.between_residue_clash_loss( + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=batch['atom14_atom_exists'], + atom14_atom_radius=atom14_atom_radius, + residue_index=batch['residue_index'], + overlap_tolerance_soft=config.clash_overlap_tolerance, + overlap_tolerance_hard=config.clash_overlap_tolerance) + + # Compute all within-residue violations (clashes, + # bond length and angle violations). + restype_atom14_bounds = residue_constants.make_atom14_dists_bounds( + overlap_tolerance=config.clash_overlap_tolerance, + bond_length_tolerance_factor=config.violation_tolerance_factor) + atom14_dists_lower_bound = utils.batched_gather( + restype_atom14_bounds['lower_bound'], batch['aatype']) + atom14_dists_upper_bound = utils.batched_gather( + restype_atom14_bounds['upper_bound'], batch['aatype']) + within_residue_violations = all_atom.within_residue_violations( + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=batch['atom14_atom_exists'], + atom14_dists_lower_bound=atom14_dists_lower_bound, + atom14_dists_upper_bound=atom14_dists_upper_bound, + tighten_bounds_for_loss=0.0) + + # Combine them to a single per-residue violation mask (used later for LDDT). + per_residue_violations_mask = jnp.max(jnp.stack([ + connection_violations['per_residue_violation_mask'], + jnp.max(between_residue_clashes['per_atom_clash_mask'], axis=-1), + jnp.max(within_residue_violations['per_atom_violations'], + axis=-1)]), axis=0) + + return { + 'between_residues': { + 'bonds_c_n_loss_mean': + connection_violations['c_n_loss_mean'], # () + 'angles_ca_c_n_loss_mean': + connection_violations['ca_c_n_loss_mean'], # () + 'angles_c_n_ca_loss_mean': + connection_violations['c_n_ca_loss_mean'], # () + 'connections_per_residue_loss_sum': + connection_violations['per_residue_loss_sum'], # (N) + 'connections_per_residue_violation_mask': + connection_violations['per_residue_violation_mask'], # (N) + 'clashes_mean_loss': + between_residue_clashes['mean_loss'], # () + 'clashes_per_atom_loss_sum': + between_residue_clashes['per_atom_loss_sum'], # (N, 14) + 'clashes_per_atom_clash_mask': + between_residue_clashes['per_atom_clash_mask'], # (N, 14) + }, + 'within_residues': { + 'per_atom_loss_sum': + within_residue_violations['per_atom_loss_sum'], # (N, 14) + 'per_atom_violations': + within_residue_violations['per_atom_violations'], # (N, 14), + }, + 'total_per_residue_violations_mask': + per_residue_violations_mask, # (N) + } + + +def compute_violation_metrics( + batch: Dict[str, jnp.ndarray], + atom14_pred_positions: jnp.ndarray, # (N, 14, 3) + violations: Dict[str, jnp.ndarray], + ) -> Dict[str, jnp.ndarray]: + """Compute several metrics to assess the structural violations.""" + + ret = {} + extreme_ca_ca_violations = all_atom.extreme_ca_ca_distance_violations( + pred_atom_positions=atom14_pred_positions, + pred_atom_mask=batch['atom14_atom_exists'].astype(jnp.float32), + residue_index=batch['residue_index'].astype(jnp.float32)) + ret['violations_extreme_ca_ca_distance'] = extreme_ca_ca_violations + ret['violations_between_residue_bond'] = utils.mask_mean( + mask=batch['seq_mask'], + value=violations['between_residues'][ + 'connections_per_residue_violation_mask']) + ret['violations_between_residue_clash'] = utils.mask_mean( + mask=batch['seq_mask'], + value=jnp.max( + violations['between_residues']['clashes_per_atom_clash_mask'], + axis=-1)) + ret['violations_within_residue'] = utils.mask_mean( + mask=batch['seq_mask'], + value=jnp.max( + violations['within_residues']['per_atom_violations'], axis=-1)) + ret['violations_per_residue'] = utils.mask_mean( + mask=batch['seq_mask'], + value=violations['total_per_residue_violations_mask']) + return ret + + +def supervised_chi_loss(ret, batch, value, config): + """Computes loss for direct chi angle supervision. + + Jumper et al. (2021) Suppl. Alg. 27 "torsionAngleLoss" + + Args: + ret: Dictionary to write outputs into, needs to contain 'loss'. + batch: Batch, needs to contain 'seq_mask', 'chi_mask', 'chi_angles'. + value: Dictionary containing structure module output, needs to contain + value['sidechains']['angles_sin_cos'] for angles and + value['sidechains']['unnormalized_angles_sin_cos'] for unnormalized + angles. + config: Configuration of loss, should contain 'chi_weight' and + 'angle_norm_weight', 'angle_norm_weight' scales angle norm term, + 'chi_weight' scales torsion term. + """ + eps = 1e-6 + + sequence_mask = batch['seq_mask'] + num_res = sequence_mask.shape[0] + chi_mask = batch['chi_mask'].astype(jnp.float32) + pred_angles = jnp.reshape( + value['sidechains']['angles_sin_cos'], [-1, num_res, 7, 2]) + pred_angles = pred_angles[:, :, 3:] + + residue_type_one_hot = jax.nn.one_hot( + batch['aatype'], residue_constants.restype_num + 1, + dtype=jnp.float32)[None] + chi_pi_periodic = jnp.einsum('ijk, kl->ijl', residue_type_one_hot, + jnp.asarray(residue_constants.chi_pi_periodic)) + + true_chi = batch['chi_angles'][None] + sin_true_chi = jnp.sin(true_chi) + cos_true_chi = jnp.cos(true_chi) + sin_cos_true_chi = jnp.stack([sin_true_chi, cos_true_chi], axis=-1) + + # This is -1 if chi is pi-periodic and +1 if it's 2pi-periodic + shifted_mask = (1 - 2 * chi_pi_periodic)[..., None] + sin_cos_true_chi_shifted = shifted_mask * sin_cos_true_chi + + sq_chi_error = jnp.sum( + squared_difference(sin_cos_true_chi, pred_angles), -1) + sq_chi_error_shifted = jnp.sum( + squared_difference(sin_cos_true_chi_shifted, pred_angles), -1) + sq_chi_error = jnp.minimum(sq_chi_error, sq_chi_error_shifted) + + sq_chi_loss = utils.mask_mean(mask=chi_mask[None], value=sq_chi_error) + ret['chi_loss'] = sq_chi_loss + ret['loss'] += config.chi_weight * sq_chi_loss + unnormed_angles = jnp.reshape( + value['sidechains']['unnormalized_angles_sin_cos'], [-1, num_res, 7, 2]) + angle_norm = jnp.sqrt(jnp.sum(jnp.square(unnormed_angles), axis=-1) + eps) + norm_error = jnp.abs(angle_norm - 1.) + angle_norm_loss = utils.mask_mean(mask=sequence_mask[None, :, None], + value=norm_error) + + ret['angle_norm_loss'] = angle_norm_loss + ret['loss'] += config.angle_norm_weight * angle_norm_loss + + +def generate_new_affine(sequence_mask): + num_residues, _ = sequence_mask.shape + quaternion = jnp.tile( + jnp.reshape(jnp.asarray([1., 0., 0., 0.]), [1, 4]), + [num_residues, 1]) + + translation = jnp.zeros([num_residues, 3]) + return quat_affine.QuatAffine(quaternion, translation, unstack_inputs=True) + + +def l2_normalize(x, axis=-1, epsilon=1e-12): + return x / jnp.sqrt( + jnp.maximum(jnp.sum(x**2, axis=axis, keepdims=True), epsilon)) + + +class MultiRigidSidechain(hk.Module): + """Class to make side chain atoms.""" + + def __init__(self, config, global_config, name='rigid_sidechain'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, affine, representations_list, aatype): + """Predict side chains using multi-rigid representations. + + Args: + affine: The affines for each residue (translations in angstroms). + representations_list: A list of activations to predict side chains from. + aatype: Amino acid types. + + Returns: + Dict containing atom positions and frames (in angstroms). + """ + act = [ + common_modules.Linear( # pylint: disable=g-complex-comprehension + self.config.num_channel, + name='input_projection')(jax.nn.relu(x)) + for x in representations_list + ] + # Sum the activation list (equivalent to concat then Linear). + act = sum(act) + + final_init = 'zeros' if self.global_config.zero_init else 'linear' + + # Mapping with some residual blocks. + for _ in range(self.config.num_residual_block): + old_act = act + act = common_modules.Linear( + self.config.num_channel, + initializer='relu', + name='resblock1')( + jax.nn.relu(act)) + act = common_modules.Linear( + self.config.num_channel, + initializer=final_init, + name='resblock2')( + jax.nn.relu(act)) + act += old_act + + # Map activations to torsion angles. Shape: (num_res, 14). + num_res = act.shape[0] + unnormalized_angles = common_modules.Linear( + 14, name='unnormalized_angles')( + jax.nn.relu(act)) + unnormalized_angles = jnp.reshape( + unnormalized_angles, [num_res, 7, 2]) + angles = l2_normalize(unnormalized_angles, axis=-1) + + outputs = { + 'angles_sin_cos': angles, # jnp.ndarray (N, 7, 2) + 'unnormalized_angles_sin_cos': + unnormalized_angles, # jnp.ndarray (N, 7, 2) + } + + # Map torsion angles to frames. + backb_to_global = r3.rigids_from_quataffine(affine) + + # Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" + + # r3.Rigids with shape (N, 8). + all_frames_to_global = all_atom.torsion_angles_to_frames( + aatype, + backb_to_global, + angles) + + # Use frames and literature positions to create the final atom coordinates. + # r3.Vecs with shape (N, 14). + pred_positions = all_atom.frames_and_literature_positions_to_atom14_pos( + aatype, all_frames_to_global) + + outputs.update({ + 'atom_pos': pred_positions, # r3.Vecs (N, 14) + 'frames': all_frames_to_global, # r3.Rigids (N, 8) + 'angles': angles, + 'backb_to_global': backb_to_global + }) + return outputs diff --git a/build/lib/colabdesign/af/alphafold/model/folding_multimer.py b/build/lib/colabdesign/af/alphafold/model/folding_multimer.py new file mode 100644 index 00000000..14db08ae --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/folding_multimer.py @@ -0,0 +1,1033 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Modules and utilities for the structure module in the multimer system.""" + +import functools +import numbers +from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union + +from colabdesign.af.alphafold.common import residue_constants +from colabdesign.af.alphafold.model import all_atom_multimer +from colabdesign.af.alphafold.model import common_modules +from colabdesign.af.alphafold.model import geometry +from colabdesign.af.alphafold.model import modules +from colabdesign.af.alphafold.model import prng +from colabdesign.af.alphafold.model import utils +from colabdesign.af.alphafold.model.geometry import utils as geometry_utils +import haiku as hk +import jax +import jax.numpy as jnp +import ml_collections +import numpy as np + + +EPSILON = 1e-8 +Float = Union[float, jnp.ndarray] + + +def squared_difference(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + """Computes Squared difference between two arrays.""" + return jnp.square(x - y) + + +def make_backbone_affine( + positions: geometry.Vec3Array, + mask: jnp.ndarray, + aatype: jnp.ndarray, + ) -> Tuple[geometry.Rigid3Array, jnp.ndarray]: + """Make backbone Rigid3Array and mask.""" + del aatype + a = residue_constants.atom_order['N'] + b = residue_constants.atom_order['CA'] + c = residue_constants.atom_order['C'] + + rigid_mask = (mask[:, a] * mask[:, b] * mask[:, c]).astype( + jnp.float32) + + rigid = all_atom_multimer.make_transform_from_reference( + a_xyz=positions[:, a], b_xyz=positions[:, b], c_xyz=positions[:, c]) + + return rigid, rigid_mask + + +class QuatRigid(hk.Module): + """Module for projecting Rigids via a quaternion.""" + + def __init__(self, + global_config: ml_collections.ConfigDict, + rigid_shape: Union[int, Iterable[int]] = tuple(), + full_quat: bool = False, + init: str = 'zeros', + name: str = 'quat_rigid'): + """Module projecting a Rigid Object. + + For this Module the Rotation is parametrized as a quaternion, + If 'full_quat' is True a 4 vector is produced for the rotation which is + normalized and treated as a quaternion. + When 'full_quat' is False a 3 vector is produced and the 1st component of + the quaternion is set to 1. + + Args: + global_config: Global Config, used to set certain properties of underlying + Linear module, see common_modules.Linear for details. + rigid_shape: Shape of Rigids relative to shape of activations, e.g. when + activations have shape (n,) and this is (m,) output will be (n, m) + full_quat: Whether to parametrize rotation using full quaternion. + init: initializer to use, see common_modules.Linear for details + name: Name to use for module. + """ + self.init = init + self.global_config = global_config + if isinstance(rigid_shape, int): + self.rigid_shape = (rigid_shape,) + else: + self.rigid_shape = tuple(rigid_shape) + self.full_quat = full_quat + super(QuatRigid, self).__init__(name=name) + + def __call__(self, activations: jnp.ndarray) -> geometry.Rigid3Array: + """Executes Module. + + This returns a set of rigid with the same shape as activations, projecting + the channel dimension, rigid_shape controls the trailing dimensions. + For example when activations is shape (12, 5) and rigid_shape is (3, 2) + then the shape of the output rigids will be (12, 3, 2). + This also supports passing in an empty tuple for rigid shape, in that case + the example would produce a rigid of shape (12,). + + Args: + activations: Activations to use for projection, shape [..., num_channel] + Returns: + Rigid transformations with shape [...] + rigid_shape + """ + if self.full_quat: + rigid_dim = 7 + else: + rigid_dim = 6 + linear_dims = self.rigid_shape + (rigid_dim,) + rigid_flat = common_modules.Linear( + linear_dims, + initializer=self.init, + precision=jax.lax.Precision.HIGHEST, + name='rigid')( + activations) + rigid_flat = geometry_utils.unstack(rigid_flat) + if self.full_quat: + qw, qx, qy, qz = rigid_flat[:4] + translation = rigid_flat[4:] + else: + qx, qy, qz = rigid_flat[:3] + qw = jnp.ones_like(qx) + translation = rigid_flat[3:] + rotation = geometry.Rot3Array.from_quaternion( + qw, qx, qy, qz, normalize=True) + translation = geometry.Vec3Array(*translation) + return geometry.Rigid3Array(rotation, translation) + + +class PointProjection(hk.Module): + """Given input reprensentation and frame produces points in global frame.""" + + def __init__(self, + num_points: Union[Iterable[int], int], + global_config: ml_collections.ConfigDict, + return_local_points: bool = False, + name: str = 'point_projection'): + """Constructs Linear Module. + + Args: + num_points: number of points to project. Can be tuple when outputting + multiple dimensions + global_config: Global Config, passed through to underlying Linear + return_local_points: Whether to return points in local frame as well. + name: name of module, used for name scopes. + """ + if isinstance(num_points, numbers.Integral): + self.num_points = (num_points,) + else: + self.num_points = tuple(num_points) + + self.return_local_points = return_local_points + + self.global_config = global_config + + super().__init__(name=name) + + def __call__( + self, activations: jnp.ndarray, rigids: geometry.Rigid3Array + ) -> Union[geometry.Vec3Array, Tuple[geometry.Vec3Array, geometry.Vec3Array]]: + output_shape = self.num_points + output_shape = output_shape[:-1] + (3 * output_shape[-1],) + points_local = common_modules.Linear( + output_shape, + precision=jax.lax.Precision.HIGHEST, + name='point_projection')( + activations) + points_local = jnp.split(points_local, 3, axis=-1) + points_local = geometry.Vec3Array(*points_local) + rigids = rigids[(...,) + (None,) * len(output_shape)] + points_global = rigids.apply_to_point(points_local) + if self.return_local_points: + return points_global, points_local + else: + return points_global + + +class InvariantPointAttention(hk.Module): + """Invariant point attention module. + + The high-level idea is that this attention module works over a set of points + and associated orientations in 3D space (e.g. protein residues). + + Each residue outputs a set of queries and keys as points in their local + reference frame. The attention is then defined as the euclidean distance + between the queries and keys in the global frame. + """ + + def __init__(self, + config: ml_collections.ConfigDict, + global_config: ml_collections.ConfigDict, + dist_epsilon: float = 1e-8, + name: str = 'invariant_point_attention'): + """Initialize. + + Args: + config: iterative Fold Head Config + global_config: Global Config of Model. + dist_epsilon: Small value to avoid NaN in distance calculation. + name: Sonnet name. + """ + super().__init__(name=name) + + self._dist_epsilon = dist_epsilon + self._zero_initialize_last = global_config.zero_init + + self.config = config + + self.global_config = global_config + + def __call__( + self, + inputs_1d: jnp.ndarray, + inputs_2d: jnp.ndarray, + mask: jnp.ndarray, + rigid: geometry.Rigid3Array, + ) -> jnp.ndarray: + """Compute geometric aware attention. + + Given a set of query residues (defined by affines and associated scalar + features), this function computes geometric aware attention between the + query residues and target residues. + + The residues produce points in their local reference frame, which + are converted into the global frame to get attention via euclidean distance. + + Equivalently the target residues produce points in their local frame to be + used as attention values, which are converted into the query residues local + frames. + + Args: + inputs_1d: (N, C) 1D input embedding that is the basis for the + scalar queries. + inputs_2d: (N, M, C') 2D input embedding, used for biases values in the + attention between query_inputs_1d and target_inputs_1d. + mask: (N, 1) mask to indicate query_inputs_1d that participate in + the attention. + rigid: Rigid object describing the position and orientation of + every element in query_inputs_1d. + + Returns: + Transformation of the input embedding. + """ + + num_head = self.config.num_head + + attn_logits = 0. + + num_point_qk = self.config.num_point_qk + # Each point pair (q, k) contributes Var [0.5 ||q||^2 - ] = 9 / 2 + point_variance = max(num_point_qk, 1) * 9. / 2 + point_weights = np.sqrt(1.0 / point_variance) + + # This is equivalent to jax.nn.softplus, but avoids a bug in the test... + softplus = lambda x: jnp.logaddexp(x, jnp.zeros_like(x)) + raw_point_weights = hk.get_parameter( + 'trainable_point_weights', + shape=[num_head], + # softplus^{-1} (1) + init=hk.initializers.Constant(np.log(np.exp(1.) - 1.))) + + # Trainable per-head weights for points. + trainable_point_weights = softplus(raw_point_weights) + point_weights *= trainable_point_weights + q_point = PointProjection([num_head, num_point_qk], + self.global_config, + name='q_point_projection')(inputs_1d, + rigid) + + k_point = PointProjection([num_head, num_point_qk], + self.global_config, + name='k_point_projection')(inputs_1d, + rigid) + + dist2 = geometry.square_euclidean_distance( + q_point[:, None, :, :], k_point[None, :, :, :], epsilon=0.) + attn_qk_point = -0.5 * jnp.sum(point_weights[:, None] * dist2, axis=-1) + attn_logits += attn_qk_point + + num_scalar_qk = self.config.num_scalar_qk + # We assume that all queries and keys come iid from N(0, 1) distribution + # and compute the variances of the attention logits. + # Each scalar pair (q, k) contributes Var q*k = 1 + scalar_variance = max(num_scalar_qk, 1) * 1. + scalar_weights = np.sqrt(1.0 / scalar_variance) + q_scalar = common_modules.Linear([num_head, num_scalar_qk], + use_bias=False, + name='q_scalar_projection')( + inputs_1d) + + k_scalar = common_modules.Linear([num_head, num_scalar_qk], + use_bias=False, + name='k_scalar_projection')( + inputs_1d) + q_scalar *= scalar_weights + attn_logits += jnp.einsum('qhc,khc->qkh', q_scalar, k_scalar) + + attention_2d = common_modules.Linear( + num_head, name='attention_2d')(inputs_2d) + attn_logits += attention_2d + + mask_2d = mask * jnp.swapaxes(mask, -1, -2) + attn_logits -= 1e5 * (1. - mask_2d[..., None]) + + attn_logits *= np.sqrt(1. / 3) # Normalize by number of logit terms (3) + attn = jax.nn.softmax(attn_logits, axis=-2) + + num_scalar_v = self.config.num_scalar_v + + v_scalar = common_modules.Linear([num_head, num_scalar_v], + use_bias=False, + name='v_scalar_projection')( + inputs_1d) + + # [num_query_residues, num_head, num_scalar_v] + result_scalar = jnp.einsum('qkh, khc->qhc', attn, v_scalar) + + num_point_v = self.config.num_point_v + v_point = PointProjection([num_head, num_point_v], + self.global_config, + name='v_point_projection')(inputs_1d, + rigid) + + result_point_global = jax.tree_map( + lambda x: jnp.sum(attn[..., None] * x, axis=-3), v_point[None]) + + # Features used in the linear output projection. Should have the size + # [num_query_residues, ?] + output_features = [] + num_query_residues, _ = inputs_1d.shape + + flat_shape = [num_query_residues, -1] + + result_scalar = jnp.reshape(result_scalar, flat_shape) + output_features.append(result_scalar) + + result_point_global = jax.tree_map(lambda r: jnp.reshape(r, flat_shape), + result_point_global) + result_point_local = rigid[..., None].apply_inverse_to_point( + result_point_global) + output_features.extend( + [result_point_local.x, result_point_local.y, result_point_local.z]) + + point_norms = result_point_local.norm(self._dist_epsilon) + output_features.append(point_norms) + + # Dimensions: h = heads, i and j = residues, + # c = inputs_2d channels + # Contraction happens over the second residue dimension, similarly to how + # the usual attention is performed. + result_attention_over_2d = jnp.einsum('ijh, ijc->ihc', attn, inputs_2d) + output_features.append(jnp.reshape(result_attention_over_2d, flat_shape)) + + final_init = 'zeros' if self._zero_initialize_last else 'linear' + + final_act = jnp.concatenate(output_features, axis=-1) + + return common_modules.Linear( + self.config.num_channel, + initializer=final_init, + name='output_projection')(final_act) + + +class FoldIteration(hk.Module): + """A single iteration of iterative folding. + + First, each residue attends to all residues using InvariantPointAttention. + Then, we apply transition layers to update the hidden representations. + Finally, we use the hidden representations to produce an update to the + affine of each residue. + """ + + def __init__(self, + config: ml_collections.ConfigDict, + global_config: ml_collections.ConfigDict, + name: str = 'fold_iteration'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__( + self, + activations: Mapping[str, Any], + aatype: jnp.ndarray, + sequence_mask: jnp.ndarray, + update_rigid: bool, + initial_act: jnp.ndarray, + use_dropout: bool, + safe_key: Optional[prng.SafeKey] = None, + static_feat_2d: Optional[jnp.ndarray] = None, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + + c = self.config + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + def safe_dropout_fn(tensor, safe_key): + return modules.apply_dropout( + tensor=tensor, + safe_key=safe_key, + rate=jnp.where(use_dropout, c.dropout, 0)) + + rigid = activations['rigid'] + + act = activations['act'] + attention_module = InvariantPointAttention( + self.config, self.global_config) + # Attention + act += attention_module( + inputs_1d=act, + inputs_2d=static_feat_2d, + mask=sequence_mask, + rigid=rigid) + + safe_key, *sub_keys = safe_key.split(3) + sub_keys = iter(sub_keys) + act = safe_dropout_fn(act, next(sub_keys)) + act = common_modules.LayerNorm( + axis=-1, + create_scale=True, + create_offset=True, + name='attention_layer_norm')( + act) + final_init = 'zeros' if self.global_config.zero_init else 'linear' + + # Transition + input_act = act + for i in range(c.num_layer_in_transition): + init = 'relu' if i < c.num_layer_in_transition - 1 else final_init + act = common_modules.Linear( + c.num_channel, + initializer=init, + name='transition')( + act) + if i < c.num_layer_in_transition - 1: + act = jax.nn.relu(act) + act += input_act + act = safe_dropout_fn(act, next(sub_keys)) + act = common_modules.LayerNorm( + axis=-1, + create_scale=True, + create_offset=True, + name='transition_layer_norm')(act) + if update_rigid: + # Rigid update + rigid_update = QuatRigid( + self.global_config, init=final_init)( + act) + rigid = rigid @ rigid_update + + sc = MultiRigidSidechain(c.sidechain, self.global_config)( + rigid.scale_translation(c.position_scale), [act, initial_act], aatype) + + outputs = {'rigid': rigid, 'sc': sc} + + rotation = rigid.rotation #jax.tree_map(jax.lax.stop_gradient, rigid.rotation) + rigid = geometry.Rigid3Array(rotation, rigid.translation) + + new_activations = { + 'act': act, + 'rigid': rigid + } + return new_activations, outputs + + +def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray], + batch: Mapping[str, jnp.ndarray], + config: ml_collections.ConfigDict, + global_config: ml_collections.ConfigDict, + safe_key: prng.SafeKey + ) -> Dict[str, Any]: + """Generate predicted Rigid's for a single chain. + + This is the main part of the iterative fold head - it iteratively applies + folding to produce a set of predicted residue positions. + + Args: + representations: Embeddings dictionary. + batch: Batch dictionary. + config: config for the iterative fold head. + global_config: global config. + safe_key: A prng.SafeKey object that wraps a PRNG key. + + Returns: + A dictionary containing residue Rigid's and sidechain positions. + """ + c = config + sequence_mask = batch['seq_mask'][:, None] + act = common_modules.LayerNorm( + axis=-1, create_scale=True, create_offset=True, name='single_layer_norm')( + representations['single']) + + initial_act = act + act = common_modules.Linear( + c.num_channel, name='initial_projection')(act) + + if "initial_atom_pos" in batch: + atom = residue_constants.atom_order + atom_pos = batch["initial_atom_pos"] + if global_config.bfloat16: atom_pos = atom_pos.astype(jnp.float32) + atom_pos = geometry.Vec3Array.from_array(atom_pos) + rigid = all_atom_multimer.make_transform_from_reference( + a_xyz=atom_pos[:, atom["N"]], + b_xyz=atom_pos[:, atom["CA"]], + c_xyz=atom_pos[:, atom["C"]]).scale_translation(1/c.position_scale) + + else: + # Sequence Mask has extra 1 at the end. + rigid = geometry.Rigid3Array.identity(sequence_mask.shape[:-1]) + + fold_iteration = FoldIteration(c, global_config, name='fold_iteration') + + assert len(batch['seq_mask'].shape) == 1 + + activations = { + 'act': + act, + 'rigid': + rigid + } + + act_2d = common_modules.LayerNorm( + axis=-1, + create_scale=True, + create_offset=True, + name='pair_layer_norm')( + representations['pair']) + + outputs = [] + def fold_iter(act, key): + act, out = fold_iteration( + act, + initial_act=initial_act, + static_feat_2d=act_2d, + aatype=batch['aatype'], + safe_key=prng.SafeKey(key), + sequence_mask=sequence_mask, + update_rigid=True, + use_dropout=batch["use_dropout"]) + return act, out + + keys = jax.random.split(safe_key.get(), c.num_layer) + activations, output = hk.scan(fold_iter, activations, keys) + output['act'] = activations['act'] + return output + + +class StructureModule(hk.Module): + """StructureModule as a network head. + + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" + """ + + def __init__(self, + config: ml_collections.ConfigDict, + global_config: ml_collections.ConfigDict, + name: str = 'structure_module'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, + representations: Mapping[str, jnp.ndarray], + batch: Mapping[str, Any], + safe_key: Optional[prng.SafeKey] = None, + ) -> Dict[str, Any]: + c = self.config + ret = {} + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + output = generate_monomer_rigids( + representations=representations, + batch=batch, + config=self.config, + global_config=self.global_config, + safe_key=safe_key) + + ret['traj'] = output['rigid'].scale_translation(c.position_scale).to_array() + ret['sidechains'] = output['sc'] + ret['sidechains']['atom_pos'] = ret['sidechains']['atom_pos'].to_array() + ret['sidechains']['frames'] = ret['sidechains']['frames'].to_array() + if 'local_atom_pos' in ret['sidechains']: + ret['sidechains']['local_atom_pos'] = ret['sidechains'][ + 'local_atom_pos'].to_array() + ret['sidechains']['local_frames'] = ret['sidechains'][ + 'local_frames'].to_array() + + aatype = batch['aatype'] + seq_mask = batch['seq_mask'] + + atom14_pred_mask = all_atom_multimer.get_atom14_mask( + aatype) * seq_mask[:, None] + atom14_pred_positions = output['sc']['atom_pos'][-1] + ret['final_atom14_positions'] = atom14_pred_positions # (N, 14, 3) + ret['final_atom14_mask'] = atom14_pred_mask # (N, 14) + + atom37_mask = all_atom_multimer.get_atom37_mask(aatype) * seq_mask[:, None] + atom37_pred_positions = all_atom_multimer.atom14_to_atom37( + atom14_pred_positions, aatype) + atom37_pred_positions *= atom37_mask[:, :, None] + ret['final_atom_positions'] = atom37_pred_positions # (N, 37, 3) + ret['final_atom_mask'] = atom37_mask # (N, 37) + ret['final_rigids'] = ret['traj'][-1] + + ret['act'] = output['act'] + + return ret + + +def compute_atom14_gt( + aatype: jnp.ndarray, + all_atom_positions: geometry.Vec3Array, + all_atom_mask: jnp.ndarray, + pred_pos: geometry.Vec3Array +) -> Tuple[geometry.Vec3Array, jnp.ndarray, jnp.ndarray]: + """Find atom14 positions, this includes finding the correct renaming.""" + gt_positions, gt_mask = all_atom_multimer.atom37_to_atom14( + aatype, all_atom_positions, + all_atom_mask) + alt_gt_positions, alt_gt_mask = all_atom_multimer.get_alt_atom14( + aatype, gt_positions, gt_mask) + atom_is_ambiguous = all_atom_multimer.get_atom14_is_ambiguous(aatype) + + alt_naming_is_better = all_atom_multimer.find_optimal_renaming( + gt_positions=gt_positions, + alt_gt_positions=alt_gt_positions, + atom_is_ambiguous=atom_is_ambiguous, + gt_exists=gt_mask, + pred_positions=pred_pos) + + use_alt = alt_naming_is_better[:, None] + + gt_mask = (1. - use_alt) * gt_mask + use_alt * alt_gt_mask + gt_positions = (1. - use_alt) * gt_positions + use_alt * alt_gt_positions + + return gt_positions, alt_gt_mask, alt_naming_is_better + + +def backbone_loss(gt_rigid: geometry.Rigid3Array, + gt_frames_mask: jnp.ndarray, + gt_positions_mask: jnp.ndarray, + target_rigid: geometry.Rigid3Array, + config: ml_collections.ConfigDict, + pair_mask: jnp.ndarray + ) -> Tuple[Float, jnp.ndarray]: + """Backbone FAPE Loss.""" + loss_fn = functools.partial( + all_atom_multimer.frame_aligned_point_error, + l1_clamp_distance=config.atom_clamp_distance, + length_scale=config.loss_unit_distance) + + loss_fn = jax.vmap(loss_fn, (0, None, None, 0, None, None, None)) + fape = loss_fn(target_rigid, gt_rigid, gt_frames_mask, + target_rigid.translation, gt_rigid.translation, + gt_positions_mask, pair_mask) + + return jnp.mean(fape), fape[-1] + + +def compute_frames( + aatype: jnp.ndarray, + all_atom_positions: geometry.Vec3Array, + all_atom_mask: jnp.ndarray, + use_alt: jnp.ndarray + ) -> Tuple[geometry.Rigid3Array, jnp.ndarray]: + """Compute Frames from all atom positions. + + Args: + aatype: array of aatypes, int of [N] + all_atom_positions: Vector of all atom positions, shape [N, 37] + all_atom_mask: mask, shape [N] + use_alt: whether to use alternative orientation for ambiguous aatypes + shape [N] + Returns: + Rigid corresponding to Frames w shape [N, 8], + mask which Rigids are present w shape [N, 8] + """ + frames_batch = all_atom_multimer.atom37_to_frames(aatype, all_atom_positions, + all_atom_mask) + gt_frames = frames_batch['rigidgroups_gt_frames'] + alt_gt_frames = frames_batch['rigidgroups_alt_gt_frames'] + use_alt = use_alt[:, None] + + renamed_gt_frames = jax.tree_map( + lambda x, y: (1. - use_alt) * x + use_alt * y, gt_frames, alt_gt_frames) + + return renamed_gt_frames, frames_batch['rigidgroups_gt_exists'] + + +def sidechain_loss(gt_frames: geometry.Rigid3Array, + gt_frames_mask: jnp.ndarray, + gt_positions: geometry.Vec3Array, + gt_mask: jnp.ndarray, + pred_frames: geometry.Rigid3Array, + pred_positions: geometry.Vec3Array, + config: ml_collections.ConfigDict + ) -> Dict[str, jnp.ndarray]: + """Sidechain Loss using cleaned up rigids.""" + + flat_gt_frames = jax.tree_map(jnp.ravel, gt_frames) + flat_frames_mask = jnp.ravel(gt_frames_mask) + + flat_gt_positions = jax.tree_map(jnp.ravel, gt_positions) + flat_positions_mask = jnp.ravel(gt_mask) + + # Compute frame_aligned_point_error score for the final layer. + def _slice_last_layer_and_flatten(x): + return jnp.ravel(x[-1]) + + flat_pred_frames = jax.tree_map(_slice_last_layer_and_flatten, pred_frames) + flat_pred_positions = jax.tree_map(_slice_last_layer_and_flatten, + pred_positions) + fape = all_atom_multimer.frame_aligned_point_error( + pred_frames=flat_pred_frames, + target_frames=flat_gt_frames, + frames_mask=flat_frames_mask, + pred_positions=flat_pred_positions, + target_positions=flat_gt_positions, + positions_mask=flat_positions_mask, + pair_mask=None, + length_scale=config.sidechain.loss_unit_distance, + l1_clamp_distance=config.sidechain.atom_clamp_distance) + + return { + 'fape': fape, + 'loss': fape} + + +def structural_violation_loss(mask: jnp.ndarray, + violations: Mapping[str, Float], + config: ml_collections.ConfigDict + ) -> Float: + """Computes Loss for structural Violations.""" + # Put all violation losses together to one large loss. + num_atoms = jnp.sum(mask).astype(jnp.float32) + 1e-6 + between_residues = violations['between_residues'] + within_residues = violations['within_residues'] + return (config.structural_violation_loss_weight * + (between_residues['bonds_c_n_loss_mean'] + + between_residues['angles_ca_c_n_loss_mean'] + + between_residues['angles_c_n_ca_loss_mean'] + + jnp.sum(between_residues['clashes_per_atom_loss_sum'] + + within_residues['per_atom_loss_sum']) / num_atoms + )) + + +def find_structural_violations( + aatype: jnp.ndarray, + residue_index: jnp.ndarray, + mask: jnp.ndarray, + pred_positions: geometry.Vec3Array, # (N, 14) + config: ml_collections.ConfigDict, + asym_id: jnp.ndarray, + ) -> Dict[str, Any]: + """Computes several checks for structural Violations.""" + + # Compute between residue backbone violations of bonds and angles. + connection_violations = all_atom_multimer.between_residue_bond_loss( + pred_atom_positions=pred_positions, + pred_atom_mask=mask.astype(jnp.float32), + residue_index=residue_index.astype(jnp.float32), + aatype=aatype, + tolerance_factor_soft=config.violation_tolerance_factor, + tolerance_factor_hard=config.violation_tolerance_factor) + + # Compute the van der Waals radius for every atom + # (the first letter of the atom name is the element type). + # shape (N, 14) + atomtype_radius = jnp.array([ + residue_constants.van_der_waals_radius[name[0]] + for name in residue_constants.atom_types + ]) + residx_atom14_to_atom37 = all_atom_multimer.get_atom14_to_atom37_map(aatype) + atom_radius = mask * utils.batched_gather(atomtype_radius, + residx_atom14_to_atom37) + + # Compute the between residue clash loss. + between_residue_clashes = all_atom_multimer.between_residue_clash_loss( + pred_positions=pred_positions, + atom_exists=mask, + atom_radius=atom_radius, + residue_index=residue_index, + overlap_tolerance_soft=config.clash_overlap_tolerance, + overlap_tolerance_hard=config.clash_overlap_tolerance, + asym_id=asym_id) + + # Compute all within-residue violations (clashes, + # bond length and angle violations). + restype_atom14_bounds = residue_constants.make_atom14_dists_bounds( + overlap_tolerance=config.clash_overlap_tolerance, + bond_length_tolerance_factor=config.violation_tolerance_factor) + dists_lower_bound = utils.batched_gather(restype_atom14_bounds['lower_bound'], + aatype) + dists_upper_bound = utils.batched_gather(restype_atom14_bounds['upper_bound'], + aatype) + within_residue_violations = all_atom_multimer.within_residue_violations( + pred_positions=pred_positions, + atom_exists=mask, + dists_lower_bound=dists_lower_bound, + dists_upper_bound=dists_upper_bound, + tighten_bounds_for_loss=0.0) + + # Combine them to a single per-residue violation mask (used later for LDDT). + per_residue_violations_mask = jnp.max(jnp.stack([ + connection_violations['per_residue_violation_mask'], + jnp.max(between_residue_clashes['per_atom_clash_mask'], axis=-1), + jnp.max(within_residue_violations['per_atom_violations'], + axis=-1)]), axis=0) + + return { + 'between_residues': { + 'bonds_c_n_loss_mean': + connection_violations['c_n_loss_mean'], # () + 'angles_ca_c_n_loss_mean': + connection_violations['ca_c_n_loss_mean'], # () + 'angles_c_n_ca_loss_mean': + connection_violations['c_n_ca_loss_mean'], # () + 'connections_per_residue_loss_sum': + connection_violations['per_residue_loss_sum'], # (N) + 'connections_per_residue_violation_mask': + connection_violations['per_residue_violation_mask'], # (N) + 'clashes_mean_loss': + between_residue_clashes['mean_loss'], # () + 'clashes_per_atom_loss_sum': + between_residue_clashes['per_atom_loss_sum'], # (N, 14) + 'clashes_per_atom_clash_mask': + between_residue_clashes['per_atom_clash_mask'], # (N, 14) + }, + 'within_residues': { + 'per_atom_loss_sum': + within_residue_violations['per_atom_loss_sum'], # (N, 14) + 'per_atom_violations': + within_residue_violations['per_atom_violations'], # (N, 14), + }, + 'total_per_residue_violations_mask': + per_residue_violations_mask, # (N) + } + + +def compute_violation_metrics( + residue_index: jnp.ndarray, + mask: jnp.ndarray, + seq_mask: jnp.ndarray, + pred_positions: geometry.Vec3Array, # (N, 14) + violations: Mapping[str, jnp.ndarray], +) -> Dict[str, jnp.ndarray]: + """Compute several metrics to assess the structural violations.""" + ret = {} + between_residues = violations['between_residues'] + within_residues = violations['within_residues'] + extreme_ca_ca_violations = all_atom_multimer.extreme_ca_ca_distance_violations( + positions=pred_positions, + mask=mask.astype(jnp.float32), + residue_index=residue_index.astype(jnp.float32)) + ret['violations_extreme_ca_ca_distance'] = extreme_ca_ca_violations + ret['violations_between_residue_bond'] = utils.mask_mean( + mask=seq_mask, + value=between_residues['connections_per_residue_violation_mask']) + ret['violations_between_residue_clash'] = utils.mask_mean( + mask=seq_mask, + value=jnp.max(between_residues['clashes_per_atom_clash_mask'], axis=-1)) + ret['violations_within_residue'] = utils.mask_mean( + mask=seq_mask, + value=jnp.max(within_residues['per_atom_violations'], axis=-1)) + ret['violations_per_residue'] = utils.mask_mean( + mask=seq_mask, value=violations['total_per_residue_violations_mask']) + return ret + + +def supervised_chi_loss( + sequence_mask: jnp.ndarray, + target_chi_mask: jnp.ndarray, + aatype: jnp.ndarray, + target_chi_angles: jnp.ndarray, + pred_angles: jnp.ndarray, + unnormed_angles: jnp.ndarray, + config: ml_collections.ConfigDict) -> Tuple[Float, Float, Float]: + """Computes loss for direct chi angle supervision.""" + eps = 1e-6 + chi_mask = target_chi_mask.astype(jnp.float32) + + pred_angles = pred_angles[:, :, 3:] + + residue_type_one_hot = jax.nn.one_hot( + aatype, residue_constants.restype_num + 1, dtype=jnp.float32)[None] + chi_pi_periodic = jnp.einsum('ijk, kl->ijl', residue_type_one_hot, + jnp.asarray(residue_constants.chi_pi_periodic)) + + true_chi = target_chi_angles[None] + sin_true_chi = jnp.sin(true_chi) + cos_true_chi = jnp.cos(true_chi) + sin_cos_true_chi = jnp.stack([sin_true_chi, cos_true_chi], axis=-1) + + # This is -1 if chi is pi periodic and +1 if it's 2 pi periodic + shifted_mask = (1 - 2 * chi_pi_periodic)[..., None] + sin_cos_true_chi_shifted = shifted_mask * sin_cos_true_chi + + sq_chi_error = jnp.sum( + squared_difference(sin_cos_true_chi, pred_angles), -1) + sq_chi_error_shifted = jnp.sum( + squared_difference(sin_cos_true_chi_shifted, pred_angles), -1) + sq_chi_error = jnp.minimum(sq_chi_error, sq_chi_error_shifted) + + sq_chi_loss = utils.mask_mean(mask=chi_mask[None], value=sq_chi_error) + angle_norm = jnp.sqrt(jnp.sum(jnp.square(unnormed_angles), axis=-1) + eps) + norm_error = jnp.abs(angle_norm - 1.) + angle_norm_loss = utils.mask_mean(mask=sequence_mask[None, :, None], + value=norm_error) + loss = (config.chi_weight * sq_chi_loss + + config.angle_norm_weight * angle_norm_loss) + return loss, sq_chi_loss, angle_norm_loss + + +def l2_normalize(x: jnp.ndarray, + axis: int = -1, + epsilon: float = 1e-12 + ) -> jnp.ndarray: + return x / jnp.sqrt( + jnp.maximum(jnp.sum(x**2, axis=axis, keepdims=True), epsilon)) + + +def get_renamed_chi_angles(aatype: jnp.ndarray, + chi_angles: jnp.ndarray, + alt_is_better: jnp.ndarray + ) -> jnp.ndarray: + """Return renamed chi angles.""" + chi_angle_is_ambiguous = utils.batched_gather( + jnp.array(residue_constants.chi_pi_periodic, dtype=jnp.float32), aatype) + alt_chi_angles = chi_angles + np.pi * chi_angle_is_ambiguous + # Map back to [-pi, pi]. + alt_chi_angles = alt_chi_angles - 2 * np.pi * (alt_chi_angles > np.pi).astype( + jnp.float32) + alt_is_better = alt_is_better[:, None] + return (1. - alt_is_better) * chi_angles + alt_is_better * alt_chi_angles + + +class MultiRigidSidechain(hk.Module): + """Class to make side chain atoms.""" + + def __init__(self, + config: ml_collections.ConfigDict, + global_config: ml_collections.ConfigDict, + name: str = 'rigid_sidechain'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, + rigid: geometry.Rigid3Array, + representations_list: Iterable[jnp.ndarray], + aatype: jnp.ndarray + ) -> Dict[str, Any]: + """Predict sidechains using multi-rigid representations. + + Args: + rigid: The Rigid's for each residue (translations in angstoms) + representations_list: A list of activations to predict sidechains from. + aatype: amino acid types. + + Returns: + dict containing atom positions and frames (in angstrom) + """ + act = [ + common_modules.Linear( # pylint: disable=g-complex-comprehension + self.config.num_channel, + name='input_projection')(jax.nn.relu(x)) + for x in representations_list] + # Sum the activation list (equivalent to concat then Conv1D) + act = sum(act) + + final_init = 'zeros' if self.global_config.zero_init else 'linear' + + # Mapping with some residual blocks. + for _ in range(self.config.num_residual_block): + old_act = act + act = common_modules.Linear( + self.config.num_channel, + initializer='relu', + name='resblock1')( + jax.nn.relu(act)) + act = common_modules.Linear( + self.config.num_channel, + initializer=final_init, + name='resblock2')( + jax.nn.relu(act)) + act += old_act + + # Map activations to torsion angles. + # [batch_size, num_res, 14] + num_res = act.shape[0] + unnormalized_angles = common_modules.Linear( + 14, name='unnormalized_angles')( + jax.nn.relu(act)) + unnormalized_angles = jnp.reshape( + unnormalized_angles, [num_res, 7, 2]) + angles = l2_normalize(unnormalized_angles, axis=-1) + + outputs = { + 'angles_sin_cos': angles, # jnp.ndarray (N, 7, 2) + 'unnormalized_angles_sin_cos': + unnormalized_angles, # jnp.ndarray (N, 7, 2) + } + + # Map torsion angles to frames. + # geometry.Rigid3Array with shape (N, 8) + all_frames_to_global = all_atom_multimer.torsion_angles_to_frames( + aatype, + rigid, + angles) + + # Use frames and literature positions to create the final atom coordinates. + # geometry.Vec3Array with shape (N, 14) + pred_positions = all_atom_multimer.frames_and_literature_positions_to_atom14_pos( + aatype, all_frames_to_global) + + outputs.update({ + 'atom_pos': pred_positions, # geometry.Vec3Array (N, 14) + 'frames': all_frames_to_global, # geometry.Rigid3Array (N, 8) + }) + return outputs diff --git a/build/lib/colabdesign/af/alphafold/model/geometry/__init__.py b/build/lib/colabdesign/af/alphafold/model/geometry/__init__.py new file mode 100644 index 00000000..761b886e --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/geometry/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Geometry Module.""" + +from colabdesign.af.alphafold.model.geometry import rigid_matrix_vector +from colabdesign.af.alphafold.model.geometry import rotation_matrix +from colabdesign.af.alphafold.model.geometry import struct_of_array +from colabdesign.af.alphafold.model.geometry import vector + +Rot3Array = rotation_matrix.Rot3Array +Rigid3Array = rigid_matrix_vector.Rigid3Array + +StructOfArray = struct_of_array.StructOfArray + +Vec3Array = vector.Vec3Array +square_euclidean_distance = vector.square_euclidean_distance +euclidean_distance = vector.euclidean_distance +dihedral_angle = vector.dihedral_angle +dot = vector.dot +cross = vector.cross diff --git a/build/lib/colabdesign/af/alphafold/model/geometry/rigid_matrix_vector.py b/build/lib/colabdesign/af/alphafold/model/geometry/rigid_matrix_vector.py new file mode 100644 index 00000000..4c7bb105 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/geometry/rigid_matrix_vector.py @@ -0,0 +1,106 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rigid3Array Transformations represented by a Matrix and a Vector.""" + +from __future__ import annotations +from typing import Union + +from colabdesign.af.alphafold.model.geometry import rotation_matrix +from colabdesign.af.alphafold.model.geometry import struct_of_array +from colabdesign.af.alphafold.model.geometry import vector +import jax +import jax.numpy as jnp + +Float = Union[float, jnp.ndarray] + +VERSION = '0.1' + + +@struct_of_array.StructOfArray(same_dtype=True) +class Rigid3Array: + """Rigid Transformation, i.e. element of special euclidean group.""" + + rotation: rotation_matrix.Rot3Array + translation: vector.Vec3Array + + def __matmul__(self, other: Rigid3Array) -> Rigid3Array: + new_rotation = self.rotation @ other.rotation + new_translation = self.apply_to_point(other.translation) + return Rigid3Array(new_rotation, new_translation) + + def inverse(self) -> Rigid3Array: + """Return Rigid3Array corresponding to inverse transform.""" + inv_rotation = self.rotation.inverse() + inv_translation = inv_rotation.apply_to_point(-self.translation) + return Rigid3Array(inv_rotation, inv_translation) + + def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Apply Rigid3Array transform to point.""" + return self.rotation.apply_to_point(point) + self.translation + + def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Apply inverse Rigid3Array transform to point.""" + new_point = point - self.translation + return self.rotation.apply_inverse_to_point(new_point) + + def compose_rotation(self, other_rotation): + rot = self.rotation @ other_rotation + trans = jax.tree_map(lambda x: jnp.broadcast_to(x, rot.shape), + self.translation) + return Rigid3Array(rot, trans) + + @classmethod + def identity(cls, shape, dtype=jnp.float32) -> Rigid3Array: + """Return identity Rigid3Array of given shape.""" + return cls( + rotation_matrix.Rot3Array.identity(shape, dtype=dtype), + vector.Vec3Array.zeros(shape, dtype=dtype)) # pytype: disable=wrong-arg-count # trace-all-classes + + def scale_translation(self, factor: Float) -> Rigid3Array: + """Scale translation in Rigid3Array by 'factor'.""" + return Rigid3Array(self.rotation, self.translation * factor) + + def to_array(self): + rot_array = self.rotation.to_array() + vec_array = self.translation.to_array() + return jnp.concatenate([rot_array, vec_array[..., None]], axis=-1) + + @classmethod + def from_array(cls, array): + rot = rotation_matrix.Rot3Array.from_array(array[..., :3]) + vec = vector.Vec3Array.from_array(array[..., -1]) + return cls(rot, vec) # pytype: disable=wrong-arg-count # trace-all-classes + + @classmethod + def from_array4x4(cls, array: jnp.ndarray) -> Rigid3Array: + """Construct Rigid3Array from homogeneous 4x4 array.""" + assert array.shape[-1] == 4 + assert array.shape[-2] == 4 + rotation = rotation_matrix.Rot3Array( + array[..., 0, 0], array[..., 0, 1], array[..., 0, 2], + array[..., 1, 0], array[..., 1, 1], array[..., 1, 2], + array[..., 2, 0], array[..., 2, 1], array[..., 2, 2] + ) + translation = vector.Vec3Array( + array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]) + return cls(rotation, translation) # pytype: disable=wrong-arg-count # trace-all-classes + + def __getstate__(self): + return (VERSION, (self.rotation, self.translation)) + + def __setstate__(self, state): + version, (rot, trans) = state + del version + object.__setattr__(self, 'rotation', rot) + object.__setattr__(self, 'translation', trans) diff --git a/build/lib/colabdesign/af/alphafold/model/geometry/rotation_matrix.py b/build/lib/colabdesign/af/alphafold/model/geometry/rotation_matrix.py new file mode 100644 index 00000000..846ea5d2 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/geometry/rotation_matrix.py @@ -0,0 +1,157 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rot3Array Matrix Class.""" + +from __future__ import annotations +import dataclasses + +from colabdesign.af.alphafold.model.geometry import struct_of_array +from colabdesign.af.alphafold.model.geometry import utils +from colabdesign.af.alphafold.model.geometry import vector +import jax +import jax.numpy as jnp +import numpy as np + +COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz'] + +VERSION = '0.1' + + +@struct_of_array.StructOfArray(same_dtype=True) +class Rot3Array: + """Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays.""" + + xx: jnp.ndarray = dataclasses.field(metadata={'dtype': jnp.float32}) + xy: jnp.ndarray + xz: jnp.ndarray + yx: jnp.ndarray + yy: jnp.ndarray + yz: jnp.ndarray + zx: jnp.ndarray + zy: jnp.ndarray + zz: jnp.ndarray + + __array_ufunc__ = None + + def inverse(self) -> Rot3Array: + """Returns inverse of Rot3Array.""" + return Rot3Array(self.xx, self.yx, self.zx, + self.xy, self.yy, self.zy, + self.xz, self.yz, self.zz) + + def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Applies Rot3Array to point.""" + return vector.Vec3Array( + self.xx * point.x + self.xy * point.y + self.xz * point.z, + self.yx * point.x + self.yy * point.y + self.yz * point.z, + self.zx * point.x + self.zy * point.y + self.zz * point.z) + + def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Applies inverse Rot3Array to point.""" + return self.inverse().apply_to_point(point) + + def __matmul__(self, other: Rot3Array) -> Rot3Array: + """Composes two Rot3Arrays.""" + c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx)) + c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy)) + c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz)) + return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) + + @classmethod + def identity(cls, shape, dtype=jnp.float32) -> Rot3Array: + """Returns identity of given shape.""" + ones = jnp.ones(shape, dtype=dtype) + zeros = jnp.zeros(shape, dtype=dtype) + return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) # pytype: disable=wrong-arg-count # trace-all-classes + + @classmethod + def from_two_vectors(cls, e0: vector.Vec3Array, + e1: vector.Vec3Array) -> Rot3Array: + """Construct Rot3Array from two Vectors. + + Rot3Array is constructed such that in the corresponding frame 'e0' lies on + the positive x-Axis and 'e1' lies in the xy plane with positive sign of y. + + Args: + e0: Vector + e1: Vector + Returns: + Rot3Array + """ + # Normalize the unit vector for the x-axis, e0. + e0 = e0.normalized() + # make e1 perpendicular to e0. + c = e1.dot(e0) + e1 = (e1 - c * e0).normalized() + # Compute e2 as cross product of e0 and e1. + e2 = e0.cross(e1) + return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) # pytype: disable=wrong-arg-count # trace-all-classes + + @classmethod + def from_array(cls, array: jnp.ndarray) -> Rot3Array: + """Construct Rot3Array Matrix from array of shape. [..., 3, 3].""" + unstacked = utils.unstack(array, axis=-2) + unstacked = sum([utils.unstack(x, axis=-1) for x in unstacked], []) + return cls(*unstacked) + + def to_array(self) -> jnp.ndarray: + """Convert Rot3Array to array of shape [..., 3, 3].""" + return jnp.stack( + [jnp.stack([self.xx, self.xy, self.xz], axis=-1), + jnp.stack([self.yx, self.yy, self.yz], axis=-1), + jnp.stack([self.zx, self.zy, self.zz], axis=-1)], + axis=-2) + + @classmethod + def from_quaternion(cls, + w: jnp.ndarray, + x: jnp.ndarray, + y: jnp.ndarray, + z: jnp.ndarray, + normalize: bool = True, + epsilon: float = 1e-6) -> Rot3Array: + """Construct Rot3Array from components of quaternion.""" + if normalize: + inv_norm = jax.lax.rsqrt(jnp.maximum(epsilon, w**2 + x**2 + y**2 + z**2)) + w *= inv_norm + x *= inv_norm + y *= inv_norm + z *= inv_norm + xx = 1 - 2 * (jnp.square(y) + jnp.square(z)) + xy = 2 * (x * y - w * z) + xz = 2 * (x * z + w * y) + yx = 2 * (x * y + w * z) + yy = 1 - 2 * (jnp.square(x) + jnp.square(z)) + yz = 2 * (y * z - w * x) + zx = 2 * (x * z - w * y) + zy = 2 * (y * z + w * x) + zz = 1 - 2 * (jnp.square(x) + jnp.square(y)) + return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) # pytype: disable=wrong-arg-count # trace-all-classes + + @classmethod + def random_uniform(cls, key, shape, dtype=jnp.float32) -> Rot3Array: + """Samples uniform random Rot3Array according to Haar Measure.""" + quat_array = jax.random.normal(key, tuple(shape) + (4,), dtype=dtype) + quats = utils.unstack(quat_array) + return cls.from_quaternion(*quats) + + def __getstate__(self): + return (VERSION, + [np.asarray(getattr(self, field)) for field in COMPONENTS]) + + def __setstate__(self, state): + version, state = state + del version + for i, field in enumerate(COMPONENTS): + object.__setattr__(self, field, state[i]) diff --git a/build/lib/colabdesign/af/alphafold/model/geometry/struct_of_array.py b/build/lib/colabdesign/af/alphafold/model/geometry/struct_of_array.py new file mode 100644 index 00000000..97a89fd4 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/geometry/struct_of_array.py @@ -0,0 +1,220 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Class decorator to represent (nested) struct of arrays.""" + +import dataclasses + +import jax + + +def get_item(instance, key): + sliced = {} + for field in get_array_fields(instance): + num_trailing_dims = field.metadata.get('num_trailing_dims', 0) + this_key = key + if isinstance(key, tuple) and Ellipsis in this_key: + this_key += (slice(None),) * num_trailing_dims + sliced[field.name] = getattr(instance, field.name)[this_key] + return dataclasses.replace(instance, **sliced) + + +@property +def get_shape(instance): + """Returns Shape for given instance of dataclass.""" + first_field = dataclasses.fields(instance)[0] + num_trailing_dims = first_field.metadata.get('num_trailing_dims', None) + value = getattr(instance, first_field.name) + if num_trailing_dims: + return value.shape[:-num_trailing_dims] + else: + return value.shape + + +def get_len(instance): + """Returns length for given instance of dataclass.""" + shape = instance.shape + if shape: + return shape[0] + else: + raise TypeError('len() of unsized object') # Match jax.numpy behavior. + + +@property +def get_dtype(instance): + """Returns Dtype for given instance of dataclass.""" + fields = dataclasses.fields(instance) + sets_dtype = [ + field.name for field in fields if field.metadata.get('sets_dtype', False) + ] + if sets_dtype: + assert len(sets_dtype) == 1, 'at most field can set dtype' + field_value = getattr(instance, sets_dtype[0]) + elif instance.same_dtype: + field_value = getattr(instance, fields[0].name) + else: + # Should this be Value Error? + raise AttributeError('Trying to access Dtype on Struct of Array without' + 'either "same_dtype" or field setting dtype') + + if hasattr(field_value, 'dtype'): + return field_value.dtype + else: + # Should this be Value Error? + raise AttributeError(f'field_value {field_value} does not have dtype') + + +def replace(instance, **kwargs): + return dataclasses.replace(instance, **kwargs) + + +def post_init(instance): + """Validate instance has same shapes & dtypes.""" + array_fields = get_array_fields(instance) + arrays = list(get_array_fields(instance, return_values=True).values()) + first_field = array_fields[0] + # These slightly weird constructions about checking whether the leaves are + # actual arrays is since e.g. vmap internally relies on being able to + # construct pytree's with object() as leaves, this would break the checking + # as such we are only validating the object when the entries in the dataclass + # Are arrays or other dataclasses of arrays. + try: + dtype = instance.dtype + except AttributeError: + dtype = None + if dtype is not None: + first_shape = instance.shape + for array, field in zip(arrays, array_fields): + field_shape = array.shape + num_trailing_dims = field.metadata.get('num_trailing_dims', None) + if num_trailing_dims: + array_shape = array.shape + field_shape = array_shape[:-num_trailing_dims] + msg = (f'field {field} should have number of trailing dims' + ' {num_trailing_dims}') + assert len(array_shape) == len(first_shape) + num_trailing_dims, msg + else: + field_shape = array.shape + + shape_msg = (f"Stripped Shape {field_shape} of field {field} doesn't " + f"match shape {first_shape} of field {first_field}") + assert field_shape == first_shape, shape_msg + + field_dtype = array.dtype + + allowed_metadata_dtypes = field.metadata.get('allowed_dtypes', []) + if allowed_metadata_dtypes: + msg = f'Dtype is {field_dtype} but must be in {allowed_metadata_dtypes}' + assert field_dtype in allowed_metadata_dtypes, msg + + if 'dtype' in field.metadata: + target_dtype = field.metadata['dtype'] + else: + target_dtype = dtype + + msg = f'Dtype is {field_dtype} but must be {target_dtype}' + assert field_dtype == target_dtype, msg + + +def flatten(instance): + """Flatten Struct of Array instance.""" + array_likes = list(get_array_fields(instance, return_values=True).values()) + flat_array_likes = [] + inner_treedefs = [] + num_arrays = [] + for array_like in array_likes: + flat_array_like, inner_treedef = jax.tree_flatten(array_like) + inner_treedefs.append(inner_treedef) + flat_array_likes += flat_array_like + num_arrays.append(len(flat_array_like)) + metadata = get_metadata_fields(instance, return_values=True) + metadata = type(instance).metadata_cls(**metadata) + return flat_array_likes, (inner_treedefs, metadata, num_arrays) + + +def make_metadata_class(cls): + metadata_fields = get_fields(cls, + lambda x: x.metadata.get('is_metadata', False)) + metadata_cls = dataclasses.make_dataclass( + cls_name='Meta' + cls.__name__, + fields=[(field.name, field.type, field) for field in metadata_fields], + frozen=True, + eq=True) + return metadata_cls + + +def get_fields(cls_or_instance, filterfn, return_values=False): + fields = dataclasses.fields(cls_or_instance) + fields = [field for field in fields if filterfn(field)] + if return_values: + return { + field.name: getattr(cls_or_instance, field.name) for field in fields + } + else: + return fields + + +def get_array_fields(cls, return_values=False): + return get_fields( + cls, + lambda x: not x.metadata.get('is_metadata', False), + return_values=return_values) + + +def get_metadata_fields(cls, return_values=False): + return get_fields( + cls, + lambda x: x.metadata.get('is_metadata', False), + return_values=return_values) + + +class StructOfArray: + """Class Decorator for Struct Of Arrays.""" + + def __init__(self, same_dtype=True): + self.same_dtype = same_dtype + + def __call__(self, cls): + cls.__array_ufunc__ = None + cls.replace = replace + cls.same_dtype = self.same_dtype + cls.dtype = get_dtype + cls.shape = get_shape + cls.__len__ = get_len + cls.__getitem__ = get_item + cls.__post_init__ = post_init + new_cls = dataclasses.dataclass(cls, frozen=True, eq=False) # pytype: disable=wrong-keyword-args + # pytree claims to require metadata to be hashable, not sure why, + # But making derived dataclass that can just hold metadata + new_cls.metadata_cls = make_metadata_class(new_cls) + + def unflatten(aux, data): + inner_treedefs, metadata, num_arrays = aux + array_fields = [field.name for field in get_array_fields(new_cls)] + value_dict = {} + array_start = 0 + for num_array, inner_treedef, array_field in zip(num_arrays, + inner_treedefs, + array_fields): + value_dict[array_field] = jax.tree_unflatten( + inner_treedef, data[array_start:array_start + num_array]) + array_start += num_array + metadata_fields = get_metadata_fields(new_cls) + for field in metadata_fields: + value_dict[field.name] = getattr(metadata, field.name) + + return new_cls(**value_dict) + + jax.tree_util.register_pytree_node( + nodetype=new_cls, flatten_func=flatten, unflatten_func=unflatten) + return new_cls diff --git a/build/lib/colabdesign/af/alphafold/model/geometry/test_utils.py b/build/lib/colabdesign/af/alphafold/model/geometry/test_utils.py new file mode 100644 index 00000000..18de0741 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/geometry/test_utils.py @@ -0,0 +1,98 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared utils for tests.""" + +import dataclasses + +from colabdesign.af.alphafold.model.geometry import rigid_matrix_vector +from colabdesign.af.alphafold.model.geometry import rotation_matrix +from colabdesign.af.alphafold.model.geometry import vector +import jax.numpy as jnp +import numpy as np + + +def assert_rotation_matrix_equal(matrix1: rotation_matrix.Rot3Array, + matrix2: rotation_matrix.Rot3Array): + for field in dataclasses.fields(rotation_matrix.Rot3Array): + field = field.name + np.testing.assert_array_equal( + getattr(matrix1, field), getattr(matrix2, field)) + + +def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array, + mat2: rotation_matrix.Rot3Array): + np.testing.assert_array_almost_equal(mat1.to_array(), mat2.to_array(), 6) + + +def assert_array_equal_to_rotation_matrix(array: jnp.ndarray, + matrix: rotation_matrix.Rot3Array): + """Check that array and Matrix match.""" + np.testing.assert_array_equal(matrix.xx, array[..., 0, 0]) + np.testing.assert_array_equal(matrix.xy, array[..., 0, 1]) + np.testing.assert_array_equal(matrix.xz, array[..., 0, 2]) + np.testing.assert_array_equal(matrix.yx, array[..., 1, 0]) + np.testing.assert_array_equal(matrix.yy, array[..., 1, 1]) + np.testing.assert_array_equal(matrix.yz, array[..., 1, 2]) + np.testing.assert_array_equal(matrix.zx, array[..., 2, 0]) + np.testing.assert_array_equal(matrix.zy, array[..., 2, 1]) + np.testing.assert_array_equal(matrix.zz, array[..., 2, 2]) + + +def assert_array_close_to_rotation_matrix(array: jnp.ndarray, + matrix: rotation_matrix.Rot3Array): + np.testing.assert_array_almost_equal(matrix.to_array(), array, 6) + + +def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array): + np.testing.assert_array_equal(vec1.x, vec2.x) + np.testing.assert_array_equal(vec1.y, vec2.y) + np.testing.assert_array_equal(vec1.z, vec2.z) + + +def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array): + np.testing.assert_allclose(vec1.x, vec2.x, atol=1e-6, rtol=0.) + np.testing.assert_allclose(vec1.y, vec2.y, atol=1e-6, rtol=0.) + np.testing.assert_allclose(vec1.z, vec2.z, atol=1e-6, rtol=0.) + + +def assert_array_close_to_vector(array: jnp.ndarray, vec: vector.Vec3Array): + np.testing.assert_allclose(vec.to_array(), array, atol=1e-6, rtol=0.) + + +def assert_array_equal_to_vector(array: jnp.ndarray, vec: vector.Vec3Array): + np.testing.assert_array_equal(vec.to_array(), array) + + +def assert_rigid_equal_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array, + rigid2: rigid_matrix_vector.Rigid3Array): + assert_rot_trans_equal_to_rigid(rigid1.rotation, rigid1.translation, rigid2) + + +def assert_rigid_close_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array, + rigid2: rigid_matrix_vector.Rigid3Array): + assert_rot_trans_close_to_rigid(rigid1.rotation, rigid1.translation, rigid2) + + +def assert_rot_trans_equal_to_rigid(rot: rotation_matrix.Rot3Array, + trans: vector.Vec3Array, + rigid: rigid_matrix_vector.Rigid3Array): + assert_rotation_matrix_equal(rot, rigid.rotation) + assert_vectors_equal(trans, rigid.translation) + + +def assert_rot_trans_close_to_rigid(rot: rotation_matrix.Rot3Array, + trans: vector.Vec3Array, + rigid: rigid_matrix_vector.Rigid3Array): + assert_rotation_matrix_close(rot, rigid.rotation) + assert_vectors_close(trans, rigid.translation) diff --git a/build/lib/colabdesign/af/alphafold/model/geometry/utils.py b/build/lib/colabdesign/af/alphafold/model/geometry/utils.py new file mode 100644 index 00000000..64c4a649 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/geometry/utils.py @@ -0,0 +1,23 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils for geometry library.""" + +from typing import List + +import jax.numpy as jnp + + +def unstack(value: jnp.ndarray, axis: int = -1) -> List[jnp.ndarray]: + return [jnp.squeeze(v, axis=axis) + for v in jnp.split(value, value.shape[axis], axis=axis)] diff --git a/build/lib/colabdesign/af/alphafold/model/geometry/vector.py b/build/lib/colabdesign/af/alphafold/model/geometry/vector.py new file mode 100644 index 00000000..8b5e653b --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/geometry/vector.py @@ -0,0 +1,217 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Vec3Array Class.""" + +from __future__ import annotations +import dataclasses +from typing import Union + +from colabdesign.af.alphafold.model.geometry import struct_of_array +from colabdesign.af.alphafold.model.geometry import utils +import jax +import jax.numpy as jnp +import numpy as np + +Float = Union[float, jnp.ndarray] + +VERSION = '0.1' + + +@struct_of_array.StructOfArray(same_dtype=True) +class Vec3Array: + """Vec3Array in 3 dimensional Space implemented as struct of arrays. + + This is done in order to improve performance and precision. + On TPU small matrix multiplications are very suboptimal and will waste large + compute ressources, furthermore any matrix multiplication on tpu happen in + mixed bfloat16/float32 precision, which is often undesirable when handling + physical coordinates. + In most cases this will also be faster on cpu's/gpu's since it allows for + easier use of vector instructions. + """ + + x: jnp.ndarray = dataclasses.field(metadata={'dtype': jnp.float32}) + y: jnp.ndarray + z: jnp.ndarray + + def __post_init__(self): + if hasattr(self.x, 'dtype'): + assert self.x.dtype == self.y.dtype + assert self.x.dtype == self.z.dtype + assert all([x == y for x, y in zip(self.x.shape, self.y.shape)]) + assert all([x == z for x, z in zip(self.x.shape, self.z.shape)]) + + def __add__(self, other: Vec3Array) -> Vec3Array: + return jax.tree_map(lambda x, y: x + y, self, other) + + def __sub__(self, other: Vec3Array) -> Vec3Array: + return jax.tree_map(lambda x, y: x - y, self, other) + + def __mul__(self, other: Float) -> Vec3Array: + return jax.tree_map(lambda x: x * other, self) + + def __rmul__(self, other: Float) -> Vec3Array: + return self * other + + def __truediv__(self, other: Float) -> Vec3Array: + return jax.tree_map(lambda x: x / other, self) + + def __neg__(self) -> Vec3Array: + return jax.tree_map(lambda x: -x, self) + + def __pos__(self) -> Vec3Array: + return jax.tree_map(lambda x: x, self) + + def cross(self, other: Vec3Array) -> Vec3Array: + """Compute cross product between 'self' and 'other'.""" + new_x = self.y * other.z - self.z * other.y + new_y = self.z * other.x - self.x * other.z + new_z = self.x * other.y - self.y * other.x + return Vec3Array(new_x, new_y, new_z) + + def dot(self, other: Vec3Array) -> Float: + """Compute dot product between 'self' and 'other'.""" + return self.x * other.x + self.y * other.y + self.z * other.z + + def norm(self, epsilon: float = 1e-6) -> Float: + """Compute Norm of Vec3Array, clipped to epsilon.""" + # To avoid NaN on the backward pass, we must use maximum before the sqrt + norm2 = self.dot(self) + if epsilon: + norm2 = jnp.maximum(norm2, epsilon**2) + return jnp.sqrt(norm2) + + def norm2(self): + return self.dot(self) + + def normalized(self, epsilon: float = 1e-6) -> Vec3Array: + """Return unit vector with optional clipping.""" + return self / self.norm(epsilon) + + @classmethod + def zeros(cls, shape, dtype=jnp.float32): + """Return Vec3Array corresponding to zeros of given shape.""" + return cls( + jnp.zeros(shape, dtype), jnp.zeros(shape, dtype), + jnp.zeros(shape, dtype)) # pytype: disable=wrong-arg-count # trace-all-classes + + def to_array(self) -> jnp.ndarray: + return jnp.stack([self.x, self.y, self.z], axis=-1) + + @classmethod + def from_array(cls, array): + return cls(*utils.unstack(array)) + + def __getstate__(self): + return (VERSION, + [np.asarray(self.x), + np.asarray(self.y), + np.asarray(self.z)]) + + def __setstate__(self, state): + version, state = state + del version + for i, letter in enumerate('xyz'): + object.__setattr__(self, letter, state[i]) + + +def square_euclidean_distance(vec1: Vec3Array, + vec2: Vec3Array, + epsilon: float = 1e-6) -> Float: + """Computes square of euclidean distance between 'vec1' and 'vec2'. + + Args: + vec1: Vec3Array to compute distance to + vec2: Vec3Array to compute distance from, should be + broadcast compatible with 'vec1' + epsilon: distance is clipped from below to be at least epsilon + + Returns: + Array of square euclidean distances; + shape will be result of broadcasting 'vec1' and 'vec2' + """ + difference = vec1 - vec2 + distance = difference.dot(difference) + if epsilon: + distance = jnp.maximum(distance, epsilon) + return distance + + +def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float: + return vector1.dot(vector2) + + +def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float: + return vector1.cross(vector2) + + +def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float: + return vector.norm(epsilon) + + +def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array: + return vector.normalized(epsilon) + + +def euclidean_distance(vec1: Vec3Array, + vec2: Vec3Array, + epsilon: float = 1e-6) -> Float: + """Computes euclidean distance between 'vec1' and 'vec2'. + + Args: + vec1: Vec3Array to compute euclidean distance to + vec2: Vec3Array to compute euclidean distance from, should be + broadcast compatible with 'vec1' + epsilon: distance is clipped from below to be at least epsilon + + Returns: + Array of euclidean distances; + shape will be result of broadcasting 'vec1' and 'vec2' + """ + distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2) + distance = jnp.sqrt(distance_sq) + return distance + + +def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array, + d: Vec3Array) -> Float: + """Computes torsion angle for a quadruple of points. + + For points (a, b, c, d), this is the angle between the planes defined by + points (a, b, c) and (b, c, d). It is also known as the dihedral angle. + + Arguments: + a: A Vec3Array of coordinates. + b: A Vec3Array of coordinates. + c: A Vec3Array of coordinates. + d: A Vec3Array of coordinates. + + Returns: + A tensor of angles in radians: [-pi, pi]. + """ + v1 = a - b + v2 = b - c + v3 = d - c + + c1 = v1.cross(v2) + c2 = v3.cross(v2) + c3 = c2.cross(c1) + + v2_mag = v2.norm() + return jnp.arctan2(c3.dot(v2), v2_mag * c1.dot(c2)) + + +def random_gaussian_vector(shape, key, dtype=jnp.float32): + vec_array = jax.random.normal(key, shape + (3,), dtype) + return Vec3Array.from_array(vec_array) diff --git a/build/lib/colabdesign/af/alphafold/model/layer_stack.py b/build/lib/colabdesign/af/alphafold/model/layer_stack.py new file mode 100644 index 00000000..cbbb0dcb --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/layer_stack.py @@ -0,0 +1,274 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Function to stack repeats of a layer function without shared parameters.""" + +import collections +import contextlib +import functools +import inspect +from typing import Any, Callable, Optional, Tuple, Union + +import haiku as hk +import jax +import jax.numpy as jnp + +LayerStackCarry = collections.namedtuple('LayerStackCarry', ['x', 'rng']) +LayerStackScanned = collections.namedtuple('LayerStackScanned', + ['i', 'args_ys']) + +# WrappedFn should take in arbitrarily nested `jnp.ndarray`, and return the +# exact same type. We cannot express this with `typing`. So we just use it +# to inform the user. In reality, the typing below will accept anything. +NestedArray = Any +WrappedFn = Callable[..., Union[NestedArray, Tuple[NestedArray]]] + + +def _check_no_varargs(f): + if list(inspect.signature( + f).parameters.values())[0].kind == inspect.Parameter.VAR_POSITIONAL: + raise ValueError( + 'The function `f` should not have any `varargs` (that is *args) ' + 'argument. Instead, it should only use explicit positional' + 'arguments.') + + +@contextlib.contextmanager +def nullcontext(): + yield + + +def maybe_with_rng(key): + if key is not None: + return hk.with_rng(key) + else: + return nullcontext() + + +def maybe_fold_in(key, data): + if key is not None: + return jax.random.fold_in(key, data) + else: + return None + + +class _LayerStack(hk.Module): + """Module to compose parameterized functions, implemented as a scan.""" + + def __init__(self, + count: int, + unroll: int, + name: Optional[str] = None): + """Iterate a function `f` `count` times, with non-shared parameters.""" + super().__init__(name=name) + self._count = count + self._unroll = unroll + + def __call__(self, x, *args_ys): + count = self._count + if hk.running_init(): + # At initialization time, we run just one layer but add an extra first + # dimension to every initialized tensor, making sure to use different + # random keys for different slices. + def creator(next_creator, shape, dtype, init, context): + del context + + def multi_init(shape, dtype): + assert shape[0] == count + key = hk.maybe_next_rng_key() + + def rng_context_init(slice_idx): + slice_key = maybe_fold_in(key, slice_idx) + with maybe_with_rng(slice_key): + return init(shape[1:], dtype) + + return jax.vmap(rng_context_init)(jnp.arange(count)) + + return next_creator((count,) + tuple(shape), dtype, multi_init) + + def getter(next_getter, value, context): + trailing_dims = len(context.original_shape) + 1 + sliced_value = jax.lax.index_in_dim( + value, index=0, axis=value.ndim - trailing_dims, keepdims=False) + return next_getter(sliced_value) + + with hk.experimental.custom_creator( + creator), hk.experimental.custom_getter(getter): + if len(args_ys) == 1 and args_ys[0] is None: + args0 = (None,) + else: + args0 = [ + jax.lax.dynamic_index_in_dim(ys, 0, keepdims=False) + for ys in args_ys + ] + x, z = self._call_wrapped(x, *args0) + if z is None: + return x, z + + # Broadcast state to hold each layer state. + def broadcast_state(layer_state): + return jnp.broadcast_to( + layer_state, [count,] + list(layer_state.shape)) + zs = jax.tree_util.tree_map(broadcast_state, z) + return x, zs + else: + # Use scan during apply, threading through random seed so that it's + # unique for each layer. + def layer(carry: LayerStackCarry, scanned: LayerStackScanned): + rng = carry.rng + + def getter(next_getter, value, context): + # Getter slices the full param at the current loop index. + trailing_dims = len(context.original_shape) + 1 + assert value.shape[value.ndim - trailing_dims] == count, ( + f'Attempting to use a parameter stack of size ' + f'{value.shape[value.ndim - trailing_dims]} for a LayerStack of ' + f'size {count}.') + + sliced_value = jax.lax.dynamic_index_in_dim( + value, scanned.i, axis=value.ndim - trailing_dims, keepdims=False) + return next_getter(sliced_value) + + with hk.experimental.custom_getter(getter): + if rng is None: + out_x, z = self._call_wrapped(carry.x, *scanned.args_ys) + else: + rng, rng_ = jax.random.split(rng) + with hk.with_rng(rng_): + out_x, z = self._call_wrapped(carry.x, *scanned.args_ys) + return LayerStackCarry(x=out_x, rng=rng), z + + carry = LayerStackCarry(x=x, rng=hk.maybe_next_rng_key()) + scanned = LayerStackScanned(i=jnp.arange(count, dtype=jnp.int32), + args_ys=args_ys) + + carry, zs = hk.scan( + layer, carry, scanned, length=count, unroll=self._unroll) + return carry.x, zs + + def _call_wrapped(self, + x: jnp.ndarray, + *args, + ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + raise NotImplementedError() + + +class _LayerStackNoState(_LayerStack): + """_LayerStack impl with no per-layer state provided to the function.""" + + def __init__(self, + f: WrappedFn, + count: int, + unroll: int, + name: Optional[str] = None): + super().__init__(count=count, unroll=unroll, name=name) + _check_no_varargs(f) + self._f = f + + @hk.transparent + def _call_wrapped(self, args, y): + del y + ret = self._f(*args) + if len(args) == 1: + # If the function takes a single argument, the wrapped function receives + # a tuple of length 1, and therefore it must return a tuple of length 1. + ret = (ret,) + return ret, None + + +class _LayerStackWithState(_LayerStack): + """_LayerStack impl with per-layer state provided to the function.""" + + def __init__(self, + f: WrappedFn, + count: int, + unroll: int, + name: Optional[str] = None): + super().__init__(count=count, unroll=unroll, name=name) + self._f = f + + @hk.transparent + def _call_wrapped(self, x, *args): + return self._f(x, *args) + + +def layer_stack(num_layers: int, + with_state=False, + unroll: int = 1, + name: Optional[str] = None): + """Utility to wrap a Haiku function and recursively apply it to an input. + + A function is valid if it uses only explicit position parameters, and + its return type matches its input type. The position parameters can be + arbitrarily nested structures with `jnp.ndarray` at the leaf nodes. Note + that kwargs are not supported, neither are functions with variable number + of parameters (specified by `*args`). + + If `with_state=False` then the new, wrapped function can be understood as + performing the following: + ``` + for i in range(num_layers): + x = f(x) + return x + ``` + + And if `with_state=True`, assuming `f` takes two arguments on top of `x`: + ``` + for i in range(num_layers): + x, zs[i] = f(x, ys_0[i], ys_1[i]) + return x, zs + ``` + The code using `layer_stack` for the above function would be: + ``` + def f(x, y_0, y_1): + ... + return new_x, z + x, zs = layer_stack.layer_stack(num_layers, + with_state=True)(f)(x, ys_0, ys_1) + ``` + + Crucially, any parameters created inside `f` will not be shared across + iterations. + + Args: + num_layers: The number of times to iterate the wrapped function. + with_state: Whether or not to pass per-layer state to the wrapped function. + unroll: the unroll used by `scan`. + name: Name of the Haiku context. + + Returns: + Callable that will produce a layer stack when called with a valid function. + """ + def iterate(f): + if with_state: + @functools.wraps(f) + def wrapped(x, *args): + for ys in args: + assert ys.shape[0] == num_layers + return _LayerStackWithState( + f, num_layers, unroll=unroll, name=name)(x, *args) + else: + _check_no_varargs(f) + @functools.wraps(f) + def wrapped(*args): + ret = _LayerStackNoState( + f, num_layers, unroll=unroll, name=name)(args, None)[0] + if len(args) == 1: + # If the function takes a single argument, we must also return a + # single value, and not a tuple of length 1. + ret = ret[0] + return ret + + return wrapped + return iterate diff --git a/build/lib/colabdesign/af/alphafold/model/lddt.py b/build/lib/colabdesign/af/alphafold/model/lddt.py new file mode 100644 index 00000000..6b2a3f9c --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/lddt.py @@ -0,0 +1,88 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""lDDT protein distance score.""" +import jax.numpy as jnp + + +def lddt(predicted_points, + true_points, + true_points_mask, + cutoff=15., + per_residue=False): + """Measure (approximate) lDDT for a batch of coordinates. + + lDDT reference: + Mariani, V., Biasini, M., Barbato, A. & Schwede, T. lDDT: A local + superposition-free score for comparing protein structures and models using + distance difference tests. Bioinformatics 29, 2722–2728 (2013). + + lDDT is a measure of the difference between the true distance matrix and the + distance matrix of the predicted points. The difference is computed only on + points closer than cutoff *in the true structure*. + + This function does not compute the exact lDDT value that the original paper + describes because it does not include terms for physical feasibility + (e.g. bond length violations). Therefore this is only an approximate + lDDT score. + + Args: + predicted_points: (batch, length, 3) array of predicted 3D points + true_points: (batch, length, 3) array of true 3D points + true_points_mask: (batch, length, 1) binary-valued float array. This mask + should be 1 for points that exist in the true points. + cutoff: Maximum distance for a pair of points to be included + per_residue: If true, return score for each residue. Note that the overall + lDDT is not exactly the mean of the per_residue lDDT's because some + residues have more contacts than others. + + Returns: + An (approximate, see above) lDDT score in the range 0-1. + """ + + assert len(predicted_points.shape) == 3 + assert predicted_points.shape[-1] == 3 + assert true_points_mask.shape[-1] == 1 + assert len(true_points_mask.shape) == 3 + + # Compute true and predicted distance matrices. + dmat_true = jnp.sqrt(1e-10 + jnp.sum( + (true_points[:, :, None] - true_points[:, None, :])**2, axis=-1)) + + dmat_predicted = jnp.sqrt(1e-10 + jnp.sum( + (predicted_points[:, :, None] - + predicted_points[:, None, :])**2, axis=-1)) + + dists_to_score = ( + (dmat_true < cutoff).astype(jnp.float32) * true_points_mask * + jnp.transpose(true_points_mask, [0, 2, 1]) * + (1. - jnp.eye(dmat_true.shape[1])) # Exclude self-interaction. + ) + + # Shift unscored distances to be far away. + dist_l1 = jnp.abs(dmat_true - dmat_predicted) + + # True lDDT uses a number of fixed bins. + # We ignore the physical plausibility correction to lDDT, though. + score = 0.25 * ((dist_l1 < 0.5).astype(jnp.float32) + + (dist_l1 < 1.0).astype(jnp.float32) + + (dist_l1 < 2.0).astype(jnp.float32) + + (dist_l1 < 4.0).astype(jnp.float32)) + + # Normalize over the appropriate axes. + reduce_axes = (-1,) if per_residue else (-2, -1) + norm = 1. / (1e-10 + jnp.sum(dists_to_score, axis=reduce_axes)) + score = norm * (1e-10 + jnp.sum(dists_to_score * score, axis=reduce_axes)) + + return score diff --git a/build/lib/colabdesign/af/alphafold/model/mapping.py b/build/lib/colabdesign/af/alphafold/model/mapping.py new file mode 100644 index 00000000..1371b618 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/mapping.py @@ -0,0 +1,222 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Specialized mapping functions.""" + +import functools +import inspect + +from typing import Any, Callable, Optional, Sequence, Union + +import haiku as hk +import jax +import jax.numpy as jnp + + +PYTREE = Any +PYTREE_JAX_ARRAY = Any + +partial = functools.partial +PROXY = object() + + +def _maybe_slice(array, i, slice_size, axis): + if axis is PROXY: + return array + else: + return jax.lax.dynamic_slice_in_dim( + array, i, slice_size=slice_size, axis=axis) + + +def _maybe_get_size(array, axis): + if axis == PROXY: + return -1 + else: + return array.shape[axis] + + +def _expand_axes(axes, values, name='sharded_apply'): + values_tree_def = jax.tree_flatten(values)[1] + flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes) + # Replace None's with PROXY + flat_axes = [PROXY if x is None else x for x in flat_axes] + return jax.tree_unflatten(values_tree_def, flat_axes) + + +def sharded_map( + fun: Callable[..., PYTREE_JAX_ARRAY], + shard_size: Union[int, None] = 1, + in_axes: Union[int, PYTREE] = 0, + out_axes: Union[int, PYTREE] = 0) -> Callable[..., PYTREE_JAX_ARRAY]: + """Sharded vmap. + + Maps `fun` over axes, in a way similar to vmap, but does so in shards of + `shard_size`. This allows a smooth trade-off between memory usage + (as in a plain map) vs higher throughput (as in a vmap). + + Args: + fun: Function to apply smap transform to. + shard_size: Integer denoting shard size. + in_axes: Either integer or pytree describing which axis to map over for each + input to `fun`, None denotes broadcasting. + out_axes: integer or pytree denoting to what axis in the output the mapped + over axis maps. + + Returns: + function with smap applied. + """ + if 'split_rng' in inspect.signature(hk.vmap).parameters: + vmapped_fun = hk.vmap(fun, in_axes, out_axes, split_rng=False) + else: + # TODO(tomhennigan): Remove this when older versions of Haiku aren't used. + vmapped_fun = hk.vmap(fun, in_axes, out_axes) + return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes) + +def sharded_apply( + fun: Callable[..., PYTREE_JAX_ARRAY], # pylint: disable=g-bare-generic + shard_size: Union[int, None] = 1, + in_axes: Union[int, PYTREE] = 0, + out_axes: Union[int, PYTREE] = 0, + new_out_axes: bool = False) -> Callable[..., PYTREE_JAX_ARRAY]: + """Sharded apply. + + Applies `fun` over shards to axes, in a way similar to vmap, + but does so in shards of `shard_size`. Shards are stacked after. + This allows a smooth trade-off between + memory usage (as in a plain map) vs higher throughput (as in a vmap). + + Args: + fun: Function to apply smap transform to. + shard_size: Integer denoting shard size. + in_axes: Either integer or pytree describing which axis to map over for each + input to `fun`, None denotes broadcasting. + out_axes: integer or pytree denoting to what axis in the output the mapped + over axis maps. + new_out_axes: whether to stack outputs on new axes. This assumes that the + output sizes for each shard (including the possible remainder shard) are + the same. + + Returns: + function with smap applied. + """ + docstr = ('Mapped version of {fun}. Takes similar arguments to {fun} ' + 'but with additional array axes over which {fun} is mapped.') + if new_out_axes: + raise NotImplementedError('New output axes not yet implemented.') + + # shard size None denotes no sharding + if shard_size is None: + return fun + + @jax.util.wraps(fun, docstr=docstr) + def mapped_fn(*args): + # Expand in axes and Determine Loop range + in_axes_ = _expand_axes(in_axes, args) + + in_sizes = jax.tree_util.tree_map(_maybe_get_size, args, in_axes_) + flat_sizes = jax.tree_flatten(in_sizes)[0] + in_size = max(flat_sizes) + assert all(i in {in_size, -1} for i in flat_sizes) + + num_extra_shards = (in_size - 1) // shard_size + + # Fix Up if necessary + last_shard_size = in_size % shard_size + last_shard_size = shard_size if last_shard_size == 0 else last_shard_size + + def apply_fun_to_slice(slice_start, slice_size): + input_slice = jax.tree_util.tree_map( + lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis + ), args, in_axes_) + return fun(*input_slice) + + remainder_shape_dtype = hk.eval_shape( + partial(apply_fun_to_slice, 0, last_shard_size)) + out_dtypes = jax.tree_map(lambda x: x.dtype, remainder_shape_dtype) + out_shapes = jax.tree_map(lambda x: x.shape, remainder_shape_dtype) + out_axes_ = _expand_axes(out_axes, remainder_shape_dtype) + + if num_extra_shards > 0: + regular_shard_shape_dtype = hk.eval_shape( + partial(apply_fun_to_slice, 0, shard_size)) + shard_shapes = jax.tree_map(lambda x: x.shape, regular_shard_shape_dtype) + + def make_output_shape(axis, shard_shape, remainder_shape): + return shard_shape[:axis] + ( + shard_shape[axis] * num_extra_shards + + remainder_shape[axis],) + shard_shape[axis + 1:] + + out_shapes = jax.tree_util.tree_map(make_output_shape, out_axes_, shard_shapes, + out_shapes) + + # Calls dynamic Update slice with different argument order + # This is here since tree_multimap only works with positional arguments + def dynamic_update_slice_in_dim(full_array, update, axis, i): + return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis) + + def compute_shard(outputs, slice_start, slice_size): + slice_out = apply_fun_to_slice(slice_start, slice_size) + update_slice = partial( + dynamic_update_slice_in_dim, i=slice_start) + return jax.tree_util.tree_map(update_slice, outputs, slice_out, out_axes_) + + def scan_iteration(outputs, i): + new_outputs = compute_shard(outputs, i, shard_size) + return new_outputs, () + + slice_starts = jnp.arange(0, in_size - shard_size + 1, shard_size) + + def allocate_buffer(dtype, shape): + return jnp.zeros(shape, dtype=dtype) + + outputs = jax.tree_util.tree_map(allocate_buffer, out_dtypes, out_shapes) + + if slice_starts.shape[0] > 0: + outputs, _ = hk.scan(scan_iteration, outputs, slice_starts) + + if last_shard_size != shard_size: + remainder_start = in_size - last_shard_size + outputs = compute_shard(outputs, remainder_start, last_shard_size) + + return outputs + + return mapped_fn + + +def inference_subbatch( + module: Callable[..., PYTREE_JAX_ARRAY], + subbatch_size: int, + batched_args: Sequence[PYTREE_JAX_ARRAY], + nonbatched_args: Sequence[PYTREE_JAX_ARRAY], + low_memory: bool = True, + input_subbatch_dim: int = 0, + output_subbatch_dim: Optional[int] = None) -> PYTREE_JAX_ARRAY: + """Run through subbatches (like batch apply but with split and concat).""" + assert len(batched_args) > 0 # pylint: disable=g-explicit-length-test + + if not low_memory: + args = list(batched_args) + list(nonbatched_args) + return module(*args) + + if output_subbatch_dim is None: + output_subbatch_dim = input_subbatch_dim + + def run_module(*batched_args): + args = list(batched_args) + list(nonbatched_args) + return module(*args) + sharded_module = sharded_apply(run_module, + shard_size=subbatch_size, + in_axes=input_subbatch_dim, + out_axes=output_subbatch_dim) + return sharded_module(*batched_args) diff --git a/build/lib/colabdesign/af/alphafold/model/model.py b/build/lib/colabdesign/af/alphafold/model/model.py new file mode 100644 index 00000000..63d35919 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/model.py @@ -0,0 +1,97 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Code for constructing the model.""" +from typing import Any, Mapping, Optional, Union + +from absl import logging +from colabdesign.af.alphafold.model import modules +from colabdesign.af.alphafold.model import modules_multimer + +import haiku as hk +import jax +import ml_collections +import numpy as np +import tree + +class RunModel: + """Container for JAX model.""" + + def __init__(self, + config: ml_collections.ConfigDict, + params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None, + return_representations=True, + recycle_mode=None, + use_multimer=False): + + self.config = config + self.params = params + + self.mode = recycle_mode + if self.mode is None: self.mode = [] + + def _forward_fn(batch): + if use_multimer: + model = modules_multimer.AlphaFold(self.config.model) + else: + model = modules.AlphaFold(self.config.model) + return model( + batch, + return_representations=return_representations) + + self.init = jax.jit(hk.transform(_forward_fn).init) + self.apply_fn = jax.jit(hk.transform(_forward_fn).apply) + + def apply(params, key, feat): + + if "prev" in feat: + prev = feat["prev"] + else: + L = feat['aatype'].shape[0] + prev = {'prev_msa_first_row': np.zeros([L,256]), + 'prev_pair': np.zeros([L,L,128]), + 'prev_pos': np.zeros([L,37,3])} + if self.config.global_config.use_dgram: + prev['prev_dgram'] = np.zeros([L,L,64]) + feat["prev"] = prev + + ################################ + # decide how to run recycles + ################################ + if self.config.model.num_recycle: + # use scan() + def loop(prev, sub_key): + feat["prev"] = prev + results = self.apply_fn(params, sub_key, feat) + prev = results["prev"] + if "backprop" not in self.mode: + prev = jax.lax.stop_gradient(prev) + return prev, results + + keys = jax.random.split(key, self.config.model.num_recycle + 1) + _, o = jax.lax.scan(loop, prev, keys) + results = jax.tree_map(lambda x:x[-1], o) + + if "add_prev" in self.mode: + for k in ["distogram","predicted_lddt","predicted_aligned_error"]: + if k in results: + results[k]["logits"] = o[k]["logits"].mean(0) + + else: + # single pass + results = self.apply_fn(params, key, feat) + + return results + + self.apply = jax.jit(apply) \ No newline at end of file diff --git a/build/lib/colabdesign/af/alphafold/model/modules.py b/build/lib/colabdesign/af/alphafold/model/modules.py new file mode 100644 index 00000000..05a74bf2 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/modules.py @@ -0,0 +1,1780 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Modules and code used in the core part of AlphaFold. + +The structure generation code is in 'folding.py'. +""" +import functools +from colabdesign.af.alphafold.common import residue_constants +from colabdesign.af.alphafold.model import all_atom +from colabdesign.af.alphafold.model import common_modules +from colabdesign.af.alphafold.model import folding +from colabdesign.af.alphafold.model import layer_stack +from colabdesign.af.alphafold.model import lddt +from colabdesign.af.alphafold.model import mapping +from colabdesign.af.alphafold.model import prng +from colabdesign.af.alphafold.model import quat_affine +from colabdesign.af.alphafold.model import utils +import haiku as hk +import jax +import jax.numpy as jnp + +from colabdesign.af.alphafold.model.r3 import Rigids, Rots, Vecs + +def apply_dropout(*, tensor, safe_key, rate, broadcast_dim=None): + """Applies dropout to a tensor.""" + shape = list(tensor.shape) + if broadcast_dim is not None: + shape[broadcast_dim] = 1 + keep_rate = 1.0 - rate + keep = jax.random.bernoulli(safe_key.get(), keep_rate, shape=shape) + return keep * tensor / keep_rate + +def dropout_wrapper(module, + input_act, + mask, + safe_key, + global_config, + use_dropout, + output_act=None, + **kwargs): + """Applies module + dropout + residual update.""" + if output_act is None: + output_act = input_act + + gc = global_config + residual = module(input_act, mask, **kwargs) + + if module.config.shared_dropout: + if module.config.orientation == 'per_row': + broadcast_dim = 0 + else: + broadcast_dim = 1 + else: + broadcast_dim = None + + residual = apply_dropout(tensor=residual, + safe_key=safe_key, + rate=jnp.where(use_dropout, module.config.dropout_rate, 0), + broadcast_dim=broadcast_dim) + + new_act = output_act + residual + + return new_act + + +def create_extra_msa_feature(batch): + """Expand extra_msa into 1hot and concat with other extra msa features. + + We do this as late as possible as the one_hot extra msa can be very large. + + Arguments: + batch: a dictionary with the following keys: + * 'extra_msa': [N_extra_seq, N_res] MSA that wasn't selected as a cluster + centre. Note, that this is not one-hot encoded. + * 'extra_has_deletion': [N_extra_seq, N_res] Whether there is a deletion to + the left of each position in the extra MSA. + * 'extra_deletion_value': [N_extra_seq, N_res] The number of deletions to + the left of each position in the extra MSA. + + Returns: + Concatenated tensor of extra MSA features. + """ + # 23 = 20 amino acids + 'X' for unknown + gap + bert mask + msa_1hot = jax.nn.one_hot(batch['extra_msa'], 23) + msa_feat = [msa_1hot, + jnp.expand_dims(batch['extra_has_deletion'], axis=-1), + jnp.expand_dims(batch['extra_deletion_value'], axis=-1)] + return jnp.concatenate(msa_feat, axis=-1) + +class AlphaFoldIteration(hk.Module): + """A single recycling iteration of AlphaFold architecture. + Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 3-22 + """ + def __init__(self, config, global_config, name='alphafold_iteration'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, batch, **kwargs): + + # Compute representations for each batch element and average. + evoformer_module = EmbeddingsAndEvoformer(self.config.embeddings_and_evoformer, self.global_config) + representations = evoformer_module(batch) + + head_factory = { + 'masked_msa': MaskedMsaHead, + 'distogram': DistogramHead, + 'structure_module': folding.StructureModule, + 'predicted_lddt': PredictedLDDTHead, + 'predicted_aligned_error': PredictedAlignedErrorHead, + 'experimentally_resolved': ExperimentallyResolvedHead} + + heads = {} + for name, head_config in sorted(self.config.heads.items()): + if not head_config.weight: continue + heads[name] = head_factory[name](head_config, self.global_config) + + ret = {'representations':representations} + for name, head in heads.items(): + if name in ('predicted_lddt', 'predicted_aligned_error'): + continue + else: + ret[name] = head(representations, batch) + if 'representations' in ret[name]: + representations.update(ret[name].pop('representations')) + + for name in ('predicted_lddt', 'predicted_aligned_error'): + ret[name] = heads[name](representations, batch) + return ret + +class AlphaFold(hk.Module): + """AlphaFold Jumper et al. (2021) Suppl. Alg. 2 "Inference""" + def __init__(self, config, name='alphafold'): + super().__init__(name=name) + self.config = config + self.global_config = config.global_config + + def __call__(self, batch, **kwargs): + """Run the AlphaFold model.""" + impl = AlphaFoldIteration(self.config, self.global_config) + + def get_prev(ret): + new_prev = { + 'prev_msa_first_row': ret['representations']['msa_first_row'], + 'prev_pair': ret['representations']['pair'], + 'prev_pos': ret['structure_module']['final_atom_positions'] + } + if self.global_config.use_dgram: + new_prev['prev_dgram'] = ret["distogram"]["logits"] + return new_prev + + prev = batch.pop("prev") + ret = impl(batch={**batch, **prev}) + ret["prev"] = get_prev(ret) + return ret + +class TemplatePairStack(hk.Module): + """Pair stack for the templates. + + Jumper et al. (2021) Suppl. Alg. 16 "TemplatePairStack" + """ + + def __init__(self, config, global_config, name='template_pair_stack'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, pair_act, pair_mask, use_dropout, safe_key=None): + """Builds TemplatePairStack module. + + Arguments: + pair_act: Pair activations for single template, shape [N_res, N_res, c_t]. + pair_mask: Pair mask, shape [N_res, N_res]. + safe_key: Safe key object encapsulating the random number generation key. + + Returns: + Updated pair_act, shape [N_res, N_res, c_t]. + """ + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + gc = self.global_config + c = self.config + + if not c.num_block: + return pair_act + + def block(x): + """One block of the template pair stack.""" + pair_act, safe_key = x + + dropout_wrapper_fn = functools.partial( + dropout_wrapper, global_config=gc, use_dropout=use_dropout) + + safe_key, *sub_keys = safe_key.split(6) + sub_keys = iter(sub_keys) + + pair_act = dropout_wrapper_fn( + TriangleAttention(c.triangle_attention_starting_node, gc, + name='triangle_attention_starting_node'), + pair_act, + pair_mask, + next(sub_keys)) + pair_act = dropout_wrapper_fn( + TriangleAttention(c.triangle_attention_ending_node, gc, + name='triangle_attention_ending_node'), + pair_act, + pair_mask, + next(sub_keys)) + pair_act = dropout_wrapper_fn( + TriangleMultiplication(c.triangle_multiplication_outgoing, gc, + name='triangle_multiplication_outgoing'), + pair_act, + pair_mask, + next(sub_keys)) + pair_act = dropout_wrapper_fn( + TriangleMultiplication(c.triangle_multiplication_incoming, gc, + name='triangle_multiplication_incoming'), + pair_act, + pair_mask, + next(sub_keys)) + pair_act = dropout_wrapper_fn( + Transition(c.pair_transition, gc, name='pair_transition'), + pair_act, + pair_mask, + next(sub_keys)) + + return pair_act, safe_key + + if gc.use_remat: + block = hk.remat(block) + + res_stack = layer_stack.layer_stack(c.num_block)(block) + pair_act, safe_key = res_stack((pair_act, safe_key)) + return pair_act + + +class Transition(hk.Module): + """Transition layer. + + Jumper et al. (2021) Suppl. Alg. 9 "MSATransition" + Jumper et al. (2021) Suppl. Alg. 15 "PairTransition" + """ + + def __init__(self, config, global_config, name='transition_block'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, act, mask): + """Builds Transition module. + + Arguments: + act: A tensor of queries of size [batch_size, N_res, N_channel]. + mask: A tensor denoting the mask of size [batch_size, N_res]. + + Returns: + A float32 tensor of size [batch_size, N_res, N_channel]. + """ + _, _, nc = act.shape + + num_intermediate = int(nc * self.config.num_intermediate_factor) + mask = jnp.expand_dims(mask, axis=-1) + + act = common_modules.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='input_layer_norm')( + act) + + transition_module = hk.Sequential([ + common_modules.Linear( + num_intermediate, + initializer='relu', + name='transition1'), jax.nn.relu, + common_modules.Linear( + nc, + initializer=utils.final_init(self.global_config), + name='transition2') + ]) + + act = mapping.inference_subbatch( + transition_module, + self.global_config.subbatch_size, + batched_args=[act], + nonbatched_args=[], + low_memory=self.global_config.subbatch_size is not None) + + return act + + +def glorot_uniform(): + return hk.initializers.VarianceScaling(scale=1.0, + mode='fan_avg', + distribution='uniform') + + +class Attention(hk.Module): + """Multihead attention.""" + + def __init__(self, config, global_config, output_dim, name='attention'): + super().__init__(name=name) + + self.config = config + self.global_config = global_config + self.output_dim = output_dim + + def __call__(self, q_data, m_data, bias, nonbatched_bias=None): + """Builds Attention module. + + Arguments: + q_data: A tensor of queries, shape [batch_size, N_queries, q_channels]. + m_data: A tensor of memories from which the keys and values are + projected, shape [batch_size, N_keys, m_channels]. + bias: A bias for the attention, shape [batch_size, N_queries, N_keys]. + nonbatched_bias: Shared bias, shape [N_queries, N_keys]. + + Returns: + A float32 tensor of shape [batch_size, N_queries, output_dim]. + """ + # Sensible default for when the config keys are missing + key_dim = self.config.get('key_dim', int(q_data.shape[-1])) + value_dim = self.config.get('value_dim', int(m_data.shape[-1])) + num_head = self.config.num_head + assert key_dim % num_head == 0 + assert value_dim % num_head == 0 + key_dim = key_dim // num_head + value_dim = value_dim // num_head + + q_weights = hk.get_parameter( + 'query_w', shape=(q_data.shape[-1], num_head, key_dim), + dtype=q_data.dtype, + init=glorot_uniform()) + k_weights = hk.get_parameter( + 'key_w', shape=(m_data.shape[-1], num_head, key_dim), + dtype=q_data.dtype, + init=glorot_uniform()) + v_weights = hk.get_parameter( + 'value_w', shape=(m_data.shape[-1], num_head, value_dim), + dtype=q_data.dtype, + init=glorot_uniform()) + + q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5) + k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights) + v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights) + logits = jnp.einsum('bqhc,bkhc->bhqk', q, k) + bias + if nonbatched_bias is not None: + logits += jnp.expand_dims(nonbatched_bias, axis=0) + + # patch for jax > 0.3.25 + logits = jnp.clip(logits,-1e8,1e8) + + weights = jax.nn.softmax(logits) + weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v) + + if self.global_config.zero_init: + init = hk.initializers.Constant(0.0) + else: + init = glorot_uniform() + + if self.config.gating: + gating_weights = hk.get_parameter( + 'gating_w', + shape=(q_data.shape[-1], num_head, value_dim), + dtype=q_data.dtype, + init=hk.initializers.Constant(0.0)) + gating_bias = hk.get_parameter( + 'gating_b', + shape=(num_head, value_dim), + dtype=q_data.dtype, + init=hk.initializers.Constant(1.0)) + + gate_values = jnp.einsum('bqc, chv->bqhv', q_data, + gating_weights) + gating_bias + + gate_values = jax.nn.sigmoid(gate_values) + + weighted_avg *= gate_values + + o_weights = hk.get_parameter( + 'output_w', shape=(num_head, value_dim, self.output_dim), + dtype=q_data.dtype, + init=init) + o_bias = hk.get_parameter('output_b', shape=(self.output_dim,), + dtype=q_data.dtype, + init=hk.initializers.Constant(0.0)) + + output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias + + return output + + +class GlobalAttention(hk.Module): + """Global attention. + + Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" lines 2-7 + """ + + def __init__(self, config, global_config, output_dim, name='attention'): + super().__init__(name=name) + + self.config = config + self.global_config = global_config + self.output_dim = output_dim + + def __call__(self, q_data, m_data, q_mask, bias): + """Builds GlobalAttention module. + + Arguments: + q_data: A tensor of queries with size [batch_size, N_queries, + q_channels] + m_data: A tensor of memories from which the keys and values + projected. Size [batch_size, N_keys, m_channels] + q_mask: A binary mask for q_data with zeros in the padded sequence + elements and ones otherwise. Size [batch_size, N_queries, q_channels] + (or broadcastable to this shape). + bias: A bias for the attention. + + Returns: + A float32 tensor of size [batch_size, N_queries, output_dim]. + """ + # Sensible default for when the config keys are missing + key_dim = self.config.get('key_dim', int(q_data.shape[-1])) + value_dim = self.config.get('value_dim', int(m_data.shape[-1])) + num_head = self.config.num_head + assert key_dim % num_head == 0 + assert value_dim % num_head == 0 + key_dim = key_dim // num_head + value_dim = value_dim // num_head + + q_weights = hk.get_parameter( + 'query_w', shape=(q_data.shape[-1], num_head, key_dim), + dtype=q_data.dtype, + init=glorot_uniform()) + k_weights = hk.get_parameter( + 'key_w', shape=(m_data.shape[-1], key_dim), + dtype=q_data.dtype, + init=glorot_uniform()) + v_weights = hk.get_parameter( + 'value_w', shape=(m_data.shape[-1], value_dim), + dtype=q_data.dtype, + init=glorot_uniform()) + + v = jnp.einsum('bka,ac->bkc', m_data, v_weights) + + q_avg = utils.mask_mean(q_mask, q_data, axis=1) + + q = jnp.einsum('ba,ahc->bhc', q_avg, q_weights) * key_dim**(-0.5) + k = jnp.einsum('bka,ac->bkc', m_data, k_weights) + bias = (1e9 * (q_mask[:, None, :, 0] - 1.)) + logits = jnp.einsum('bhc,bkc->bhk', q, k) + bias + weights = jax.nn.softmax(logits) + weighted_avg = jnp.einsum('bhk,bkc->bhc', weights, v) + + if self.global_config.zero_init: + init = hk.initializers.Constant(0.0) + else: + init = glorot_uniform() + + o_weights = hk.get_parameter( + 'output_w', shape=(num_head, value_dim, self.output_dim), + dtype=q_data.dtype, + init=init) + o_bias = hk.get_parameter('output_b', shape=(self.output_dim,), + dtype=q_data.dtype, + init=hk.initializers.Constant(0.0)) + + if self.config.gating: + gating_weights = hk.get_parameter( + 'gating_w', + shape=(q_data.shape[-1], num_head, value_dim), + dtype=q_data.dtype, + init=hk.initializers.Constant(0.0)) + gating_bias = hk.get_parameter( + 'gating_b', + shape=(num_head, value_dim), + dtype=q_data.dtype, + init=hk.initializers.Constant(1.0)) + + gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gating_weights) + gate_values = jax.nn.sigmoid(gate_values + gating_bias) + weighted_avg = weighted_avg[:, None] * gate_values + output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias + else: + output = jnp.einsum('bhc,hco->bo', weighted_avg, o_weights) + o_bias + output = output[:, None] + return output + + +class MSARowAttentionWithPairBias(hk.Module): + """MSA per-row attention biased by the pair representation. + + Jumper et al. (2021) Suppl. Alg. 7 "MSARowAttentionWithPairBias" + """ + + def __init__(self, config, global_config, + name='msa_row_attention_with_pair_bias'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, + msa_act, + msa_mask, + pair_act): + """Builds MSARowAttentionWithPairBias module. + + Arguments: + msa_act: [N_seq, N_res, c_m] MSA representation. + msa_mask: [N_seq, N_res] mask of non-padded regions. + pair_act: [N_res, N_res, c_z] pair representation. + + Returns: + Update to msa_act, shape [N_seq, N_res, c_m]. + """ + c = self.config + + assert len(msa_act.shape) == 3 + assert len(msa_mask.shape) == 2 + assert c.orientation == 'per_row' + + bias = (1e9 * (msa_mask - 1.))[:, None, None, :] + assert len(bias.shape) == 4 + + msa_act = common_modules.LayerNorm( + axis=[-1], create_scale=True, create_offset=True, name='query_norm')( + msa_act) + + pair_act = common_modules.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='feat_2d_norm')( + pair_act) + + init_factor = 1. / jnp.sqrt(int(pair_act.shape[-1])) + weights = hk.get_parameter( + 'feat_2d_weights', + shape=(pair_act.shape[-1], c.num_head), + dtype=msa_act.dtype, + init=hk.initializers.RandomNormal(stddev=init_factor)) + nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights) + + attn_mod = Attention( + c, self.global_config, msa_act.shape[-1]) + msa_act = mapping.inference_subbatch( + attn_mod, + self.global_config.subbatch_size, + batched_args=[msa_act, msa_act, bias], + nonbatched_args=[nonbatched_bias], + low_memory=self.global_config.subbatch_size is not None) + + return msa_act + + +class MSAColumnAttention(hk.Module): + """MSA per-column attention. + + Jumper et al. (2021) Suppl. Alg. 8 "MSAColumnAttention" + """ + + def __init__(self, config, global_config, name='msa_column_attention'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, + msa_act, + msa_mask): + """Builds MSAColumnAttention module. + + Arguments: + msa_act: [N_seq, N_res, c_m] MSA representation. + msa_mask: [N_seq, N_res] mask of non-padded regions. + + Returns: + Update to msa_act, shape [N_seq, N_res, c_m] + """ + c = self.config + + assert len(msa_act.shape) == 3 + assert len(msa_mask.shape) == 2 + assert c.orientation == 'per_column' + + msa_act = jnp.swapaxes(msa_act, -2, -3) + msa_mask = jnp.swapaxes(msa_mask, -1, -2) + + bias = (1e9 * (msa_mask - 1.))[:, None, None, :] + assert len(bias.shape) == 4 + + msa_act = common_modules.LayerNorm( + axis=[-1], create_scale=True, create_offset=True, name='query_norm')( + msa_act) + + attn_mod = Attention( + c, self.global_config, msa_act.shape[-1]) + msa_act = mapping.inference_subbatch( + attn_mod, + self.global_config.subbatch_size, + batched_args=[msa_act, msa_act, bias], + nonbatched_args=[], + low_memory=self.global_config.subbatch_size is not None) + + msa_act = jnp.swapaxes(msa_act, -2, -3) + + return msa_act + + +class MSAColumnGlobalAttention(hk.Module): + """MSA per-column global attention. + + Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" + """ + + def __init__(self, config, global_config, name='msa_column_global_attention'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, + msa_act, + msa_mask): + """Builds MSAColumnGlobalAttention module. + + Arguments: + msa_act: [N_seq, N_res, c_m] MSA representation. + msa_mask: [N_seq, N_res] mask of non-padded regions. + + Returns: + Update to msa_act, shape [N_seq, N_res, c_m]. + """ + c = self.config + + assert len(msa_act.shape) == 3 + assert len(msa_mask.shape) == 2 + assert c.orientation == 'per_column' + + msa_act = jnp.swapaxes(msa_act, -2, -3) + msa_mask = jnp.swapaxes(msa_mask, -1, -2) + + bias = (1e9 * (msa_mask - 1.))[:, None, None, :] + assert len(bias.shape) == 4 + + msa_act = common_modules.LayerNorm( + axis=[-1], create_scale=True, create_offset=True, name='query_norm')( + msa_act) + + attn_mod = GlobalAttention( + c, self.global_config, msa_act.shape[-1], + name='attention') + # [N_seq, N_res, 1] + msa_mask = jnp.expand_dims(msa_mask, axis=-1) + msa_act = mapping.inference_subbatch( + attn_mod, + self.global_config.subbatch_size, + batched_args=[msa_act, msa_act, msa_mask, bias], + nonbatched_args=[], + low_memory=self.global_config.subbatch_size is not None) + + msa_act = jnp.swapaxes(msa_act, -2, -3) + + return msa_act + + +class TriangleAttention(hk.Module): + """Triangle Attention. + + Jumper et al. (2021) Suppl. Alg. 13 "TriangleAttentionStartingNode" + Jumper et al. (2021) Suppl. Alg. 14 "TriangleAttentionEndingNode" + """ + + def __init__(self, config, global_config, name='triangle_attention'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, pair_act, pair_mask): + """Builds TriangleAttention module. + + Arguments: + pair_act: [N_res, N_res, c_z] pair activations tensor + pair_mask: [N_res, N_res] mask of non-padded regions in the tensor. + + Returns: + Update to pair_act, shape [N_res, N_res, c_z]. + """ + c = self.config + + assert len(pair_act.shape) == 3 + assert len(pair_mask.shape) == 2 + assert c.orientation in ['per_row', 'per_column'] + + if c.orientation == 'per_column': + pair_act = jnp.swapaxes(pair_act, -2, -3) + pair_mask = jnp.swapaxes(pair_mask, -1, -2) + + bias = (1e9 * (pair_mask - 1.))[:, None, None, :] + assert len(bias.shape) == 4 + + pair_act = common_modules.LayerNorm( + axis=[-1], create_scale=True, create_offset=True, name='query_norm')( + pair_act) + + init_factor = 1. / jnp.sqrt(int(pair_act.shape[-1])) + weights = hk.get_parameter( + 'feat_2d_weights', + shape=(pair_act.shape[-1], c.num_head), + dtype=pair_act.dtype, + init=hk.initializers.RandomNormal(stddev=init_factor)) + nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights) + + attn_mod = Attention( + c, self.global_config, pair_act.shape[-1]) + pair_act = mapping.inference_subbatch( + attn_mod, + self.global_config.subbatch_size, + batched_args=[pair_act, pair_act, bias], + nonbatched_args=[nonbatched_bias], + low_memory=self.global_config.subbatch_size is not None) + + if c.orientation == 'per_column': + pair_act = jnp.swapaxes(pair_act, -2, -3) + + return pair_act + + +class MaskedMsaHead(hk.Module): + """Head to predict MSA at the masked locations. + + The MaskedMsaHead employs a BERT-style objective to reconstruct a masked + version of the full MSA, based on a linear projection of + the MSA representation. + Jumper et al. (2021) Suppl. Sec. 1.9.9 "Masked MSA prediction" + """ + + def __init__(self, config, global_config, name='masked_msa_head'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + if global_config.multimer_mode: + self.num_output = len(residue_constants.restypes_with_x_and_gap) + else: + self.num_output = config.num_output + + def __call__(self, representations, batch): + """Builds MaskedMsaHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'msa': MSA representation, shape [N_seq, N_res, c_m]. + batch: Batch, unused. + + Returns: + Dictionary containing: + * 'logits': logits of shape [N_seq, N_res, N_aatype] with + (unnormalized) log probabilies of predicted aatype at position. + """ + del batch + logits = common_modules.Linear( + self.num_output, + initializer=utils.final_init(self.global_config), + name='logits')( + representations['msa']) + return dict(logits=logits) + +class PredictedLDDTHead(hk.Module): + """Head to predict the per-residue LDDT to be used as a confidence measure. + + Jumper et al. (2021) Suppl. Sec. 1.9.6 "Model confidence prediction (pLDDT)" + Jumper et al. (2021) Suppl. Alg. 29 "predictPerResidueLDDT_Ca" + """ + + def __init__(self, config, global_config, name='predicted_lddt_head'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, representations, batch): + """Builds ExperimentallyResolvedHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'structure_module': Single representation from the structure module, + shape [N_res, c_s]. + batch: Batch, unused. + + Returns: + Dictionary containing : + * 'logits': logits of shape [N_res, N_bins] with + (unnormalized) log probabilies of binned predicted lDDT. + """ + act = representations['structure_module'] + + act = common_modules.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='input_layer_norm')( + act) + + act = common_modules.Linear( + self.config.num_channels, + initializer='relu', + name='act_0')( + act) + act = jax.nn.relu(act) + + act = common_modules.Linear( + self.config.num_channels, + initializer='relu', + name='act_1')( + act) + act = jax.nn.relu(act) + + logits = common_modules.Linear( + self.config.num_bins, + initializer=utils.final_init(self.global_config), + name='logits')( + act) + # Shape (batch_size, num_res, num_bins) + return dict(logits=logits) + + +class PredictedAlignedErrorHead(hk.Module): + """Head to predict the distance errors in the backbone alignment frames. + + Can be used to compute predicted TM-Score. + Jumper et al. (2021) Suppl. Sec. 1.9.7 "TM-score prediction" + """ + + def __init__(self, config, global_config, + name='predicted_aligned_error_head'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, representations, batch): + """Builds PredictedAlignedErrorHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'pair': pair representation, shape [N_res, N_res, c_z]. + batch: Batch, unused. + + Returns: + Dictionary containing: + * logits: logits for aligned error, shape [N_res, N_res, N_bins]. + * bin_breaks: array containing bin breaks, shape [N_bins - 1]. + """ + + act = representations['pair'] + + # Shape (num_res, num_res, num_bins) + logits = common_modules.Linear( + self.config.num_bins, + initializer=utils.final_init(self.global_config), + name='logits')(act) + # Shape (num_bins,) + breaks = jnp.linspace( + 0., self.config.max_error_bin, self.config.num_bins - 1) + return dict(logits=logits, breaks=breaks) + + +class ExperimentallyResolvedHead(hk.Module): + """Predicts if an atom is experimentally resolved in a high-res structure. + + Only trained on high-resolution X-ray crystals & cryo-EM. + Jumper et al. (2021) Suppl. Sec. 1.9.10 '"Experimentally resolved" prediction' + """ + + def __init__(self, config, global_config, + name='experimentally_resolved_head'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, representations, batch): + """Builds ExperimentallyResolvedHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'single': Single representation, shape [N_res, c_s]. + batch: Batch, unused. + + Returns: + Dictionary containing: + * 'logits': logits of shape [N_res, 37], + log probability that an atom is resolved in atom37 representation, + can be converted to probability by applying sigmoid. + """ + logits = common_modules.Linear( + 37, # atom_exists.shape[-1] + initializer=utils.final_init(self.global_config), + name='logits')(representations['single']) + return dict(logits=logits) + + +def _layer_norm(axis=-1, name='layer_norm'): + return common_modules.LayerNorm( + axis=axis, + create_scale=True, + create_offset=True, + eps=1e-5, + use_fast_variance=True, + scale_init=hk.initializers.Constant(1.), + offset_init=hk.initializers.Constant(0.), + param_axis=axis, + name=name) + +class TriangleMultiplication(hk.Module): + """Triangle multiplication layer ("outgoing" or "incoming"). + Jumper et al. (2021) Suppl. Alg. 11 "TriangleMultiplicationOutgoing" + Jumper et al. (2021) Suppl. Alg. 12 "TriangleMultiplicationIncoming" + """ + + def __init__(self, config, global_config, name='triangle_multiplication'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, left_act, left_mask): + """Builds TriangleMultiplication module. + Arguments: + left_act: Pair activations, shape [N_res, N_res, c_z] + left_mask: Pair mask, shape [N_res, N_res]. + Returns: + Outputs, same shape/type as left_act. + """ + + if self.config.fuse_projection_weights: + return self._fused_triangle_multiplication(left_act, left_mask) + else: + return self._triangle_multiplication(left_act, left_mask) + + @hk.transparent + def _triangle_multiplication(self, left_act, left_mask): + """Implementation of TriangleMultiplication used in AF2 and AF-M<2.3.""" + c = self.config + gc = self.global_config + + mask = left_mask[..., None] + + act = common_modules.LayerNorm(axis=[-1], create_scale=True, create_offset=True, + name='layer_norm_input')(left_act) + input_act = act + + left_projection = common_modules.Linear( + c.num_intermediate_channel, + name='left_projection') + left_proj_act = mask * left_projection(act) + + right_projection = common_modules.Linear( + c.num_intermediate_channel, + name='right_projection') + right_proj_act = mask * right_projection(act) + + left_gate_values = jax.nn.sigmoid(common_modules.Linear( + c.num_intermediate_channel, + bias_init=1., + initializer=utils.final_init(gc), + name='left_gate')(act)) + + right_gate_values = jax.nn.sigmoid(common_modules.Linear( + c.num_intermediate_channel, + bias_init=1., + initializer=utils.final_init(gc), + name='right_gate')(act)) + + left_proj_act *= left_gate_values + right_proj_act *= right_gate_values + + # "Outgoing" edges equation: 'ikc,jkc->ijc' + # "Incoming" edges equation: 'kjc,kic->ijc' + # Note on the Suppl. Alg. 11 & 12 notation: + # For the "outgoing" edges, a = left_proj_act and b = right_proj_act + # For the "incoming" edges, it's swapped: + # b = left_proj_act and a = right_proj_act + act = jnp.einsum(c.equation, left_proj_act, right_proj_act) + + act = common_modules.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='center_layer_norm')( + act) + + output_channel = int(input_act.shape[-1]) + + act = common_modules.Linear( + output_channel, + initializer=utils.final_init(gc), + name='output_projection')(act) + + gate_values = jax.nn.sigmoid(common_modules.Linear( + output_channel, + bias_init=1., + initializer=utils.final_init(gc), + name='gating_linear')(input_act)) + act *= gate_values + + return act + + @hk.transparent + def _fused_triangle_multiplication(self, left_act, left_mask): + """TriangleMultiplication with fused projection weights.""" + mask = left_mask[..., None] + c = self.config + gc = self.global_config + + left_act = _layer_norm(axis=-1, name='left_norm_input')(left_act) + + # Both left and right projections are fused into projection. + projection = common_modules.Linear( + 2*c.num_intermediate_channel, name='projection') + proj_act = mask * projection(left_act) + + # Both left + right gate are fused into gate_values. + gate_values = common_modules.Linear( + 2 * c.num_intermediate_channel, + name='gate', + bias_init=1., + initializer=utils.final_init(gc))(left_act) + proj_act *= jax.nn.sigmoid(gate_values) + + left_proj_act = proj_act[:, :, :c.num_intermediate_channel] + right_proj_act = proj_act[:, :, c.num_intermediate_channel:] + act = jnp.einsum(c.equation, left_proj_act, right_proj_act) + + act = _layer_norm(axis=-1, name='center_norm')(act) + + output_channel = int(left_act.shape[-1]) + + act = common_modules.Linear( + output_channel, + initializer=utils.final_init(gc), + name='output_projection')(act) + + gate_values = common_modules.Linear( + output_channel, + bias_init=1., + initializer=utils.final_init(gc), + name='gating_linear')(left_act) + act *= jax.nn.sigmoid(gate_values) + + return act + + +class DistogramHead(hk.Module): + """Head to predict a distogram. + + Jumper et al. (2021) Suppl. Sec. 1.9.8 "Distogram prediction" + """ + + def __init__(self, config, global_config, name='distogram_head'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, representations, batch): + """Builds DistogramHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'pair': pair representation, shape [N_res, N_res, c_z]. + batch: Batch, unused. + + Returns: + Dictionary containing: + * logits: logits for distogram, shape [N_res, N_res, N_bins]. + * bin_breaks: array containing bin breaks, shape [N_bins - 1,]. + """ + half_logits = common_modules.Linear( + self.config.num_bins, + initializer=utils.final_init(self.global_config), + name='half_logits')( + representations['pair']) + + logits = half_logits + jnp.swapaxes(half_logits, -2, -3) + breaks = jnp.linspace(self.config.first_break, self.config.last_break, + self.config.num_bins - 1) + + return dict(logits=logits, bin_edges=breaks) + + +class OuterProductMean(hk.Module): + """Computes mean outer product. + + Jumper et al. (2021) Suppl. Alg. 10 "OuterProductMean" + """ + + def __init__(self, + config, + global_config, + num_output_channel, + name='outer_product_mean'): + super().__init__(name=name) + self.global_config = global_config + self.config = config + self.num_output_channel = num_output_channel + + def __call__(self, act, mask): + """Builds OuterProductMean module. + + Arguments: + act: MSA representation, shape [N_seq, N_res, c_m]. + mask: MSA mask, shape [N_seq, N_res]. + + Returns: + Update to pair representation, shape [N_res, N_res, c_z]. + """ + gc = self.global_config + c = self.config + + mask = mask[..., None] + act = common_modules.LayerNorm([-1], True, True, name='layer_norm_input')(act) + + left_act = mask * common_modules.Linear( + c.num_outer_channel, + initializer='linear', + name='left_projection')( + act) + + right_act = mask * common_modules.Linear( + c.num_outer_channel, + initializer='linear', + name='right_projection')( + act) + + if gc.zero_init: + init_w = hk.initializers.Constant(0.0) + else: + init_w = hk.initializers.VarianceScaling(scale=2., mode='fan_in') + + output_w = hk.get_parameter( + 'output_w', + shape=(c.num_outer_channel, c.num_outer_channel, + self.num_output_channel), + dtype=act.dtype, + init=init_w) + output_b = hk.get_parameter( + 'output_b', shape=(self.num_output_channel,), + dtype=act.dtype, + init=hk.initializers.Constant(0.0)) + + def compute_chunk(left_act): + # This is equivalent to + # + # act = jnp.einsum('abc,ade->dceb', left_act, right_act) + # act = jnp.einsum('dceb,cef->bdf', act, output_w) + output_b + # + # but faster. + left_act = jnp.transpose(left_act, [0, 2, 1]) + act = jnp.einsum('acb,ade->dceb', left_act, right_act) + act = jnp.einsum('dceb,cef->dbf', act, output_w) + output_b + return jnp.transpose(act, [1, 0, 2]) + + act = mapping.inference_subbatch( + compute_chunk, + c.chunk_size, + batched_args=[left_act], + nonbatched_args=[], + low_memory=True, + input_subbatch_dim=1, + output_subbatch_dim=0) + + epsilon = 1e-3 + norm = jnp.einsum('abc,adc->bdc', mask, mask) + act /= epsilon + norm + + return act + +def dgram_from_positions(positions, num_bins, min_bin, max_bin): + """Compute distogram from amino acid positions. + Arguments: + positions: [N_res, 3] Position coordinates. + num_bins: The number of bins in the distogram. + min_bin: The left edge of the first bin. + max_bin: The left edge of the final bin. The final bin catches + everything larger than `max_bin`. + Returns: + Distogram with the specified number of bins. + """ + def squared_difference(x, y): + return jnp.square(x - y) + + lower_breaks = jnp.linspace(min_bin, max_bin, num_bins) + lower_breaks = jnp.square(lower_breaks) + upper_breaks = jnp.concatenate([lower_breaks[1:],jnp.array([1e8], dtype=jnp.float32)], axis=-1) + dist2 = jnp.sum( + squared_difference( + jnp.expand_dims(positions, axis=-2), + jnp.expand_dims(positions, axis=-3)), + axis=-1, keepdims=True) + + return ((dist2 > lower_breaks).astype(jnp.float32) * (dist2 < upper_breaks).astype(jnp.float32)) + +def dgram_from_positions_soft(positions, num_bins, min_bin, max_bin, temp=2.0): + '''soft positions to dgram converter''' + lower_breaks = jnp.append(-1e8,jnp.linspace(min_bin, max_bin, num_bins)) + upper_breaks = jnp.append(lower_breaks[1:],1e8) + dist = jnp.sqrt(jnp.square(positions[...,:,None,:] - positions[...,None,:,:]).sum(-1,keepdims=True) + 1e-8) + o = jax.nn.sigmoid((dist - lower_breaks)/temp) * jax.nn.sigmoid((upper_breaks - dist)/temp) + o = o/(o.sum(-1,keepdims=True) + 1e-8) + return o[...,1:] + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask): + """Create pseudo beta features.""" + + ca_idx = residue_constants.atom_order['CA'] + cb_idx = residue_constants.atom_order['CB'] + + is_gly = jnp.equal(aatype, residue_constants.restype_order['G']) + is_gly_tile = jnp.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]) + pseudo_beta = jnp.where(is_gly_tile, all_atom_positions[..., ca_idx, :], all_atom_positions[..., cb_idx, :]) + + if all_atom_mask is not None: + pseudo_beta_mask = jnp.where(is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]) + pseudo_beta_mask = pseudo_beta_mask.astype(jnp.float32) + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + +class EvoformerIteration(hk.Module): + """Single iteration (block) of Evoformer stack. + Jumper et al. (2021) Suppl. Alg. 6 "EvoformerStack" lines 2-10 + """ + + def __init__(self, config, global_config, is_extra_msa, + name='evoformer_iteration'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + self.is_extra_msa = is_extra_msa + + def __call__(self, activations, masks, use_dropout, safe_key=None): + """Builds EvoformerIteration module. + + Arguments: + activations: Dictionary containing activations: + * 'msa': MSA activations, shape [N_seq, N_res, c_m]. + * 'pair': pair activations, shape [N_res, N_res, c_z]. + masks: Dictionary of masks: + * 'msa': MSA mask, shape [N_seq, N_res]. + * 'pair': pair mask, shape [N_res, N_res]. + safe_key: prng.SafeKey encapsulating rng key. + + Returns: + Outputs, same shape/type as act. + """ + c = self.config + gc = self.global_config + + msa_act, pair_act = activations['msa'], activations['pair'] + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + msa_mask, pair_mask = masks['msa'], masks['pair'] + + dropout_wrapper_fn = functools.partial( + dropout_wrapper, + global_config=gc, + use_dropout=use_dropout) + + safe_key, *sub_keys = safe_key.split(10) + sub_keys = iter(sub_keys) + + outer_module = OuterProductMean( + config=c.outer_product_mean, + global_config=self.global_config, + num_output_channel=int(pair_act.shape[-1]), + name='outer_product_mean') + if c.outer_product_mean.first: + pair_act = dropout_wrapper_fn( + outer_module, + msa_act, + msa_mask, + safe_key=next(sub_keys), + output_act=pair_act) + + msa_act = dropout_wrapper_fn( + MSARowAttentionWithPairBias( + c.msa_row_attention_with_pair_bias, gc, + name='msa_row_attention_with_pair_bias'), + msa_act, + msa_mask, + safe_key=next(sub_keys), + pair_act=pair_act) + + if not self.is_extra_msa: + attn_mod = MSAColumnAttention( + c.msa_column_attention, gc, name='msa_column_attention') + else: + attn_mod = MSAColumnGlobalAttention( + c.msa_column_attention, gc, name='msa_column_global_attention') + msa_act = dropout_wrapper_fn( + attn_mod, + msa_act, + msa_mask, + safe_key=next(sub_keys)) + + msa_act = dropout_wrapper_fn( + Transition(c.msa_transition, gc, name='msa_transition'), + msa_act, + msa_mask, + safe_key=next(sub_keys)) + + if not c.outer_product_mean.first: + pair_act = dropout_wrapper_fn( + outer_module, + msa_act, + msa_mask, + safe_key=next(sub_keys), + output_act=pair_act) + + pair_act = dropout_wrapper_fn( + TriangleMultiplication(c.triangle_multiplication_outgoing, gc, + name='triangle_multiplication_outgoing'), + pair_act, + pair_mask, + safe_key=next(sub_keys)) + pair_act = dropout_wrapper_fn( + TriangleMultiplication(c.triangle_multiplication_incoming, gc, + name='triangle_multiplication_incoming'), + pair_act, + pair_mask, + safe_key=next(sub_keys)) + + pair_act = dropout_wrapper_fn( + TriangleAttention(c.triangle_attention_starting_node, gc, + name='triangle_attention_starting_node'), + pair_act, + pair_mask, + safe_key=next(sub_keys)) + pair_act = dropout_wrapper_fn( + TriangleAttention(c.triangle_attention_ending_node, gc, + name='triangle_attention_ending_node'), + pair_act, + pair_mask, + safe_key=next(sub_keys)) + + pair_act = dropout_wrapper_fn( + Transition(c.pair_transition, gc, name='pair_transition'), + pair_act, + pair_mask, + safe_key=next(sub_keys)) + + return {'msa': msa_act, 'pair': pair_act} + + +class EmbeddingsAndEvoformer(hk.Module): + """Embeds the input data and runs Evoformer. + + Produces the MSA, single and pair representations. + Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5-18 + """ + + def __init__(self, config, global_config, name='evoformer'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, batch, safe_key=None): + + c = self.config + gc = self.global_config + dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32 + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + with utils.bfloat16_context(): + + # Embed clustered MSA. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5 + # Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder" + + msa_feat = batch['msa_feat'].astype(dtype) + target_feat = jnp.pad(batch["target_feat"].astype(dtype),[[0,0],[1,1]]) + preprocess_1d = common_modules.Linear(c.msa_channel, name='preprocess_1d')(target_feat) + preprocess_msa = common_modules.Linear(c.msa_channel, name='preprocess_msa')(msa_feat) + msa_activations = preprocess_1d[None] + preprocess_msa + + left_single = common_modules.Linear(c.pair_channel, name='left_single')(target_feat) + right_single = common_modules.Linear(c.pair_channel, name='right_single')(target_feat) + pair_activations = left_single[:, None] + right_single[None] + + mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :] + mask_2d = mask_2d.astype(dtype) + + # Inject previous outputs for recycling. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6 + # Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder" + + if gc.use_dgram: + # use predicted distogram input (from Sergey) + dgram = jax.nn.softmax(batch["prev_dgram"]) + dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0) + dgram = dgram @ dgram_map + + else: + # use predicted position input + prev_pseudo_beta = pseudo_beta_fn(batch['aatype'], batch['prev_pos'], None) + if c.backprop_dgram: + dgram = dgram_from_positions_soft(prev_pseudo_beta, temp=c.backprop_dgram_temp, **c.prev_pos) + else: + dgram = dgram_from_positions(prev_pseudo_beta, **c.prev_pos) + dgram = dgram.astype(dtype) + pair_activations += common_modules.Linear(c.pair_channel, name='prev_pos_linear')(dgram) + + if c.recycle_features: + if 'prev_msa_first_row' in batch: + prev_msa_first_row = common_modules.LayerNorm( + axis=[-1], create_scale=True, create_offset=True, + name='prev_msa_first_row_norm')(batch['prev_msa_first_row']).astype(dtype) + msa_activations = msa_activations.at[0].add(prev_msa_first_row) + + if 'prev_pair' in batch: + pair_activations += common_modules.LayerNorm( + axis=[-1], create_scale=True, create_offset=True, + name='prev_pair_norm')(batch['prev_pair']).astype(dtype) + + # Relative position encoding. + # Jumper et al. (2021) Suppl. Alg. 4 "relpos" + # Jumper et al. (2021) Suppl. Alg. 5 "one_hot" + if c.max_relative_feature: + # Add one-hot-encoded clipped residue distances to the pair activations. + if "rel_pos" in batch: + rel_pos = batch['rel_pos'].astype(dtype) + else: + if "offset" in batch: + offset = batch['offset'] + else: + pos = batch['residue_index'] + offset = pos[:, None] - pos[None, :] + rel_pos = jax.nn.one_hot( + jnp.clip( + offset + c.max_relative_feature, + a_min=0, + a_max=2 * c.max_relative_feature), + 2 * c.max_relative_feature + 1).astype(dtype) + pair_activations += common_modules.Linear(c.pair_channel, name='pair_activiations')(rel_pos) + + # Embed templates into the pair activations. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-13 + + if c.template.enabled: + template_batch = {k: batch[k] for k in batch if k.startswith('template_')} + + multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :] + multichain_mask = jnp.where(batch["mask_template_interchain"], multichain_mask, True) + + template_pair_representation = TemplateEmbedding(c.template, gc)( + pair_activations, + template_batch, + mask_2d, + multichain_mask, + use_dropout=batch["use_dropout"]) + + pair_activations += template_pair_representation + + # Embed extra MSA features. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 14-16 + if c.use_extra_msa: + extra_msa_feat = create_extra_msa_feature(batch) + extra_msa_activations = common_modules.Linear(c.extra_msa_channel, + name='extra_msa_activations')(extra_msa_feat).astype(dtype) + # Extra MSA Stack. + # Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack" + extra_msa_stack_input = {'msa': extra_msa_activations, + 'pair': pair_activations} + extra_msa_stack_iteration = EvoformerIteration(c.evoformer, gc, + is_extra_msa=True, name='extra_msa_stack') + def extra_msa_stack_fn(x): + act, safe_key = x + safe_key, safe_subkey = safe_key.split() + extra_evoformer_output = extra_msa_stack_iteration( + activations=act, + masks={'msa': batch['extra_msa_mask'].astype(dtype), + 'pair': mask_2d}, + safe_key=safe_subkey, + use_dropout=batch["use_dropout"]) + return (extra_evoformer_output, safe_key) + if gc.use_remat: extra_msa_stack_fn = hk.remat(extra_msa_stack_fn) + extra_msa_stack = layer_stack.layer_stack(c.extra_msa_stack_num_block)(extra_msa_stack_fn) + extra_msa_output, safe_key = extra_msa_stack((extra_msa_stack_input, safe_key)) + pair_activations = extra_msa_output['pair'] + + evoformer_input = {'msa': msa_activations,'pair': pair_activations} + evoformer_masks = {'msa': batch['msa_mask'].astype(dtype), + 'pair': mask_2d} + #################################################################### + + # Append num_templ rows to msa_activations with template embeddings. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 7-8 + if c.template.enabled and c.template.embed_torsion_angles: + num_templ, num_res = batch['template_aatype'].shape + # Embed the templates aatypes. + aatype = batch['template_aatype'] + aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1) + + # Embed the templates aatype, torsion angles and masks. + # Shape (templates, residues, msa_channels) + ret = all_atom.atom37_to_torsion_angles( + aatype=aatype, + all_atom_pos=batch['template_all_atom_positions'], + all_atom_mask=batch['template_all_atom_mask'], + # Ensure consistent behaviour during testing: + placeholder_for_undefined=not gc.zero_init) + + template_features = jnp.concatenate([ + aatype_one_hot, + jnp.reshape(ret['torsion_angles_sin_cos'], [num_templ, num_res, 14]), + jnp.reshape(ret['alt_torsion_angles_sin_cos'], [num_templ, num_res, 14]), + ret['torsion_angles_mask']], axis=-1).astype(dtype) + + template_activations = common_modules.Linear( + c.msa_channel, + initializer='relu', + name='template_single_embedding')(template_features) + template_activations = jax.nn.relu(template_activations) + template_activations = common_modules.Linear( + c.msa_channel, + initializer='relu', + name='template_projection')(template_activations) + + # Concatenate the templates to the msa. + evoformer_input['msa'] = jnp.concatenate([evoformer_input['msa'], template_activations], axis=0) + + # Concatenate templates masks to the msa masks. + # Use mask from the psi angle, as it only depends on the backbone atoms + # from a single residue. + torsion_angle_mask = ret['torsion_angles_mask'][:, :, 2] + torsion_angle_mask = torsion_angle_mask.astype(evoformer_masks['msa'].dtype) + evoformer_masks['msa'] = jnp.concatenate([evoformer_masks['msa'], torsion_angle_mask], axis=0) + #################################################################### + if c.use_msa: + # Main trunk of the network + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 17-18 + evoformer_iteration = EvoformerIteration(c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration') + def evoformer_fn(x): + act, safe_key = x + safe_key, safe_subkey = safe_key.split() + evoformer_output = evoformer_iteration( + activations=act, + masks=evoformer_masks, + safe_key=safe_subkey, + use_dropout=batch["use_dropout"]) + return (evoformer_output, safe_key) + if gc.use_remat: evoformer_fn = hk.remat(evoformer_fn) + evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)(evoformer_fn) + evoformer_output, safe_key = evoformer_stack((evoformer_input, safe_key)) + msa_activations = evoformer_output['msa'] + pair_activations = evoformer_output['pair'] + + single_activations = common_modules.Linear(c.seq_channel, name='single_activations')(msa_activations[0]) + num_sequences = batch['msa_feat'].shape[0] + + output = { + 'single': single_activations, + 'pair': pair_activations, + # Crop away template rows such that they are not used in MaskedMsaHead. + 'msa': msa_activations[:num_sequences, :, :], + 'msa_first_row': msa_activations[0], + } + + # Convert back to float32 if we're not saving memory. + if not gc.bfloat16_output: + for k, v in output.items(): + if v.dtype == jnp.bfloat16: + output[k] = v.astype(jnp.float32) + return output + +#################################################################### +#################################################################### +class SingleTemplateEmbedding(hk.Module): + """Embeds a single template. + Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9+11 + """ + + def __init__(self, config, global_config, name='single_template_embedding'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, query_embedding, batch, mask_2d, multichain_mask_2d, use_dropout): + """Build the single template embedding. + Arguments: + query_embedding: Query pair representation, shape [N_res, N_res, c_z]. + batch: A batch of template features (note the template dimension has been + stripped out as this module only runs over a single template). + mask_2d: Padding mask (Note: this doesn't care if a template exists, + unlike the template_pseudo_beta_mask). + Returns: + A template embedding [N_res, N_res, c_z]. + """ + assert mask_2d.dtype == query_embedding.dtype + dtype = query_embedding.dtype + num_res = batch['template_aatype'].shape[0] + num_channels = (self.config.template_pair_stack + .triangle_attention_ending_node.value_dim) + + if "template_dgram" in batch: + template_dgram = batch["template_dgram"] + template_mask_2d = batch["template_dgram"].sum(-1) + + else: + template_mask = batch['template_pseudo_beta_mask'] + template_mask_2d = template_mask[:, None] * template_mask[None, :] + if self.config.backprop_dgram: + template_dgram = dgram_from_positions_soft(batch['template_pseudo_beta'], + temp=self.config.backprop_dgram_temp, + **self.config.dgram_features) + else: + template_dgram = dgram_from_positions(batch['template_pseudo_beta'], + **self.config.dgram_features) + + template_mask_2d = template_mask_2d * multichain_mask_2d + template_mask_2d = template_mask_2d.astype(dtype) + + template_dgram *= template_mask_2d[..., None] + template_dgram = template_dgram.astype(dtype) + to_concat = [template_dgram, template_mask_2d[:, :, None]] + + aatype = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1, dtype=dtype) + + to_concat.append(jnp.tile(aatype[None, :, :], [num_res, 1, 1])) + to_concat.append(jnp.tile(aatype[:, None, :], [1, num_res, 1])) + + if "template_dgram" in batch: + unit_vector = [jnp.zeros((num_res,num_res,1))] * 3 + + else: + # Backbone affine mask: whether the residue has C, CA, N + # (the template mask defined above only considers pseudo CB). + n, ca, c = [residue_constants.atom_order[a] for a in ('N', 'CA', 'C')] + template_mask = ( + batch['template_all_atom_mask'][..., n] * + batch['template_all_atom_mask'][..., ca] * + batch['template_all_atom_mask'][..., c]) + template_mask_2d = template_mask[:, None] * template_mask[None, :] + template_mask_2d = template_mask_2d * multichain_mask_2d + + # compute unit_vector (not used by default) + if self.config.use_template_unit_vector: + raw_atom_pos = template_batch["template_all_atom_positions"] + if gc.bfloat16: + raw_atom_pos = raw_atom_pos.astype(jnp.float32) + + rot, trans = quat_affine.make_transform_from_reference( + n_xyz=raw_atom_pos[:, n], + ca_xyz=raw_atom_pos[:, ca], + c_xyz=raw_atom_pos[:, c]) + affines = quat_affine.QuatAffine( + quaternion=quat_affine.rot_to_quat(rot, unstack_inputs=True), + translation=trans, + rotation=rot, + unstack_inputs=True) + points = [jnp.expand_dims(x, axis=-2) for x in affines.translation] + affine_vec = affines.invert_point(points, extra_dims=1) + inv_distance_scalar = jax.lax.rsqrt(1e-6 + sum([jnp.square(x) for x in affine_vec])) + inv_distance_scalar *= template_mask_2d.astype(inv_distance_scalar.dtype) + unit_vector = [(x * inv_distance_scalar)[..., None] for x in affine_vec] + else: + unit_vector = [jnp.zeros((num_res,num_res,1))] * 3 + + unit_vector = [x.astype(dtype) for x in unit_vector] + to_concat.extend(unit_vector) + + template_mask_2d = template_mask_2d.astype(dtype) + to_concat.append(template_mask_2d[..., None]) + + act = jnp.concatenate(to_concat, axis=-1) + + # Mask out non-template regions so we don't get arbitrary values in the + # distogram for these regions. + act *= template_mask_2d[..., None] + + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 9 + act = common_modules.Linear( + num_channels, + initializer='relu', + name='embedding2d')(act) + + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 11 + act = TemplatePairStack( + self.config.template_pair_stack, self.global_config)(act, mask_2d, use_dropout=use_dropout) + + act = common_modules.LayerNorm([-1], True, True, name='output_layer_norm')(act) + return act + + +class TemplateEmbedding(hk.Module): + """Embeds a set of templates. + Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12 + Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention" + """ + + def __init__(self, config, global_config, name='template_embedding'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, query_embedding, template_batch, mask_2d, multichain_mask_2d, use_dropout): + """Build TemplateEmbedding module. + Arguments: + query_embedding: Query pair representation, shape [N_res, N_res, c_z]. + template_batch: A batch of template features. + mask_2d: Padding mask (Note: this doesn't care if a template exists, + unlike the template_pseudo_beta_mask). + Returns: + A template embedding [N_res, N_res, c_z]. + """ + + num_templates = template_batch['template_mask'].shape[0] + num_channels = (self.config.template_pair_stack + .triangle_attention_ending_node.value_dim) + num_res = query_embedding.shape[0] + + dtype = query_embedding.dtype + template_mask = template_batch['template_mask'] + template_mask = template_mask.astype(dtype) + + query_num_channels = query_embedding.shape[-1] + + # Make sure the weights are shared across templates by constructing the + # embedder here. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12 + template_embedder = SingleTemplateEmbedding(self.config, self.global_config) + + def map_fn(batch): + return template_embedder(query_embedding, batch, mask_2d, multichain_mask_2d, + use_dropout=use_dropout) + + template_pair_representation = mapping.sharded_map(map_fn, in_axes=0)(template_batch) + + # Cross attend from the query to the templates along the residue + # dimension by flattening everything else into the batch dimension. + # Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention" + flat_query = jnp.reshape(query_embedding,[num_res * num_res, 1, query_num_channels]) + + flat_templates = jnp.reshape( + jnp.transpose(template_pair_representation, [1, 2, 0, 3]), + [num_res * num_res, num_templates, num_channels]) + + bias = (1e9 * (template_mask[None, None, None, :] - 1.)) + + template_pointwise_attention_module = Attention( + self.config.attention, self.global_config, query_num_channels) + nonbatched_args = [bias] + batched_args = [flat_query, flat_templates] + + embedding = mapping.inference_subbatch( + template_pointwise_attention_module, + self.config.subbatch_size, + batched_args=batched_args, + nonbatched_args=nonbatched_args, + low_memory=self.config.subbatch_size is not None) + embedding = jnp.reshape(embedding,[num_res, num_res, query_num_channels]) + + # No gradients if no templates. + embedding *= (jnp.sum(template_mask) > 0.).astype(embedding.dtype) + + return embedding +#################################################################### \ No newline at end of file diff --git a/build/lib/colabdesign/af/alphafold/model/modules_multimer.py b/build/lib/colabdesign/af/alphafold/model/modules_multimer.py new file mode 100644 index 00000000..8822c6d8 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/modules_multimer.py @@ -0,0 +1,805 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core modules, which have been refactored in AlphaFold-Multimer. + +The main difference is that MSA sampling pipeline is moved inside the JAX model +for easier implementation of recycling and ensembling. + +Lower-level modules up to EvoformerIteration are reused from modules.py. +""" + +import functools +from typing import Sequence + +from colabdesign.af.alphafold.common import residue_constants +from colabdesign.af.alphafold.model import all_atom_multimer +from colabdesign.af.alphafold.model import common_modules +from colabdesign.af.alphafold.model import folding_multimer +from colabdesign.af.alphafold.model import geometry +from colabdesign.af.alphafold.model import layer_stack +from colabdesign.af.alphafold.model import modules +from colabdesign.af.alphafold.model import prng +from colabdesign.af.alphafold.model import utils + +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np + +def create_extra_msa_feature(batch, num_extra_msa): + """Expand extra_msa into 1hot and concat with other extra msa features. + We do this as late as possible as the one_hot extra msa can be very large. + Args: + batch: a dictionary with the following keys: + * 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster + centre. Note - This isn't one-hotted. + * 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given + position. + num_extra_msa: Number of extra msa to use. + Returns: + Concatenated tensor of extra MSA features. + """ + # 23 = 20 amino acids + 'X' for unknown + gap + bert mask + extra_msa = batch['extra_msa'][:num_extra_msa] + deletion_matrix = batch['extra_deletion_value'][:num_extra_msa] + msa_1hot = jax.nn.one_hot(extra_msa, 23) + has_deletion = jnp.clip(deletion_matrix, 0., 1.)[..., None] + deletion_value = (jnp.arctan(deletion_matrix / 3.) * (2. / jnp.pi))[..., None] + extra_msa_mask = batch['extra_msa_mask'][:num_extra_msa] + return jnp.concatenate([msa_1hot, has_deletion, deletion_value], + axis=-1), extra_msa_mask + +class AlphaFoldIteration(hk.Module): + """A single recycling iteration of AlphaFold architecture. + + Computes ensembled (averaged) representations from the provided features. + These representations are then passed to the various heads + that have been requested by the configuration file. + """ + + def __init__(self, config, global_config, name='alphafold_iteration'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, + batch, + return_representations=False, + safe_key=None): + + + # Compute representations for each MSA sample and average. + embedding_module = EmbeddingsAndEvoformer( + self.config.embeddings_and_evoformer, self.global_config) + + safe_key, safe_subkey = safe_key.split() + representations = embedding_module(batch, safe_key=safe_subkey) + + self.representations = representations + self.batch = batch + self.heads = {} + for head_name, head_config in sorted(self.config.heads.items()): + if not head_config.weight: + continue # Do not instantiate zero-weight heads. + + head_factory = { + 'masked_msa': + modules.MaskedMsaHead, + 'distogram': + modules.DistogramHead, + 'structure_module': + folding_multimer.StructureModule, + 'predicted_aligned_error': + modules.PredictedAlignedErrorHead, + 'predicted_lddt': + modules.PredictedLDDTHead, + 'experimentally_resolved': + modules.ExperimentallyResolvedHead, + }[head_name] + self.heads[head_name] = (head_config, + head_factory(head_config, self.global_config)) + + structure_module_output = None + if 'entity_id' in batch and 'all_atom_positions' in batch: + _, fold_module = self.heads['structure_module'] + structure_module_output = fold_module(representations, batch) + + + ret = {} + ret['representations'] = representations + + for name, (head_config, module) in self.heads.items(): + if name == 'structure_module' and structure_module_output is not None: + ret[name] = structure_module_output + representations['structure_module'] = structure_module_output.pop('act') + # Skip confidence heads until StructureModule is executed. + elif name in {'predicted_lddt', 'predicted_aligned_error', + 'experimentally_resolved'}: + continue + else: + ret[name] = module(representations, batch) + + + # Add confidence heads after StructureModule is executed. + if self.config.heads.get('predicted_lddt.weight', 0.0): + name = 'predicted_lddt' + head_config, module = self.heads[name] + ret[name] = module(representations, batch) + + if self.config.heads.experimentally_resolved.weight: + name = 'experimentally_resolved' + head_config, module = self.heads[name] + ret[name] = module(representations, batch) + + if self.config.heads.get('predicted_aligned_error.weight', 0.0): + name = 'predicted_aligned_error' + head_config, module = self.heads[name] + ret[name] = module(representations, batch) + # Will be used for ipTM computation. + ret[name]['asym_id'] = batch['asym_id'] + + return ret + +class AlphaFold(hk.Module): + """AlphaFold-Multimer model with recycling. + """ + + def __init__(self, config, name='alphafold'): + super().__init__(name=name) + self.config = config + self.global_config = config.global_config + + def __call__( + self, + batch, + return_representations=False, + safe_key=None): + + c = self.config + impl = AlphaFoldIteration(c, self.global_config) + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + elif isinstance(safe_key, jnp.ndarray): + safe_key = prng.SafeKey(safe_key) + + assert isinstance(batch, dict) + num_res = batch['aatype'].shape[0] + + def get_prev(ret): + new_prev = { + 'prev_pos': ret['structure_module']['final_atom_positions'], + 'prev_msa_first_row': ret['representations']['msa_first_row'], + 'prev_pair': ret['representations']['pair'], + } + return new_prev + + def apply_network(prev, safe_key): + recycled_batch = {**batch, **prev} + return impl( + batch=recycled_batch, + safe_key=safe_key) + + ret = apply_network(prev=batch.pop("prev"), safe_key=safe_key) + ret["prev"] = get_prev(ret) + + if not return_representations: + del ret['representations'] + return ret + +class EmbeddingsAndEvoformer(hk.Module): + """Embeds the input data and runs Evoformer. + + Produces the MSA, single and pair representations. + """ + + def __init__(self, config, global_config, name='evoformer'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def _relative_encoding(self, batch): + """Add relative position encodings. + + For position (i, j), the value is (i-j) clipped to [-k, k] and one-hotted. + + When not using 'use_chain_relative' the residue indices are used as is, e.g. + for heteromers relative positions will be computed using the positions in + the corresponding chains. + + When using 'use_chain_relative' we add an extra bin that denotes + 'different chain'. Furthermore we also provide the relative chain index + (i.e. sym_id) clipped and one-hotted to the network. And an extra feature + which denotes whether they belong to the same chain type, i.e. it's 0 if + they are in different heteromer chains and 1 otherwise. + + Args: + batch: batch. + Returns: + Feature embedding using the features as described before. + """ + c = self.config + gc = self.global_config + rel_feats = [] + asym_id = batch['asym_id'] + asym_id_same = jnp.equal(asym_id[:, None], asym_id[None, :]) + + if "offset" in batch: + offset = batch['offset'] + else: + pos = batch['residue_index'] + offset = pos[:, None] - pos[None, :] + + clipped_offset = jnp.clip( + offset + c.max_relative_idx, a_min=0, a_max=2 * c.max_relative_idx) + + dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32 + + if c.use_chain_relative: + + final_offset = jnp.where(asym_id_same, clipped_offset, + (2 * c.max_relative_idx + 1) * + jnp.ones_like(clipped_offset)) + + rel_pos = jax.nn.one_hot(final_offset, 2 * c.max_relative_idx + 2) + + rel_feats.append(rel_pos) + + entity_id = batch['entity_id'] + entity_id_same = jnp.equal(entity_id[:, None], entity_id[None, :]) + rel_feats.append(entity_id_same.astype(rel_pos.dtype)[..., None]) + + sym_id = batch['sym_id'] + rel_sym_id = sym_id[:, None] - sym_id[None, :] + + max_rel_chain = c.max_relative_chain + + clipped_rel_chain = jnp.clip( + rel_sym_id + max_rel_chain, a_min=0, a_max=2 * max_rel_chain) + + final_rel_chain = jnp.where(entity_id_same, clipped_rel_chain, + (2 * max_rel_chain + 1) * + jnp.ones_like(clipped_rel_chain)) + rel_chain = jax.nn.one_hot(final_rel_chain, 2 * c.max_relative_chain + 2) + + rel_feats.append(rel_chain) + + else: + rel_pos = jax.nn.one_hot(clipped_offset, 2 * c.max_relative_idx + 1) + rel_feats.append(rel_pos) + + rel_feat = jnp.concatenate(rel_feats, axis=-1) + + rel_feat = rel_feat.astype(dtype) + return common_modules.Linear( + c.pair_channel, + name='position_activations')(rel_feat) + + def __call__(self, batch, safe_key=None): + + c = self.config + gc = self.global_config + + batch = dict(batch) + dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32 + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + output = {} + with utils.bfloat16_context(): + + msa_feat = batch['msa_feat'].astype(dtype) + target_feat = jnp.pad(batch["target_feat"].astype(dtype),[[0,0],[0,1]]) + preprocess_1d = common_modules.Linear(c.msa_channel, name='preprocess_1d')(target_feat) + preprocess_msa = common_modules.Linear(c.msa_channel, name='preprocess_msa')(msa_feat) + msa_activations = preprocess_1d[None] + preprocess_msa + num_msa_sequences = msa_activations.shape[0] + + left_single = common_modules.Linear(c.pair_channel, name='left_single')(target_feat) + right_single = common_modules.Linear(c.pair_channel, name='right_single')(target_feat) + pair_activations = left_single[:, None] + right_single[None] + mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :] + mask_2d = mask_2d.astype(dtype) + + if c.recycle_pos: + prev_pseudo_beta = modules.pseudo_beta_fn(batch['aatype'], batch['prev_pos'], None) + dgram = modules.dgram_from_positions(prev_pseudo_beta, **self.config.prev_pos) + dgram = dgram.astype(dtype) + pair_activations += common_modules.Linear(c.pair_channel, name='prev_pos_linear')(dgram) + + if c.recycle_features: + prev_msa_first_row = common_modules.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='prev_msa_first_row_norm')(batch['prev_msa_first_row']).astype(dtype) + + msa_activations = msa_activations.at[0].add(prev_msa_first_row) + + pair_activations += common_modules.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='prev_pair_norm')(batch['prev_pair']).astype(dtype) + + if c.max_relative_idx: + pair_activations += self._relative_encoding(batch) + + if c.template.enabled: + template_module = TemplateEmbedding(c.template, gc) + template_batch = { + 'template_aatype': batch['template_aatype'], + 'template_all_atom_positions': batch['template_all_atom_positions'], + 'template_all_atom_mask': batch['template_all_atom_mask'] + } + if "template_dgram" in batch: + template_batch["template_dgram"] = batch["template_dgram"] + + # Construct a mask such that only intra-chain template features are + # computed, since all templates are for each chain individually. + multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :] + multichain_mask = jnp.where(batch["mask_template_interchain"], multichain_mask, True) + + safe_key, safe_subkey = safe_key.split() + template_act = template_module( + query_embedding=pair_activations, + template_batch=template_batch, + padding_mask_2d=mask_2d, + multichain_mask_2d=multichain_mask, + use_dropout=batch["use_dropout"], + safe_key=safe_subkey) + pair_activations += template_act + + # Extra MSA stack. + (extra_msa_feat, extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa) + extra_msa_activations = common_modules.Linear(c.extra_msa_channel, + name='extra_msa_activations')(extra_msa_feat).astype(dtype) + extra_msa_mask = extra_msa_mask.astype(dtype) + extra_evoformer_input = {'msa': extra_msa_activations, 'pair': pair_activations} + extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d} + extra_evoformer_iteration = modules.EvoformerIteration(c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack') + + def extra_evoformer_fn(x): + act, safe_key = x + safe_key, safe_subkey = safe_key.split() + extra_evoformer_output = extra_evoformer_iteration( + activations=act, + masks=extra_masks, + use_dropout=batch["use_dropout"], + safe_key=safe_subkey) + return (extra_evoformer_output, safe_key) + + if gc.use_remat: + extra_evoformer_fn = hk.remat(extra_evoformer_fn) + + safe_key, safe_subkey = safe_key.split() + extra_evoformer_stack = layer_stack.layer_stack( + c.extra_msa_stack_num_block)( + extra_evoformer_fn) + extra_evoformer_output, safe_key = extra_evoformer_stack( + (extra_evoformer_input, safe_subkey)) + + pair_activations = extra_evoformer_output['pair'] + # Get the size of the MSA before potentially adding templates, so we + # can crop out the templates later. + num_msa_sequences = msa_activations.shape[0] + evoformer_input = { + 'msa': msa_activations, + 'pair': pair_activations, + } + evoformer_masks = {'msa': batch['msa_mask'].astype(dtype), 'pair': mask_2d} + if c.template.enabled: + template_features, template_masks = ( + template_embedding_1d(batch=batch, num_channel=c.msa_channel, global_config=gc)) + + evoformer_input['msa'] = jnp.concatenate([evoformer_input['msa'], template_features], axis=0) + evoformer_masks['msa'] = jnp.concatenate([evoformer_masks['msa'], template_masks], axis=0) + + evoformer_iteration = modules.EvoformerIteration( + c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration') + + def evoformer_fn(x): + act, safe_key = x + safe_key, safe_subkey = safe_key.split() + evoformer_output = evoformer_iteration( + activations=act, + masks=evoformer_masks, + use_dropout=batch["use_dropout"], + safe_key=safe_subkey) + return (evoformer_output, safe_key) + + if gc.use_remat: + evoformer_fn = hk.remat(evoformer_fn) + + safe_key, safe_subkey = safe_key.split() + evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)( + evoformer_fn) + + def run_evoformer(evoformer_input): + evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey)) + return evoformer_output + + evoformer_output = run_evoformer(evoformer_input) + + msa_activations = evoformer_output['msa'] + pair_activations = evoformer_output['pair'] + + single_activations = common_modules.Linear( + c.seq_channel, name='single_activations')(msa_activations[0]) + output.update({ + 'single': + single_activations, + 'pair': + pair_activations, + # Crop away template rows such that they are not used in MaskedMsaHead. + 'msa': + msa_activations[:num_msa_sequences, :, :], + 'msa_first_row': + msa_activations[0], + }) + + # Convert back to float32 if we're not saving memory. + if not gc.bfloat16_output: + for k, v in output.items(): + if v.dtype == jnp.bfloat16: + output[k] = v.astype(jnp.float32) + + return output + + +class TemplateEmbedding(hk.Module): + """Embed a set of templates.""" + + def __init__(self, config, global_config, name='template_embedding'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, query_embedding, template_batch, padding_mask_2d, + multichain_mask_2d, use_dropout, safe_key=None): + """Generate an embedding for a set of templates. + + Args: + query_embedding: [num_res, num_res, num_channel] a query tensor that will + be used to attend over the templates to remove the num_templates + dimension. + template_batch: A dictionary containing: + `template_aatype`: [num_templates, num_res] aatype for each template. + `template_all_atom_positions`: [num_templates, num_res, 37, 3] atom + positions for all templates. + `template_all_atom_mask`: [num_templates, num_res, 37] mask for each + template. + padding_mask_2d: [num_res, num_res] Pair mask for attention operations. + multichain_mask_2d: [num_res, num_res] Mask indicating which residue pairs + are intra-chain, used to mask out residue distance based features + between chains. + safe_key: random key generator. + + Returns: + An embedding of size [num_res, num_res, num_channels] + """ + c = self.config + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + num_templates = template_batch['template_aatype'].shape[0] + num_res, _, query_num_channels = query_embedding.shape + + # Embed each template separately. + template_embedder = SingleTemplateEmbedding(self.config, self.global_config) + def partial_template_embedder(template_batch, unsafe_key): + safe_key = prng.SafeKey(unsafe_key) + return template_embedder(query_embedding, + template_batch, + padding_mask_2d, + multichain_mask_2d, + use_dropout, + safe_key) + + safe_key, unsafe_key = safe_key.split() + unsafe_keys = jax.random.split(unsafe_key._key, num_templates) + + def scan_fn(carry, x): + return carry + partial_template_embedder(*x), None + + scan_init = jnp.zeros((num_res, num_res, c.num_channels), dtype=query_embedding.dtype) + summed_template_embeddings, _ = hk.scan(scan_fn, scan_init, (template_batch, unsafe_keys)) + + embedding = summed_template_embeddings / num_templates + embedding = jax.nn.relu(embedding) + embedding = common_modules.Linear( + query_num_channels, + initializer='relu', + name='output_linear')(embedding) + + return embedding + + +class SingleTemplateEmbedding(hk.Module): + """Embed a single template.""" + + def __init__(self, config, global_config, name='single_template_embedding'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, query_embedding, template_batch, + padding_mask_2d, multichain_mask_2d, use_dropout, safe_key): + """Build the single template embedding graph. + + Args: + query_embedding: (num_res, num_res, num_channels) - embedding of the + query sequence/msa. + template_aatype: [num_res] aatype for each template. + template_all_atom_positions: [num_res, 37, 3] atom positions for all + templates. + template_all_atom_mask: [num_res, 37] mask for each template. + padding_mask_2d: Padding mask (Note: this doesn't care if a template + exists, unlike the template_pseudo_beta_mask). + multichain_mask_2d: A mask indicating intra-chain residue pairs, used + to mask out between chain distances/features when templates are for + single chains. + safe_key: Random key generator. + + Returns: + A template embedding (num_res, num_res, num_channels). + """ + gc = self.global_config + c = self.config + assert padding_mask_2d.dtype == query_embedding.dtype + dtype = query_embedding.dtype + num_channels = self.config.num_channels + + def construct_input(query_embedding, template_batch, multichain_mask_2d): + + if "template_dgram" in template_batch: + template_dgram = template_batch["template_dgram"].astype(dtype) + template_dgram *= multichain_mask_2d[...,None] + pseudo_beta_mask_2d = template_dgram.sum(-1) + + else: + # Compute distogram feature for the template. + template_positions, pseudo_beta_mask = modules.pseudo_beta_fn( + template_batch["template_aatype"], + template_batch["template_all_atom_positions"], + template_batch["template_all_atom_mask"]) + pseudo_beta_mask_2d = (pseudo_beta_mask[:, None] * + pseudo_beta_mask[None, :]) + pseudo_beta_mask_2d *= multichain_mask_2d + template_dgram = modules.dgram_from_positions( + template_positions, **self.config.dgram_features) + template_dgram *= pseudo_beta_mask_2d[..., None] + template_dgram = template_dgram.astype(dtype) + pseudo_beta_mask_2d = pseudo_beta_mask_2d.astype(dtype) + + to_concat = [(template_dgram, 1), (pseudo_beta_mask_2d, 0)] + aatype = jax.nn.one_hot(template_batch["template_aatype"], 22, axis=-1, dtype=dtype) + to_concat.append((aatype[None, :, :], 1)) + to_concat.append((aatype[:, None, :], 1)) + + # Compute a feature representing the normalized vector between each + # backbone affine - i.e. in each residues local frame, what direction are + # each of the other residues. + raw_atom_pos = template_batch["template_all_atom_positions"] + if gc.bfloat16: + raw_atom_pos = raw_atom_pos.astype(jnp.float32) + + atom_pos = geometry.Vec3Array.from_array(raw_atom_pos) + rigid, backbone_mask = folding_multimer.make_backbone_affine( + atom_pos, + template_batch["template_all_atom_mask"], + template_batch["template_aatype"]) + points = rigid.translation + rigid_vec = rigid[:, None].inverse().apply_to_point(points) + unit_vector = rigid_vec.normalized() + unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z] + + if gc.bfloat16: + unit_vector = [x.astype(jnp.bfloat16) for x in unit_vector] + backbone_mask = backbone_mask.astype(jnp.bfloat16) + + backbone_mask_2d = jnp.sqrt(backbone_mask[:,None] * backbone_mask[None,:]) + backbone_mask_2d *= multichain_mask_2d + unit_vector = [x*backbone_mask_2d for x in unit_vector] + + # Note that the backbone_mask takes into account C, CA and N (unlike + # pseudo beta mask which just needs CB) so we add both masks as features. + to_concat.extend([(x, 0) for x in unit_vector]) + to_concat.append((backbone_mask_2d, 0)) + query_embedding = common_modules.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='query_embedding_norm')(query_embedding) + # Allow the template embedder to see the query embedding. Note this + # contains the position relative feature, so this is how the network knows + # which residues are next to each other. + to_concat.append((query_embedding, 1)) + + act = 0 + + for i, (x, n_input_dims) in enumerate(to_concat): + act += common_modules.Linear( + num_channels, + num_input_dims=n_input_dims, + initializer='relu', + name=f'template_pair_embedding_{i}')(x) + return act + + act = construct_input(query_embedding, template_batch, multichain_mask_2d) + + template_iteration = TemplateEmbeddingIteration( + c.template_pair_stack, gc, name='template_embedding_iteration') + + def template_iteration_fn(x): + act, safe_key = x + + safe_key, safe_subkey = safe_key.split() + act = template_iteration( + act=act, + pair_mask=padding_mask_2d, + use_dropout=use_dropout, + safe_key=safe_subkey) + return (act, safe_key) + + if gc.use_remat: + template_iteration_fn = hk.remat(template_iteration_fn) + + safe_key, safe_subkey = safe_key.split() + template_stack = layer_stack.layer_stack( + c.template_pair_stack.num_block)( + template_iteration_fn) + act, safe_key = template_stack((act, safe_subkey)) + + act = common_modules.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='output_layer_norm')( + act) + return act + +class TemplateEmbeddingIteration(hk.Module): + """Single Iteration of Template Embedding.""" + + def __init__(self, config, global_config, + name='template_embedding_iteration'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, act, pair_mask, use_dropout, safe_key=None): + """Build a single iteration of the template embedder. + + Args: + act: [num_res, num_res, num_channel] Input pairwise activations. + pair_mask: [num_res, num_res] padding mask. + safe_key: Safe pseudo-random generator key. + + Returns: + [num_res, num_res, num_channel] tensor of activations. + """ + c = self.config + gc = self.global_config + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + dropout_wrapper_fn = functools.partial( + modules.dropout_wrapper, + use_dropout=use_dropout, + global_config=gc) + + safe_key, *sub_keys = safe_key.split(20) + sub_keys = iter(sub_keys) + + act = dropout_wrapper_fn( + modules.TriangleMultiplication(c.triangle_multiplication_outgoing, gc, + name='triangle_multiplication_outgoing'), + act, + pair_mask, + safe_key=next(sub_keys)) + + act = dropout_wrapper_fn( + modules.TriangleMultiplication(c.triangle_multiplication_incoming, gc, + name='triangle_multiplication_incoming'), + act, + pair_mask, + safe_key=next(sub_keys)) + + act = dropout_wrapper_fn( + modules.TriangleAttention(c.triangle_attention_starting_node, gc, + name='triangle_attention_starting_node'), + act, + pair_mask, + safe_key=next(sub_keys)) + + act = dropout_wrapper_fn( + modules.TriangleAttention(c.triangle_attention_ending_node, gc, + name='triangle_attention_ending_node'), + act, + pair_mask, + safe_key=next(sub_keys)) + + act = dropout_wrapper_fn( + modules.Transition(c.pair_transition, gc, + name='pair_transition'), + act, + pair_mask, + safe_key=next(sub_keys)) + + return act + + +def template_embedding_1d(batch, num_channel, global_config): + """Embed templates into an (num_res, num_templates, num_channels) embedding. + + Args: + batch: A batch containing: + template_aatype, (num_templates, num_res) aatype for the templates. + template_all_atom_positions, (num_templates, num_residues, 37, 3) atom + positions for the templates. + template_all_atom_mask, (num_templates, num_residues, 37) atom mask for + each template. + num_channel: The number of channels in the output. + + Returns: + An embedding of shape (num_templates, num_res, num_channels) and a mask of + shape (num_templates, num_res). + """ + + # Embed the templates aatypes. + aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1) + + num_templates = batch['template_aatype'].shape[0] + all_chi_angles = [] + all_chi_masks = [] + for i in range(num_templates): + atom_pos = geometry.Vec3Array.from_array( + batch['template_all_atom_positions'][i, :, :, :]) + template_chi_angles, template_chi_mask = all_atom_multimer.compute_chi_angles( + atom_pos, + batch['template_all_atom_mask'][i, :, :], + batch['template_aatype'][i, :]) + all_chi_angles.append(template_chi_angles) + all_chi_masks.append(template_chi_mask) + chi_angles = jnp.stack(all_chi_angles, axis=0) + chi_mask = jnp.stack(all_chi_masks, axis=0) + + template_features = jnp.concatenate([ + aatype_one_hot, + jnp.sin(chi_angles) * chi_mask, + jnp.cos(chi_angles) * chi_mask, + chi_mask], axis=-1) + + template_mask = chi_mask[:, :, 0] + + if global_config.bfloat16: + template_features = template_features.astype(jnp.bfloat16) + template_mask = template_mask.astype(jnp.bfloat16) + + template_activations = common_modules.Linear( + num_channel, + initializer='relu', + name='template_single_embedding')( + template_features) + template_activations = jax.nn.relu(template_activations) + template_activations = common_modules.Linear( + num_channel, + initializer='relu', + name='template_projection')( + template_activations) + return template_activations, template_mask diff --git a/build/lib/colabdesign/af/alphafold/model/prng.py b/build/lib/colabdesign/af/alphafold/model/prng.py new file mode 100644 index 00000000..64f348c9 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/prng.py @@ -0,0 +1,67 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A collection of utilities surrounding PRNG usage in protein folding.""" + +import haiku as hk +import jax + +def safe_dropout(*, tensor, safe_key, rate): + """Applies dropout to a tensor.""" + keep_rate = 1.0 - rate + keep = jax.random.bernoulli(safe_key.get(), keep_rate, shape=tensor.shape) + return keep * tensor / keep_rate + +class SafeKey: + """Safety wrapper for PRNG keys.""" + + def __init__(self, key): + self._key = key + self._used = False + + def _assert_not_used(self): + if self._used: + raise RuntimeError('Random key has been used previously.') + + def get(self): + self._assert_not_used() + self._used = True + return self._key + + def split(self, num_keys=2): + self._assert_not_used() + self._used = True + new_keys = jax.random.split(self._key, num_keys) + return jax.tree_map(SafeKey, tuple(new_keys)) + + def duplicate(self, num_keys=2): + self._assert_not_used() + self._used = True + return tuple(SafeKey(self._key) for _ in range(num_keys)) + + +def _safe_key_flatten(safe_key): + # Flatten transfers "ownership" to the tree + return (safe_key._key,), safe_key._used # pylint: disable=protected-access + + +def _safe_key_unflatten(aux_data, children): + ret = SafeKey(children[0]) + ret._used = aux_data # pylint: disable=protected-access + return ret + + +jax.tree_util.register_pytree_node( + SafeKey, _safe_key_flatten, _safe_key_unflatten) + diff --git a/build/lib/colabdesign/af/alphafold/model/quat_affine.py b/build/lib/colabdesign/af/alphafold/model/quat_affine.py new file mode 100644 index 00000000..9ebcd20f --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/quat_affine.py @@ -0,0 +1,459 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quaternion geometry modules. + +This introduces a representation of coordinate frames that is based around a +‘QuatAffine’ object. This object describes an array of coordinate frames. +It consists of vectors corresponding to the +origin of the frames as well as orientations which are stored in two +ways, as unit quaternions as well as a rotation matrices. +The rotation matrices are derived from the unit quaternions and the two are kept +in sync. +For an explanation of the relation between unit quaternions and rotations see +https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation + +This representation is used in the model for the backbone frames. + +One important thing to note here, is that while we update both representations +the jit compiler is going to ensure that only the parts that are +actually used are executed. +""" + + +import functools +from typing import Tuple + +import jax +import jax.numpy as jnp +import numpy as np + +# pylint: disable=bad-whitespace +QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32) + +QUAT_TO_ROT[0, 0] = [[ 1, 0, 0], [ 0, 1, 0], [ 0, 0, 1]] # rr +QUAT_TO_ROT[1, 1] = [[ 1, 0, 0], [ 0,-1, 0], [ 0, 0,-1]] # ii +QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [ 0, 1, 0], [ 0, 0,-1]] # jj +QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [ 0,-1, 0], [ 0, 0, 1]] # kk + +QUAT_TO_ROT[1, 2] = [[ 0, 2, 0], [ 2, 0, 0], [ 0, 0, 0]] # ij +QUAT_TO_ROT[1, 3] = [[ 0, 0, 2], [ 0, 0, 0], [ 2, 0, 0]] # ik +QUAT_TO_ROT[2, 3] = [[ 0, 0, 0], [ 0, 0, 2], [ 0, 2, 0]] # jk + +QUAT_TO_ROT[0, 1] = [[ 0, 0, 0], [ 0, 0,-2], [ 0, 2, 0]] # ir +QUAT_TO_ROT[0, 2] = [[ 0, 0, 2], [ 0, 0, 0], [-2, 0, 0]] # jr +QUAT_TO_ROT[0, 3] = [[ 0,-2, 0], [ 2, 0, 0], [ 0, 0, 0]] # kr + +QUAT_MULTIPLY = np.zeros((4, 4, 4), dtype=np.float32) +QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0], + [ 0,-1, 0, 0], + [ 0, 0,-1, 0], + [ 0, 0, 0,-1]] + +QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0], + [ 1, 0, 0, 0], + [ 0, 0, 0, 1], + [ 0, 0,-1, 0]] + +QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0], + [ 0, 0, 0,-1], + [ 1, 0, 0, 0], + [ 0, 1, 0, 0]] + +QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1], + [ 0, 0, 1, 0], + [ 0,-1, 0, 0], + [ 1, 0, 0, 0]] + +QUAT_MULTIPLY_BY_VEC = QUAT_MULTIPLY[:, 1:, :] +# pylint: enable=bad-whitespace + + +def rot_to_quat(rot, unstack_inputs=False): + """Convert rotation matrix to quaternion. + + Note that this function calls self_adjoint_eig which is extremely expensive on + the GPU. If at all possible, this function should run on the CPU. + + Args: + rot: rotation matrix (see below for format). + unstack_inputs: If true, rotation matrix should be shape (..., 3, 3) + otherwise the rotation matrix should be a list of lists of tensors. + + Returns: + Quaternion as (..., 4) tensor. + """ + if unstack_inputs: + rot = [jnp.moveaxis(x, -1, 0) for x in jnp.moveaxis(rot, -2, 0)] + + [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot + + # pylint: disable=bad-whitespace + k = [[ xx + yy + zz, zy - yz, xz - zx, yx - xy,], + [ zy - yz, xx - yy - zz, xy + yx, xz + zx,], + [ xz - zx, xy + yx, yy - xx - zz, yz + zy,], + [ yx - xy, xz + zx, yz + zy, zz - xx - yy,]] + # pylint: enable=bad-whitespace + + k = (1./3.) * jnp.stack([jnp.stack(x, axis=-1) for x in k], + axis=-2) + + # Get eigenvalues in non-decreasing order and associated. + _, qs = jnp.linalg.eigh(k) + return qs[..., -1] + + +def rot_list_to_tensor(rot_list): + """Convert list of lists to rotation tensor.""" + return jnp.stack( + [jnp.stack(rot_list[0], axis=-1), + jnp.stack(rot_list[1], axis=-1), + jnp.stack(rot_list[2], axis=-1)], + axis=-2) + + +def vec_list_to_tensor(vec_list): + """Convert list to vector tensor.""" + return jnp.stack(vec_list, axis=-1) + + +def quat_to_rot(normalized_quat): + """Convert a normalized quaternion to a rotation matrix.""" + rot_tensor = jnp.sum( + np.reshape(QUAT_TO_ROT, (4, 4, 9)) * + normalized_quat[..., :, None, None] * + normalized_quat[..., None, :, None], + axis=(-3, -2)) + rot = jnp.moveaxis(rot_tensor, -1, 0) # Unstack. + return [[rot[0], rot[1], rot[2]], + [rot[3], rot[4], rot[5]], + [rot[6], rot[7], rot[8]]] + + +def quat_multiply_by_vec(quat, vec): + """Multiply a quaternion by a pure-vector quaternion.""" + return jnp.sum( + QUAT_MULTIPLY_BY_VEC * + quat[..., :, None, None] * + vec[..., None, :, None], + axis=(-3, -2)) + + +def quat_multiply(quat1, quat2): + """Multiply a quaternion by another quaternion.""" + return jnp.sum( + QUAT_MULTIPLY * + quat1[..., :, None, None] * + quat2[..., None, :, None], + axis=(-3, -2)) + + +def apply_rot_to_vec(rot, vec, unstack=False): + """Multiply rotation matrix by a vector.""" + if unstack: + x, y, z = [vec[:, i] for i in range(3)] + else: + x, y, z = vec + return [rot[0][0] * x + rot[0][1] * y + rot[0][2] * z, + rot[1][0] * x + rot[1][1] * y + rot[1][2] * z, + rot[2][0] * x + rot[2][1] * y + rot[2][2] * z] + + +def apply_inverse_rot_to_vec(rot, vec): + """Multiply the inverse of a rotation matrix by a vector.""" + # Inverse rotation is just transpose + return [rot[0][0] * vec[0] + rot[1][0] * vec[1] + rot[2][0] * vec[2], + rot[0][1] * vec[0] + rot[1][1] * vec[1] + rot[2][1] * vec[2], + rot[0][2] * vec[0] + rot[1][2] * vec[1] + rot[2][2] * vec[2]] + + +class QuatAffine(object): + """Affine transformation represented by quaternion and vector.""" + + def __init__(self, quaternion, translation, rotation=None, normalize=True, + unstack_inputs=False): + """Initialize from quaternion and translation. + + Args: + quaternion: Rotation represented by a quaternion, to be applied + before translation. Must be a unit quaternion unless normalize==True. + translation: Translation represented as a vector. + rotation: Same rotation as the quaternion, represented as a (..., 3, 3) + tensor. If None, rotation will be calculated from the quaternion. + normalize: If True, l2 normalize the quaternion on input. + unstack_inputs: If True, translation is a vector with last component 3 + """ + + if quaternion is not None: + assert quaternion.shape[-1] == 4 + + if unstack_inputs: + if rotation is not None: + rotation = [jnp.moveaxis(x, -1, 0) # Unstack. + for x in jnp.moveaxis(rotation, -2, 0)] # Unstack. + translation = jnp.moveaxis(translation, -1, 0) # Unstack. + + if normalize and quaternion is not None: + quaternion = quaternion / jnp.linalg.norm(quaternion, axis=-1, + keepdims=True) + + if rotation is None: + rotation = quat_to_rot(quaternion) + + self.quaternion = quaternion + self.rotation = [list(row) for row in rotation] + self.translation = list(translation) + + assert all(len(row) == 3 for row in self.rotation) + assert len(self.translation) == 3 + + def to_tensor(self): + return jnp.concatenate( + [self.quaternion] + + [jnp.expand_dims(x, axis=-1) for x in self.translation], + axis=-1) + + def apply_tensor_fn(self, tensor_fn): + """Return a new QuatAffine with tensor_fn applied (e.g. stop_gradient).""" + return QuatAffine( + tensor_fn(self.quaternion), + [tensor_fn(x) for x in self.translation], + rotation=[[tensor_fn(x) for x in row] for row in self.rotation], + normalize=False) + + def apply_rotation_tensor_fn(self, tensor_fn): + """Return a new QuatAffine with tensor_fn applied to the rotation part.""" + return QuatAffine( + tensor_fn(self.quaternion), + [x for x in self.translation], + rotation=[[tensor_fn(x) for x in row] for row in self.rotation], + normalize=False) + + def scale_translation(self, position_scale): + """Return a new quat affine with a different scale for translation.""" + + return QuatAffine( + self.quaternion, + [x * position_scale for x in self.translation], + rotation=[[x for x in row] for row in self.rotation], + normalize=False) + + @classmethod + def from_tensor(cls, tensor, normalize=False): + quaternion, tx, ty, tz = jnp.split(tensor, [4, 5, 6], axis=-1) + return cls(quaternion, + [tx[..., 0], ty[..., 0], tz[..., 0]], + normalize=normalize) + + def pre_compose(self, update): + """Return a new QuatAffine which applies the transformation update first. + + Args: + update: Length-6 vector. 3-vector of x, y, and z such that the quaternion + update is (1, x, y, z) and zero for the 3-vector is the identity + quaternion. 3-vector for translation concatenated. + + Returns: + New QuatAffine object. + """ + vector_quaternion_update, x, y, z = jnp.split(update, [3, 4, 5], axis=-1) + trans_update = [jnp.squeeze(x, axis=-1), + jnp.squeeze(y, axis=-1), + jnp.squeeze(z, axis=-1)] + + new_quaternion = (self.quaternion + + quat_multiply_by_vec(self.quaternion, + vector_quaternion_update)) + + trans_update = apply_rot_to_vec(self.rotation, trans_update) + new_translation = [ + self.translation[0] + trans_update[0], + self.translation[1] + trans_update[1], + self.translation[2] + trans_update[2]] + + return QuatAffine(new_quaternion, new_translation) + + def apply_to_point(self, point, extra_dims=0): + """Apply affine to a point. + + Args: + point: List of 3 tensors to apply affine. + extra_dims: Number of dimensions at the end of the transformed_point + shape that are not present in the rotation and translation. The most + common use is rotation N points at once with extra_dims=1 for use in a + network. + + Returns: + Transformed point after applying affine. + """ + rotation = self.rotation + translation = self.translation + for _ in range(extra_dims): + expand_fn = functools.partial(jnp.expand_dims, axis=-1) + rotation = jax.tree_map(expand_fn, rotation) + translation = jax.tree_map(expand_fn, translation) + + rot_point = apply_rot_to_vec(rotation, point) + return [ + rot_point[0] + translation[0], + rot_point[1] + translation[1], + rot_point[2] + translation[2]] + + def invert_point(self, transformed_point, extra_dims=0): + """Apply inverse of transformation to a point. + + Args: + transformed_point: List of 3 tensors to apply affine + extra_dims: Number of dimensions at the end of the transformed_point + shape that are not present in the rotation and translation. The most + common use is rotation N points at once with extra_dims=1 for use in a + network. + + Returns: + Transformed point after applying affine. + """ + rotation = self.rotation + translation = self.translation + for _ in range(extra_dims): + expand_fn = functools.partial(jnp.expand_dims, axis=-1) + rotation = jax.tree_map(expand_fn, rotation) + translation = jax.tree_map(expand_fn, translation) + + rot_point = [ + transformed_point[0] - translation[0], + transformed_point[1] - translation[1], + transformed_point[2] - translation[2]] + + return apply_inverse_rot_to_vec(rotation, rot_point) + + def __repr__(self): + return 'QuatAffine(%r, %r)' % (self.quaternion, self.translation) + + +def _multiply(a, b): + return jnp.stack([ + jnp.array([a[0][0]*b[0][0] + a[0][1]*b[1][0] + a[0][2]*b[2][0], + a[0][0]*b[0][1] + a[0][1]*b[1][1] + a[0][2]*b[2][1], + a[0][0]*b[0][2] + a[0][1]*b[1][2] + a[0][2]*b[2][2]]), + + jnp.array([a[1][0]*b[0][0] + a[1][1]*b[1][0] + a[1][2]*b[2][0], + a[1][0]*b[0][1] + a[1][1]*b[1][1] + a[1][2]*b[2][1], + a[1][0]*b[0][2] + a[1][1]*b[1][2] + a[1][2]*b[2][2]]), + + jnp.array([a[2][0]*b[0][0] + a[2][1]*b[1][0] + a[2][2]*b[2][0], + a[2][0]*b[0][1] + a[2][1]*b[1][1] + a[2][2]*b[2][1], + a[2][0]*b[0][2] + a[2][1]*b[1][2] + a[2][2]*b[2][2]])]) + + +def make_canonical_transform( + n_xyz: jnp.ndarray, + ca_xyz: jnp.ndarray, + c_xyz: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Returns translation and rotation matrices to canonicalize residue atoms. + + Note that this method does not take care of symmetries. If you provide the + atom positions in the non-standard way, the N atom will end up not at + [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You + need to take care of such cases in your code. + + Args: + n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates. + ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates. + c_xyz: An array of shape [batch, 3] of carbon xyz coordinates. + + Returns: + A tuple (translation, rotation) where: + translation is an array of shape [batch, 3] defining the translation. + rotation is an array of shape [batch, 3, 3] defining the rotation. + After applying the translation and rotation to all atoms in a residue: + * All atoms will be shifted so that CA is at the origin, + * All atoms will be rotated so that C is at the x-axis, + * All atoms will be shifted so that N is in the xy plane. + """ + assert len(n_xyz.shape) == 2, n_xyz.shape + assert n_xyz.shape[-1] == 3, n_xyz.shape + assert n_xyz.shape == ca_xyz.shape == c_xyz.shape, ( + n_xyz.shape, ca_xyz.shape, c_xyz.shape) + + # Place CA at the origin. + translation = -ca_xyz + n_xyz = n_xyz + translation + c_xyz = c_xyz + translation + + # Place C on the x-axis. + c_x, c_y, c_z = [c_xyz[:, i] for i in range(3)] + # Rotate by angle c1 in the x-y plane (around the z-axis). + sin_c1 = -c_y / jnp.sqrt(1e-20 + c_x**2 + c_y**2) + cos_c1 = c_x / jnp.sqrt(1e-20 + c_x**2 + c_y**2) + zeros = jnp.zeros_like(sin_c1) + ones = jnp.ones_like(sin_c1) + # pylint: disable=bad-whitespace + c1_rot_matrix = jnp.stack([jnp.array([cos_c1, -sin_c1, zeros]), + jnp.array([sin_c1, cos_c1, zeros]), + jnp.array([zeros, zeros, ones])]) + + # Rotate by angle c2 in the x-z plane (around the y-axis). + sin_c2 = c_z / jnp.sqrt(1e-20 + c_x**2 + c_y**2 + c_z**2) + cos_c2 = jnp.sqrt(c_x**2 + c_y**2) / jnp.sqrt( + 1e-20 + c_x**2 + c_y**2 + c_z**2) + c2_rot_matrix = jnp.stack([jnp.array([cos_c2, zeros, sin_c2]), + jnp.array([zeros, ones, zeros]), + jnp.array([-sin_c2, zeros, cos_c2])]) + + c_rot_matrix = _multiply(c2_rot_matrix, c1_rot_matrix) + n_xyz = jnp.stack(apply_rot_to_vec(c_rot_matrix, n_xyz, unstack=True)).T + + # Place N in the x-y plane. + _, n_y, n_z = [n_xyz[:, i] for i in range(3)] + # Rotate by angle alpha in the y-z plane (around the x-axis). + sin_n = -n_z / jnp.sqrt(1e-20 + n_y**2 + n_z**2) + cos_n = n_y / jnp.sqrt(1e-20 + n_y**2 + n_z**2) + n_rot_matrix = jnp.stack([jnp.array([ones, zeros, zeros]), + jnp.array([zeros, cos_n, -sin_n]), + jnp.array([zeros, sin_n, cos_n])]) + # pylint: enable=bad-whitespace + + return (translation, + jnp.transpose(_multiply(n_rot_matrix, c_rot_matrix), [2, 0, 1])) + + +def make_transform_from_reference( + n_xyz: jnp.ndarray, + ca_xyz: jnp.ndarray, + c_xyz: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Returns rotation and translation matrices to convert from reference. + + Note that this method does not take care of symmetries. If you provide the + atom positions in the non-standard way, the N atom will end up not at + [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You + need to take care of such cases in your code. + + Args: + n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates. + ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates. + c_xyz: An array of shape [batch, 3] of carbon xyz coordinates. + + Returns: + A tuple (rotation, translation) where: + rotation is an array of shape [batch, 3, 3] defining the rotation. + translation is an array of shape [batch, 3] defining the translation. + After applying the translation and rotation to the reference backbone, + the coordinates will approximately equal to the input coordinates. + + The order of translation and rotation differs from make_canonical_transform + because the rotation from this function should be applied before the + translation, unlike make_canonical_transform. + """ + translation, rotation = make_canonical_transform(n_xyz, ca_xyz, c_xyz) + return np.transpose(rotation, (0, 2, 1)), -translation diff --git a/build/lib/colabdesign/af/alphafold/model/r3.py b/build/lib/colabdesign/af/alphafold/model/r3.py new file mode 100644 index 00000000..b63054e8 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/r3.py @@ -0,0 +1,320 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformations for 3D coordinates. + +This Module contains objects for representing Vectors (Vecs), Rotation Matrices +(Rots) and proper Rigid transformation (Rigids). These are represented as +named tuples with arrays for each entry, for example a set of +[N, M] points would be represented as a Vecs object with arrays of shape [N, M] +for x, y and z. + +This is being done to improve readability by making it very clear what objects +are geometric objects rather than relying on comments and array shapes. +Another reason for this is to avoid using matrix +multiplication primitives like matmul or einsum, on modern accelerator hardware +these can end up on specialized cores such as tensor cores on GPU or the MXU on +cloud TPUs, this often involves lower computational precision which can be +problematic for coordinate geometry. Also these cores are typically optimized +for larger matrices than 3 dimensional, this code is written to avoid any +unintended use of these cores on both GPUs and TPUs. +""" + +import collections +from typing import List +from colabdesign.af.alphafold.model import quat_affine +import jax.numpy as jnp +import tree + +# Array of 3-component vectors, stored as individual array for +# each component. +Vecs = collections.namedtuple('Vecs', ['x', 'y', 'z']) + +# Array of 3x3 rotation matrices, stored as individual array for +# each component. +Rots = collections.namedtuple('Rots', ['xx', 'xy', 'xz', + 'yx', 'yy', 'yz', + 'zx', 'zy', 'zz']) +# Array of rigid 3D transformations, stored as array of rotations and +# array of translations. +Rigids = collections.namedtuple('Rigids', ['rot', 'trans']) + + +def squared_difference(x, y): + return jnp.square(x - y) + + +def invert_rigids(r: Rigids) -> Rigids: + """Computes group inverse of rigid transformations 'r'.""" + inv_rots = invert_rots(r.rot) + t = rots_mul_vecs(inv_rots, r.trans) + inv_trans = Vecs(-t.x, -t.y, -t.z) + return Rigids(inv_rots, inv_trans) + + +def invert_rots(m: Rots) -> Rots: + """Computes inverse of rotations 'm'.""" + return Rots(m.xx, m.yx, m.zx, + m.xy, m.yy, m.zy, + m.xz, m.yz, m.zz) + + +def rigids_from_3_points( + point_on_neg_x_axis: Vecs, # shape (...) + origin: Vecs, # shape (...) + point_on_xy_plane: Vecs, # shape (...) +) -> Rigids: # shape (...) + """Create Rigids from 3 points. + + Jumper et al. (2021) Suppl. Alg. 21 "rigidFrom3Points" + This creates a set of rigid transformations from 3 points by Gram Schmidt + orthogonalization. + + Args: + point_on_neg_x_axis: Vecs corresponding to points on the negative x axis + origin: Origin of resulting rigid transformations + point_on_xy_plane: Vecs corresponding to points in the xy plane + Returns: + Rigid transformations from global frame to local frames derived from + the input points. + """ + m = rots_from_two_vecs( + e0_unnormalized=vecs_sub(origin, point_on_neg_x_axis), + e1_unnormalized=vecs_sub(point_on_xy_plane, origin)) + + return Rigids(rot=m, trans=origin) + + +def rigids_from_list(l: List[jnp.ndarray]) -> Rigids: + """Converts flat list of arrays to rigid transformations.""" + assert len(l) == 12 + return Rigids(Rots(*(l[:9])), Vecs(*(l[9:]))) + + +def rigids_from_quataffine(a: quat_affine.QuatAffine) -> Rigids: + """Converts QuatAffine object to the corresponding Rigids object.""" + return Rigids(Rots(*tree.flatten(a.rotation)), + Vecs(*a.translation)) + + +def rigids_from_tensor4x4( + m: jnp.ndarray # shape (..., 4, 4) +) -> Rigids: # shape (...) + """Construct Rigids object from an 4x4 array. + + Here the 4x4 is representing the transformation in homogeneous coordinates. + + Args: + m: Array representing transformations in homogeneous coordinates. + Returns: + Rigids object corresponding to transformations m + """ + assert m.shape[-1] == 4 + assert m.shape[-2] == 4 + return Rigids( + Rots(m[..., 0, 0], m[..., 0, 1], m[..., 0, 2], + m[..., 1, 0], m[..., 1, 1], m[..., 1, 2], + m[..., 2, 0], m[..., 2, 1], m[..., 2, 2]), + Vecs(m[..., 0, 3], m[..., 1, 3], m[..., 2, 3])) + + +def rigids_from_tensor_flat9( + m: jnp.ndarray # shape (..., 9) +) -> Rigids: # shape (...) + """Flat9 encoding: first two columns of rotation matrix + translation.""" + assert m.shape[-1] == 9 + e0 = Vecs(m[..., 0], m[..., 1], m[..., 2]) + e1 = Vecs(m[..., 3], m[..., 4], m[..., 5]) + trans = Vecs(m[..., 6], m[..., 7], m[..., 8]) + return Rigids(rot=rots_from_two_vecs(e0, e1), + trans=trans) + + +def rigids_from_tensor_flat12( + m: jnp.ndarray # shape (..., 12) +) -> Rigids: # shape (...) + """Flat12 encoding: rotation matrix (9 floats) + translation (3 floats).""" + assert m.shape[-1] == 12 + x = jnp.moveaxis(m, -1, 0) # Unstack + return Rigids(Rots(*x[:9]), Vecs(*x[9:])) + + +def rigids_mul_rigids(a: Rigids, b: Rigids) -> Rigids: + """Group composition of Rigids 'a' and 'b'.""" + return Rigids( + rots_mul_rots(a.rot, b.rot), + vecs_add(a.trans, rots_mul_vecs(a.rot, b.trans))) + + +def rigids_mul_rots(r: Rigids, m: Rots) -> Rigids: + """Compose rigid transformations 'r' with rotations 'm'.""" + return Rigids(rots_mul_rots(r.rot, m), r.trans) + + +def rigids_mul_vecs(r: Rigids, v: Vecs) -> Vecs: + """Apply rigid transforms 'r' to points 'v'.""" + return vecs_add(rots_mul_vecs(r.rot, v), r.trans) + + +def rigids_to_list(r: Rigids) -> List[jnp.ndarray]: + """Turn Rigids into flat list, inverse of 'rigids_from_list'.""" + return list(r.rot) + list(r.trans) + + +def rigids_to_quataffine(r: Rigids) -> quat_affine.QuatAffine: + """Convert Rigids r into QuatAffine, inverse of 'rigids_from_quataffine'.""" + return quat_affine.QuatAffine( + quaternion=None, + rotation=[[r.rot.xx, r.rot.xy, r.rot.xz], + [r.rot.yx, r.rot.yy, r.rot.yz], + [r.rot.zx, r.rot.zy, r.rot.zz]], + translation=[r.trans.x, r.trans.y, r.trans.z]) + + +def rigids_to_tensor_flat9( + r: Rigids # shape (...) +) -> jnp.ndarray: # shape (..., 9) + """Flat9 encoding: first two columns of rotation matrix + translation.""" + return jnp.stack( + [r.rot.xx, r.rot.yx, r.rot.zx, r.rot.xy, r.rot.yy, r.rot.zy] + + list(r.trans), axis=-1) + + +def rigids_to_tensor_flat12( + r: Rigids # shape (...) +) -> jnp.ndarray: # shape (..., 12) + """Flat12 encoding: rotation matrix (9 floats) + translation (3 floats).""" + return jnp.stack(list(r.rot) + list(r.trans), axis=-1) + + +def rots_from_tensor3x3( + m: jnp.ndarray, # shape (..., 3, 3) +) -> Rots: # shape (...) + """Convert rotations represented as (3, 3) array to Rots.""" + assert m.shape[-1] == 3 + assert m.shape[-2] == 3 + return Rots(m[..., 0, 0], m[..., 0, 1], m[..., 0, 2], + m[..., 1, 0], m[..., 1, 1], m[..., 1, 2], + m[..., 2, 0], m[..., 2, 1], m[..., 2, 2]) + + +def rots_from_two_vecs(e0_unnormalized: Vecs, e1_unnormalized: Vecs) -> Rots: + """Create rotation matrices from unnormalized vectors for the x and y-axes. + + This creates a rotation matrix from two vectors using Gram-Schmidt + orthogonalization. + + Args: + e0_unnormalized: vectors lying along x-axis of resulting rotation + e1_unnormalized: vectors lying in xy-plane of resulting rotation + Returns: + Rotations resulting from Gram-Schmidt procedure. + """ + # Normalize the unit vector for the x-axis, e0. + e0 = vecs_robust_normalize(e0_unnormalized) + + # make e1 perpendicular to e0. + c = vecs_dot_vecs(e1_unnormalized, e0) + e1 = Vecs(e1_unnormalized.x - c * e0.x, + e1_unnormalized.y - c * e0.y, + e1_unnormalized.z - c * e0.z) + e1 = vecs_robust_normalize(e1) + + # Compute e2 as cross product of e0 and e1. + e2 = vecs_cross_vecs(e0, e1) + + return Rots(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) + + +def rots_mul_rots(a: Rots, b: Rots) -> Rots: + """Composition of rotations 'a' and 'b'.""" + c0 = rots_mul_vecs(a, Vecs(b.xx, b.yx, b.zx)) + c1 = rots_mul_vecs(a, Vecs(b.xy, b.yy, b.zy)) + c2 = rots_mul_vecs(a, Vecs(b.xz, b.yz, b.zz)) + return Rots(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) + + +def rots_mul_vecs(m: Rots, v: Vecs) -> Vecs: + """Apply rotations 'm' to vectors 'v'.""" + return Vecs(m.xx * v.x + m.xy * v.y + m.xz * v.z, + m.yx * v.x + m.yy * v.y + m.yz * v.z, + m.zx * v.x + m.zy * v.y + m.zz * v.z) + + +def vecs_add(v1: Vecs, v2: Vecs) -> Vecs: + """Add two vectors 'v1' and 'v2'.""" + return Vecs(v1.x + v2.x, v1.y + v2.y, v1.z + v2.z) + + +def vecs_dot_vecs(v1: Vecs, v2: Vecs) -> jnp.ndarray: + """Dot product of vectors 'v1' and 'v2'.""" + return v1.x * v2.x + v1.y * v2.y + v1.z * v2.z + + +def vecs_cross_vecs(v1: Vecs, v2: Vecs) -> Vecs: + """Cross product of vectors 'v1' and 'v2'.""" + return Vecs(v1.y * v2.z - v1.z * v2.y, + v1.z * v2.x - v1.x * v2.z, + v1.x * v2.y - v1.y * v2.x) + + +def vecs_from_tensor(x: jnp.ndarray # shape (..., 3) + ) -> Vecs: # shape (...) + """Converts from tensor of shape (3,) to Vecs.""" + num_components = x.shape[-1] + assert num_components == 3 + return Vecs(x[..., 0], x[..., 1], x[..., 2]) + + +def vecs_robust_normalize(v: Vecs, epsilon: float = 1e-8) -> Vecs: + """Normalizes vectors 'v'. + + Args: + v: vectors to be normalized. + epsilon: small regularizer added to squared norm before taking square root. + Returns: + normalized vectors + """ + norms = vecs_robust_norm(v, epsilon) + return Vecs(v.x / norms, v.y / norms, v.z / norms) + + +def vecs_robust_norm(v: Vecs, epsilon: float = 1e-8) -> jnp.ndarray: + """Computes norm of vectors 'v'. + + Args: + v: vectors to be normalized. + epsilon: small regularizer added to squared norm before taking square root. + Returns: + norm of 'v' + """ + return jnp.sqrt(jnp.square(v.x) + jnp.square(v.y) + jnp.square(v.z) + epsilon) + + +def vecs_sub(v1: Vecs, v2: Vecs) -> Vecs: + """Computes v1 - v2.""" + return Vecs(v1.x - v2.x, v1.y - v2.y, v1.z - v2.z) + + +def vecs_squared_distance(v1: Vecs, v2: Vecs) -> jnp.ndarray: + """Computes squared euclidean difference between 'v1' and 'v2'.""" + return (squared_difference(v1.x, v2.x) + + squared_difference(v1.y, v2.y) + + squared_difference(v1.z, v2.z)) + + +def vecs_to_tensor(v: Vecs # shape (...) + ) -> jnp.ndarray: # shape(..., 3) + """Converts 'v' to tensor with shape 3, inverse of 'vecs_from_tensor'.""" + return jnp.stack([v.x, v.y, v.z], axis=-1) diff --git a/build/lib/colabdesign/af/alphafold/model/tf/__init__.py b/build/lib/colabdesign/af/alphafold/model/tf/__init__.py new file mode 100644 index 00000000..6c520687 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/tf/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Alphafold model TensorFlow code.""" diff --git a/build/lib/colabdesign/af/alphafold/model/tf/shape_placeholders.py b/build/lib/colabdesign/af/alphafold/model/tf/shape_placeholders.py new file mode 100644 index 00000000..cffdeb5e --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/tf/shape_placeholders.py @@ -0,0 +1,20 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Placeholder values for run-time varying dimension sizes.""" + +NUM_RES = 'num residues placeholder' +NUM_MSA_SEQ = 'msa placeholder' +NUM_EXTRA_SEQ = 'extra msa placeholder' +NUM_TEMPLATES = 'num templates placeholder' diff --git a/build/lib/colabdesign/af/alphafold/model/utils.py b/build/lib/colabdesign/af/alphafold/model/utils.py new file mode 100644 index 00000000..b59123c4 --- /dev/null +++ b/build/lib/colabdesign/af/alphafold/model/utils.py @@ -0,0 +1,125 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A collection of JAX utility functions for use in protein folding.""" + +import collections +import contextlib +import functools +import numbers +from typing import Mapping + +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import io + +def bfloat16_creator(next_creator, shape, dtype, init, context): + """Creates float32 variables when bfloat16 is requested.""" + if context.original_dtype == jnp.bfloat16: + dtype = jnp.float32 + return next_creator(shape, dtype, init) + +def bfloat16_getter(next_getter, value, context): + """Casts float32 to bfloat16 when bfloat16 was originally requested.""" + if context.original_dtype == jnp.bfloat16: + assert value.dtype == jnp.float32 + value = value.astype(jnp.bfloat16) + return next_getter(value) + +@contextlib.contextmanager +def bfloat16_context(): + with hk.custom_creator(bfloat16_creator), hk.custom_getter(bfloat16_getter): + yield + +def final_init(config): + if config.zero_init: + return 'zeros' + else: + return 'linear' + +def batched_gather(params, indices, axis=0, batch_dims=0): + """Implements a JAX equivalent of `tf.gather` with `axis` and `batch_dims`.""" + take_fn = lambda p, i: jnp.take(p, i, axis=axis, mode="clip") + for _ in range(batch_dims): + take_fn = jax.vmap(take_fn) + return take_fn(params, indices) + + +def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10): + """Masked mean.""" + if drop_mask_channel: + mask = mask[..., 0] + + mask_shape = mask.shape + value_shape = value.shape + + assert len(mask_shape) == len(value_shape) + + if isinstance(axis, numbers.Integral): + axis = [axis] + elif axis is None: + axis = list(range(len(mask_shape))) + + broadcast_factor = 1. + for axis_ in axis: + value_size = value_shape[axis_] + mask_size = mask_shape[axis_] + if mask_size == 1: + broadcast_factor *= value_size + else: + assert mask_size == value_size + + return (jnp.sum(mask * value, axis=axis) / + (jnp.sum(mask, axis=axis) * broadcast_factor + eps)) + +def flat_params_to_haiku(params, fuse=None): + """Convert a dictionary of NumPy arrays to Haiku parameters.""" + P = {} + for path, array in params.items(): + scope, name = path.split('//') + if scope not in P: + P[scope] = {} + P[scope][name] = jnp.array(array) + if fuse is not None: + for a in ["evoformer_iteration", + "extra_msa_stack", + "template_embedding/single_template_embedding/template_embedding_iteration", + "template_embedding/single_template_embedding/template_pair_stack/__layer_stack_no_state"]: + for b in ["triangle_multiplication_incoming","triangle_multiplication_outgoing"]: + k = f"alphafold/alphafold_iteration/evoformer/{a}/{b}" + + if fuse and f"{k}/center_layer_norm" in P: + for c in ["gate","projection"]: + L = P.pop(f"{k}/left_{c}") + R = P.pop(f"{k}/right_{c}") + P[f"{k}/{c}"] = {} + for d in ["bias","weights"]: + P[f"{k}/{c}"][d] = jnp.concatenate([L[d],R[d]],-1) + P[f"{k}/center_norm"] = P.pop(f"{k}/center_layer_norm") + P[f"{k}/left_norm_input"] = P.pop(f"{k}/layer_norm_input") + + if not fuse and f"{k}/center_norm" in P: + for c in ["gate","projection"]: + LR = P.pop(f"{k}/{c}") + P[f"{k}/left_{c}"] = {} + P[f"{k}/right_{c}"] = {} + for d in ["bias","weights"]: + half = LR[d].shape[-1] // 2 + P[f"{k}/left_{c}"][d] = LR[d][...,:half] + P[f"{k}/right_{c}"][d] = LR[d][...,half:] + P[f"{k}/center_layer_norm"] = P.pop(f"{k}/center_norm") + P[f"{k}/layer_norm_input"] = P.pop(f"{k}/left_norm_input") + return P \ No newline at end of file diff --git a/build/lib/colabdesign/af/contrib/__init__.py b/build/lib/colabdesign/af/contrib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/colabdesign/af/contrib/crop.py b/build/lib/colabdesign/af/contrib/crop.py new file mode 100644 index 00000000..bba1d113 --- /dev/null +++ b/build/lib/colabdesign/af/contrib/crop.py @@ -0,0 +1,154 @@ +import jax +import jax.numpy as jnp +import numpy as np + +from colabdesign.shared.utils import copy_dict +from colabdesign.af.alphafold.model import config + +def setup(self, crop_len=128, crop_mode="slide", crop_iter=5): + + def set_crop(crop_len=128, crop_mode="slide", crop_iter=5): + assert crop_mode in ["slide","roll","pair"] + assert crop_len < sum(self._lengths) + assert self._args["copies"] == 1 or self._args["repeat"] + assert self.protocol not in ["partial","binder"] + self._args["crop"] = {"len":crop_len, "mode":crop_mode, "iter":crop_iter} + + ################################################################# + # function to apply to inputs + ################################################################# + def pre_callback(inputs, opt): + def crop_feat(feat, pos): + '''crop features to specified [pos]itions''' + if feat is None: return None + def find(x,k): + i = [] + for j,y in enumerate(x): + if y == k: i.append(j) + return i + shapes = config.CONFIG.data.eval.feat + NUM_RES = "num residues placeholder" + idx = {k:find(v,NUM_RES) for k,v in shapes.items()} + new_feat = copy_dict(feat) + for k in new_feat.keys(): + if k in ["batch","prev"]: + new_feat[k] = crop_feat(feat[k], pos) + if k in idx: + for i in idx[k]: new_feat[k] = jnp.take(new_feat[k], pos, i) + return new_feat + + p = opt["crop_pos"] + inputs.update(crop_feat(inputs, p)) + + ################################################################# + # function to apply to outputs + ################################################################# + def post_callback(aux, opt): + length = sum(self._lengths) + def uncrop_feat(x, pos, pair=False): + '''uncrop features''' + if pair: + p1, p2 = pos[:,None], pos[None,:] + return jnp.zeros((length,length)+x.shape[2:]).at[p1,p2].set(x) + else: + return jnp.zeros((length,)+x.shape[1:]).at[pos].set(x) + p,x = opt["crop_pos"], aux["prev"] + full_aux = {"prev":{ + "prev_pos": uncrop_feat(x["prev_pos"], p), + "prev_pair": uncrop_feat(x["prev_pair"], p, pair=True), + "prev_msa_first_row": uncrop_feat(x["prev_msa_first_row"], p) + }, + } + aux.update(full_aux) + + ################################################################# + # function to apply before design step + ################################################################# + def pre_design_callback(self): + def update_pos(): + c = self._args["crop"] + L = sum(self._lengths) + if c["mode"] == "slide": + i = np.random.randint(0,(L-c["len"])+1) + self.opt["crop_pos"] = np.arange(i,i+c["len"]) + if c["mode"] == "roll": + i = np.random.randint(0,L) + self.opt["crop_pos"] = np.sort(np.roll(np.arange(L),L-i)[:c["len"]]) + if c["mode"] == "pair": + # pick random pair of interactig crops + max_L = c["len"] // 2 + # pick first crop + i_range = np.append(np.arange(0,(L-2*max_L)+1),np.arange(max_L,(L-max_L)+1)) + i = np.random.choice(i_range) + # pick second crop + j_range = np.append(np.arange(0,(i-max_L)+1),np.arange(i+max_L,(L-max_L)+1)) + if hasattr(self,"_cmap"): + # if contact map defined, bias to interacting pairs + w = np.array([self._cmap[i:i+max_L,j:j+max_L].sum() for j in j_range]) + 1e-8 + j = np.random.choice(j_range, p=w/w.sum()) + else: + j = np.random.choice(j_range) + self.opt["crop_pos"] = np.sort(np.append(np.arange(i,i+max_L),np.arange(j,j+max_L))) + + if "crop" not in self._tmp: + self._tmp["crop"] = {"k":self._k} + update_pos() + if self._k != self._tmp["crop"]["k"]: + if (self._k % self._args["crop"]["iter"]) == 0: + update_pos() + + self._tmp["crop"]["k"] = self._k + if hasattr(self,"aux"): + for k in ["cmap","pae","plddt","atom_positions"]: + self._tmp["crop"][k] = self.aux[k] + + ################################################################# + # function to apply after design step + ################################################################# + def post_design_callback(self): + # uncrop/accumulate features + L = sum(self._lengths) + def uncrop_feat(x, pos, pair=False): + if pair: + y = np.full((L,L)+x.shape[2:],np.nan) + y[pos[:,None],pos[None,:]] = x + else: + y = np.full((L,)+x.shape[1:],np.nan) + y[pos] = x + return y + + # uncrop features + a, p = self.aux, self.opt["crop_pos"] + vs = {k:uncrop_feat(a[k],p,pair=True) for k in ["cmap","pae"]} + vs.update({k:uncrop_feat(a[k],p) for k in ["plddt","atom_positions"]}) + + # accumulate features + for k,v in vs.items(): + w = self._tmp["crop"].get(k,v) + w = np.where(np.isnan(w),v,w) + vs[k] = np.where(np.isnan(v),w,(v+w)/2) + + self.aux.update(vs) + + ################################################################# + # SETUP + ################################################################# + if hasattr(self,"set_crop"): + self.set_crop(crop_len, crop_mode, crop_iter) + + else: + self.set_crop = set_crop + self.set_crop(crop_len, crop_mode, crop_iter) + + # add distances + if self.protocol == "fixbb": + cb_atoms = self._pdb["cb_feat"]["atoms"] + cb_atoms[self._pdb["cb_feat"]["mask"] == 0,:] = np.nan + cb_dist = np.sqrt(np.square(cb_atoms[:,None] - cb_atoms[None,:]).sum(-1)) + self._cmap = cb_dist < self.opt["cmap_cutoff"] + + # populate callbacks + self._callbacks["model"]["pre"].append(pre_callback) + self._callbacks["model"]["post"].append(post_callback) + self._callbacks["design"]["pre"].append(pre_design_callback) + self._callbacks["design"]["post"].append(post_design_callback) \ No newline at end of file diff --git a/build/lib/colabdesign/af/design.py b/build/lib/colabdesign/af/design.py new file mode 100644 index 00000000..732b3d58 --- /dev/null +++ b/build/lib/colabdesign/af/design.py @@ -0,0 +1,561 @@ +import random, os +import jax +import jax.numpy as jnp +import numpy as np +from colabdesign.af.alphafold.common import residue_constants +from colabdesign.shared.utils import copy_dict, update_dict, Key, dict_to_str, to_float, softmax, categorical, to_list, copy_missing + +#################################################### +# AF_DESIGN - design functions +#################################################### +#\ +# \_af_design +# |\ +# | \_restart +# \ +# \_design +# \_step +# \_run +# \_recycle +# \_single +# +#################################################### + +class _af_design: + + def restart(self, seed=None, opt=None, weights=None, + seq=None, mode=None, keep_history=False, reset_opt=True, **kwargs): + ''' + restart the optimization + ------------ + note: model.restart() resets the [opt]ions and weights to their defaults + use model.set_opt(..., set_defaults=True) and model.set_weights(..., set_defaults=True) + or model.restart(reset_opt=False) to avoid this + ------------ + seed=0 - set seed for reproducibility + reset_opt=False - do NOT reset [opt]ions/weights to defaults + keep_history=True - do NOT clear the trajectory/[opt]ions/weights + ''' + # reset [opt]ions + if reset_opt and not keep_history: + copy_missing(self.opt, self._opt) + self.opt = copy_dict(self._opt) + if hasattr(self,"aux"): del self.aux + + if not keep_history: + # initialize trajectory + self._tmp = {"traj":{"seq":[],"xyz":[],"plddt":[],"pae":[]}, + "log":[],"best":{}} + + # update options/settings (if defined) + self.set_opt(opt) + self.set_weights(weights) + + # initialize sequence + self.set_seed(seed) + self.set_seq(seq=seq, mode=mode, **kwargs) + + # reset optimizer + self._k = 0 + self.set_optimizer() + + def _get_model_nums(self, num_models=None, sample_models=None, models=None): + '''decide which model params to use''' + if num_models is None: num_models = self.opt["num_models"] + if sample_models is None: sample_models = self.opt["sample_models"] + + ns_name = self._model_names + ns = list(range(len(ns_name))) + if models is not None: + models = models if isinstance(models,list) else [models] + ns = [ns[n if isinstance(n,int) else ns_name.index(n)] for n in models] + + m = min(num_models,len(ns)) + if sample_models and m != len(ns): + model_nums = np.random.choice(ns,(m,),replace=False) + else: + model_nums = ns[:m] + return model_nums + + def run(self, num_recycles=None, num_models=None, sample_models=None, models=None, + backprop=True, callback=None, model_nums=None, return_aux=False): + '''run model to get outputs, losses and gradients''' + + # pre-design callbacks + for fn in self._callbacks["design"]["pre"]: fn(self) + + # decide which model params to use + if model_nums is None: + model_nums = self._get_model_nums(num_models, sample_models, models) + assert len(model_nums) > 0, "ERROR: no model params defined" + + # loop through model params + auxs = [] + for n in model_nums: + p = self._model_params[n] + auxs.append(self._recycle(p, num_recycles=num_recycles, backprop=backprop)) + auxs = jax.tree_map(lambda *x: np.stack(x), *auxs) + + # update aux (average outputs) + def avg_or_first(x): + if np.issubdtype(x.dtype, np.integer): return x[0] + else: return x.mean(0) + + self.aux = jax.tree_map(avg_or_first, auxs) + self.aux["atom_positions"] = auxs["atom_positions"][0] + self.aux["all"] = auxs + + # post-design callbacks + for fn in (self._callbacks["design"]["post"] + to_list(callback)): fn(self) + + # update log + self.aux["log"] = {**self.aux["losses"]} + self.aux["log"]["plddt"] = 1 - self.aux["log"]["plddt"] + for k in ["loss","i_ptm","ptm"]: self.aux["log"][k] = self.aux[k] + for k in ["hard","soft","temp"]: self.aux["log"][k] = self.opt[k] + + # compute sequence recovery + if self.protocol in ["fixbb","partial"] or (self.protocol == "binder" and self._args["redesign"]): + if self.protocol == "partial": + aatype = self.aux["aatype"][...,self.opt["pos"]] + else: + aatype = self.aux["seq"]["pseudo"].argmax(-1) + + mask = self._wt_aatype != -1 + true = self._wt_aatype[mask] + pred = aatype[...,mask] + self.aux["log"]["seqid"] = (true == pred).mean() + + self.aux["log"] = to_float(self.aux["log"]) + self.aux["log"].update({"recycles":int(self.aux["num_recycles"]), + "models":model_nums}) + + if return_aux: return self.aux + + def _single(self, model_params, backprop=True): + '''single pass through the model''' + self._inputs["opt"] = self.opt + flags = [self._params, model_params, self._inputs, self.key()] + if backprop: + (loss, aux), grad = self._model["grad_fn"](*flags) + else: + loss, aux = self._model["fn"](*flags) + grad = jax.tree_map(np.zeros_like, self._params) + aux.update({"loss":loss,"grad":grad}) + return aux + + def _recycle(self, model_params, num_recycles=None, backprop=True): + '''multiple passes through the model (aka recycle)''' + a = self._args + mode = a["recycle_mode"] + if num_recycles is None: + num_recycles = self.opt["num_recycles"] + + if mode in ["backprop","add_prev"]: + # recycles compiled into model, only need single-pass + aux = self._single(model_params, backprop) + + else: + L = self._inputs["residue_index"].shape[0] + + # intialize previous + if "prev" not in self._inputs or a["clear_prev"]: + prev = {'prev_msa_first_row': np.zeros([L,256]), + 'prev_pair': np.zeros([L,L,128])} + + if a["use_initial_guess"] and "batch" in self._inputs: + prev["prev_pos"] = self._inputs["batch"]["all_atom_positions"] + else: + prev["prev_pos"] = np.zeros([L,37,3]) + + if a["use_dgram"]: + # TODO: add support for initial_guess + use_dgram + prev["prev_dgram"] = np.zeros([L,L,64]) + + if a["use_initial_atom_pos"]: + if "batch" in self._inputs: + self._inputs["initial_atom_pos"] = self._inputs["batch"]["all_atom_positions"] + else: + self._inputs["initial_atom_pos"] = np.zeros([L,37,3]) + + self._inputs["prev"] = prev + # decide which layers to compute gradients for + cycles = (num_recycles + 1) + mask = [0] * cycles + + if mode == "sample": mask[np.random.randint(0,cycles)] = 1 + if mode == "average": mask = [1/cycles] * cycles + if mode == "last": mask[-1] = 1 + if mode == "first": mask[0] = 1 + + # gather gradients across recycles + grad = [] + for m in mask: + if m == 0: + aux = self._single(model_params, backprop=False) + else: + aux = self._single(model_params, backprop) + grad.append(jax.tree_map(lambda x:x*m, aux["grad"])) + self._inputs["prev"] = aux["prev"] + if a["use_initial_atom_pos"]: + self._inputs["initial_atom_pos"] = aux["prev"]["prev_pos"] + + aux["grad"] = jax.tree_map(lambda *x: np.stack(x).sum(0), *grad) + + aux["num_recycles"] = num_recycles + return aux + + def step(self, lr_scale=1.0, num_recycles=None, + num_models=None, sample_models=None, models=None, backprop=True, + callback=None, save_best=False, verbose=1): + '''do one step of gradient descent''' + + # run + self.run(num_recycles=num_recycles, num_models=num_models, sample_models=sample_models, + models=models, backprop=backprop, callback=callback) + + # modify gradients + if self.opt["norm_seq_grad"]: self._norm_seq_grad() + self._state, self.aux["grad"] = self._optimizer(self._state, self.aux["grad"], self._params) + + # apply gradients + lr = self.opt["learning_rate"] * lr_scale + self._params = jax.tree_map(lambda x,g:x-lr*g, self._params, self.aux["grad"]) + + # save results + self._save_results(save_best=save_best, verbose=verbose) + + # increment + self._k += 1 + + def _print_log(self, print_str=None, aux=None): + if aux is None: aux = self.aux + keys = ["models","recycles","hard","soft","temp","seqid","loss", + "seq_ent","mlm","helix","pae","i_pae","exp_res","con","i_con", + "sc_fape","sc_rmsd","dgram_cce","fape","plddt","ptm"] + + if "i_ptm" in aux["log"]: + if len(self._lengths) > 1: + keys.append("i_ptm") + else: + aux["log"].pop("i_ptm") + + print(dict_to_str(aux["log"], filt=self.opt["weights"], + print_str=print_str, keys=keys+["rmsd"], ok=["plddt","rmsd"])) + + def _save_results(self, aux=None, save_best=False, + best_metric=None, metric_higher_better=False, + verbose=True): + if aux is None: aux = self.aux + self._tmp["log"].append(aux["log"]) + if (self._k % self._args["traj_iter"]) == 0: + # update traj + traj = {"seq": aux["seq"]["pseudo"], + "xyz": aux["atom_positions"][:,1,:], + "plddt": aux["plddt"], + "pae": aux["pae"]} + for k,v in traj.items(): + if len(self._tmp["traj"][k]) == self._args["traj_max"]: + self._tmp["traj"][k].pop(0) + self._tmp["traj"][k].append(v) + + # save best + if save_best: + if best_metric is None: + best_metric = self._args["best_metric"] + metric = float(aux["log"][best_metric]) + if self._args["best_metric"] in ["plddt","ptm","i_ptm","seqid","composite"] or metric_higher_better: + metric = -metric + if "metric" not in self._tmp["best"] or metric < self._tmp["best"]["metric"]: + self._tmp["best"]["aux"] = copy_dict(aux) + self._tmp["best"]["metric"] = metric + + if verbose and ((self._k+1) % verbose) == 0: + self._print_log(f"{self._k+1}", aux=aux) + + def predict(self, seq=None, bias=None, + num_models=None, num_recycles=None, models=None, sample_models=False, + dropout=False, hard=True, soft=False, temp=1, + return_aux=False, verbose=True, seed=None, **kwargs): + '''predict structure for input sequence (if provided)''' + + def load_settings(): + if "save" in self._tmp: + [self.opt, self._args, self._params, self._inputs] = self._tmp.pop("save") + + def save_settings(): + load_settings() + self._tmp["save"] = [copy_dict(x) for x in [self.opt, self._args, self._params, self._inputs]] + + save_settings() + + # set seed if defined + if seed is not None: self.set_seed(seed) + + # set [seq]uence/[opt]ions + if seq is not None: self.set_seq(seq=seq, bias=bias) + self.set_opt(hard=hard, soft=soft, temp=temp, dropout=dropout, pssm_hard=True) + self.set_args(shuffle_first=False) + + # run + self.run(num_recycles=num_recycles, num_models=num_models, + sample_models=sample_models, models=models, backprop=False, **kwargs) + if verbose: self._print_log("predict") + + load_settings() + + # return (or save) results + if return_aux: return self.aux + + # --------------------------------------------------------------------------------- + # example design functions + # --------------------------------------------------------------------------------- + def design(self, iters=100, + soft=0.0, e_soft=None, + temp=1.0, e_temp=None, + hard=0.0, e_hard=None, + step=1.0, e_step=None, + dropout=True, opt=None, weights=None, + num_recycles=None, ramp_recycles=False, + num_models=None, sample_models=None, models=None, + backprop=True, callback=None, save_best=False, verbose=1): + + # update options/settings (if defined) + self.set_opt(opt, dropout=dropout) + self.set_weights(weights) + m = {"soft":[soft,e_soft],"temp":[temp,e_temp], + "hard":[hard,e_hard],"step":[step,e_step]} + m = {k:[s,(s if e is None else e)] for k,(s,e) in m.items()} + + if ramp_recycles: + if num_recycles is None: + num_recycles = self.opt["num_recycles"] + m["num_recycles"] = [0,num_recycles] + + for i in range(iters): + for k,(s,e) in m.items(): + if k == "temp": + self.set_opt({k:(e+(s-e)*(1-(i+1)/iters)**2)}) + else: + v = (s+(e-s)*((i+1)/iters)) + if k == "step": step = v + elif k == "num_recycles": num_recycles = round(v) + else: self.set_opt({k:v}) + + # decay learning rate based on temperature + lr_scale = step * ((1 - self.opt["soft"]) + (self.opt["soft"] * self.opt["temp"])) + + self.step(lr_scale=lr_scale, num_recycles=num_recycles, + num_models=num_models, sample_models=sample_models, models=models, + backprop=backprop, callback=callback, save_best=save_best, verbose=verbose) + + def design_logits(self, iters=100, **kwargs): + ''' optimize logits ''' + self.design(iters, **kwargs) + + def design_soft(self, iters=100, temp=1, **kwargs): + ''' optimize softmax(logits/temp)''' + self.design(iters, soft=1, temp=temp, **kwargs) + + def design_hard(self, iters=100, **kwargs): + ''' optimize argmax(logits) ''' + self.design(iters, soft=1, hard=1, **kwargs) + + # --------------------------------------------------------------------------------- + # experimental + # --------------------------------------------------------------------------------- + def design_3stage(self, soft_iters=300, temp_iters=100, hard_iters=10, + ramp_recycles=True, **kwargs): + '''three stage design (logits→soft→hard)''' + + verbose = kwargs.get("verbose",1) + + # stage 1: logits -> softmax(logits/1.0) + if soft_iters > 0: + if verbose: print("Stage 1: running (logits → soft)") + self.design_logits(soft_iters, e_soft=1, + ramp_recycles=ramp_recycles, **kwargs) + self._tmp["seq_logits"] = self.aux["seq"]["logits"] + + # stage 2: softmax(logits/1.0) -> softmax(logits/0.01) + if temp_iters > 0: + if verbose: print("Stage 2: running (soft → hard)") + self.design_soft(temp_iters, e_temp=1e-2, **kwargs) + + # stage 3: + if hard_iters > 0: + if verbose: print("Stage 3: running (hard)") + kwargs["dropout"] = False + kwargs["save_best"] = True + kwargs["num_models"] = len(self._model_names) + self.design_hard(hard_iters, temp=1e-2, **kwargs) + + def _mutate(self, seq, plddt=None, logits=None, mutation_rate=1): + '''mutate random position''' + seq = np.array(seq) + N,L = seq.shape + + # fix some positions + i_prob = np.ones(L) if plddt is None else np.maximum(1-plddt,0) + i_prob[np.isnan(i_prob)] = 0 + if "fix_pos" in self.opt: + if "pos" in self.opt: + p = self.opt["pos"][self.opt["fix_pos"]] + seq[...,p] = self._wt_aatype_sub + else: + p = self.opt["fix_pos"] + seq[...,p] = self._wt_aatype[...,p] + i_prob[p] = 0 + + for m in range(mutation_rate): + # sample position + # https://www.biorxiv.org/content/10.1101/2021.08.24.457549v1 + i = np.random.choice(np.arange(L),p=i_prob/i_prob.sum()) + + # sample amino acid + logits = np.array(0 if logits is None else logits) + if logits.ndim == 3: logits = logits[:,i] + elif logits.ndim == 2: logits = logits[i] + a_logits = logits - np.eye(self._args["alphabet_size"])[seq[:,i]] * 1e8 + a = categorical(softmax(a_logits)) + + # return mutant + seq[:,i] = a + + return seq + + def design_semigreedy(self, iters=100, tries=10, dropout=False, + save_best=True, seq_logits=None, e_tries=None, **kwargs): + + '''semigreedy search''' + if e_tries is None: e_tries = tries + + # get starting sequence + if hasattr(self,"aux"): + seq = self.aux["seq"]["logits"].argmax(-1) + else: + seq = (self._params["seq"] + self._inputs["bias"]).argmax(-1) + + # bias sampling towards the defined bias + if seq_logits is None: seq_logits = 0 + + model_flags = {k:kwargs.pop(k,None) for k in ["num_models","sample_models","models"]} + verbose = kwargs.pop("verbose",1) + + # get current plddt + aux = self.predict(seq, return_aux=True, verbose=False, **model_flags, **kwargs) + plddt = self.aux["plddt"] + plddt = plddt[self._target_len:] if self.protocol == "binder" else plddt[:self._len] + + # optimize! + if verbose: + print("Running semigreedy optimization...") + + for i in range(iters): + buff = [] + model_nums = self._get_model_nums(**model_flags) + num_tries = (tries+(e_tries-tries)*((i+1)/iters)) + for t in range(int(num_tries)): + mut_seq = self._mutate(seq=seq, plddt=plddt, + logits=seq_logits + self._inputs["bias"]) + aux = self.predict(seq=mut_seq, return_aux=True, model_nums=model_nums, verbose=False, **kwargs) + buff.append({"aux":aux, "seq":np.array(mut_seq)}) + + # accept best + losses = [x["aux"]["loss"] for x in buff] + best = buff[np.argmin(losses)] + self.aux, seq = best["aux"], jnp.array(best["seq"]) + self.set_seq(seq=seq, bias=self._inputs["bias"]) + self._save_results(save_best=save_best, verbose=verbose) + + # update plddt + plddt = best["aux"]["plddt"] + plddt = plddt[self._target_len:] if self.protocol == "binder" else plddt[:self._len] + self._k += 1 + + def design_pssm_semigreedy(self, soft_iters=300, hard_iters=32, tries=10, e_tries=None, + ramp_recycles=True, ramp_models=True, **kwargs): + + verbose = kwargs.get("verbose",1) + + # stage 1: logits -> softmax(logits) + if soft_iters > 0: + self.design_3stage(soft_iters, 0, 0, ramp_recycles=ramp_recycles, **kwargs) + self._tmp["seq_logits"] = kwargs["seq_logits"] = self.aux["seq"]["logits"] + + # stage 2: semi_greedy + if hard_iters > 0: + kwargs["dropout"] = False + if ramp_models: + num_models = len(kwargs.get("models",self._model_names)) + iters = hard_iters + for m in range(num_models): + if verbose and m > 0: print(f'Increasing number of models to {m+1}.') + + kwargs["num_models"] = m + 1 + kwargs["save_best"] = (m + 1) == num_models + self.design_semigreedy(iters, tries=tries, e_tries=e_tries, **kwargs) + if m < 2: iters = iters // 2 + else: + self.design_semigreedy(hard_iters, tries=tries, e_tries=e_tries, **kwargs) + + # --------------------------------------------------------------------------------- + # experimental optimizers (not extensively evaluated) + # --------------------------------------------------------------------------------- + + def _design_mcmc(self, steps=1000, half_life=200, T_init=0.01, mutation_rate=1, + seq_logits=None, save_best=True, **kwargs): + ''' + MCMC with simulated annealing + ---------------------------------------- + steps = number for steps for the MCMC trajectory + half_life = half-life for the temperature decay during simulated annealing + T_init = starting temperature for simulated annealing. Temperature is decayed exponentially + mutation_rate = number of mutations at each MCMC step + ''' + + # code borrowed from: github.com/bwicky/oligomer_hallucination + + # gather settings + verbose = kwargs.pop("verbose",1) + model_flags = {k:kwargs.pop(k,None) for k in ["num_models","sample_models","models"]} + + # initialize + plddt, best_loss, current_loss = None, np.inf, np.inf + current_seq = (self._params["seq"] + self._inputs["bias"]).argmax(-1) + if seq_logits is None: seq_logits = 0 + + # run! + if verbose: print("Running MCMC with simulated annealing...") + for i in range(steps): + + # update temperature + T = T_init * (np.exp(np.log(0.5) / half_life) ** i) + + # mutate sequence + if i == 0: + mut_seq = current_seq + else: + mut_seq = self._mutate(seq=current_seq, plddt=plddt, + logits=seq_logits + self._inputs["bias"], + mutation_rate=mutation_rate) + + # get loss + model_nums = self._get_model_nums(**model_flags) + aux = self.predict(seq=mut_seq, return_aux=True, verbose=False, model_nums=model_nums, **kwargs) + loss = aux["log"]["loss"] + + # decide + delta = loss - current_loss + if i == 0 or delta < 0 or np.random.uniform() < np.exp( -delta / T): + + # accept + (current_seq,current_loss) = (mut_seq,loss) + + plddt = aux["all"]["plddt"].mean(0) + plddt = plddt[self._target_len:] if self.protocol == "binder" else plddt[:self._len] + + if loss < best_loss: + (best_loss, self._k) = (loss, i) + self.set_seq(seq=current_seq, bias=self._inputs["bias"]) + self._save_results(save_best=save_best, verbose=verbose) diff --git a/build/lib/colabdesign/af/inputs.py b/build/lib/colabdesign/af/inputs.py new file mode 100644 index 00000000..fe390a60 --- /dev/null +++ b/build/lib/colabdesign/af/inputs.py @@ -0,0 +1,155 @@ +import jax +import jax.numpy as jnp +import numpy as np + +from colabdesign.shared.utils import copy_dict +from colabdesign.shared.model import soft_seq +from colabdesign.af.alphafold.common import residue_constants +from colabdesign.af.alphafold.model import model, config + +############################################################################ +# AF_INPUTS - functions for modifying inputs before passing to alphafold +############################################################################ +class _af_inputs: + + def _get_seq(self, inputs, aux, key=None): + params, opt = inputs["params"], inputs["opt"] + '''get sequence features''' + seq = soft_seq(params["seq"], inputs["bias"], opt, key, num_seq=self._num, + shuffle_first=self._args["shuffle_first"]) + seq = self._fix_pos(seq) + aux.update({"seq":seq, "seq_pseudo":seq["pseudo"]}) + + # protocol specific modifications to seq features + if self.protocol == "binder": + # concatenate target and binder sequence + seq_target = jax.nn.one_hot(inputs["batch"]["aatype"][:self._target_len],self._args["alphabet_size"]) + seq_target = jnp.broadcast_to(seq_target,(self._num, *seq_target.shape)) + seq = jax.tree_map(lambda x:jnp.concatenate([seq_target,x],1), seq) + + if self.protocol in ["fixbb","hallucination","partial"] and self._args["copies"] > 1: + seq = jax.tree_map(lambda x:expand_copies(x, self._args["copies"], self._args["block_diag"]), seq) + + return seq + + def _fix_pos(self, seq, return_p=False): + if "fix_pos" in self.opt: + if "pos" in self.opt: + seq_ref = jax.nn.one_hot(self._wt_aatype_sub,self._args["alphabet_size"]) + p = self.opt["pos"][self.opt["fix_pos"]] + fix_seq = lambda x: x.at[...,p,:].set(seq_ref) + else: + seq_ref = jax.nn.one_hot(self._wt_aatype,self._args["alphabet_size"]) + p = self.opt["fix_pos"] + fix_seq = lambda x: x.at[...,p,:].set(seq_ref[...,p,:]) + seq = jax.tree_map(fix_seq, seq) + if return_p: return seq, p + return seq + + def _update_template(self, inputs, key): + ''''dynamically update template features''' + if "batch" in inputs: + batch, opt = inputs["batch"], inputs["opt"] + + # enable templates + inputs["template_mask"] = inputs["template_mask"].at[0].set(1) + L = batch["aatype"].shape[0] + + # decide which position to remove sequence and/or sidechains + rm = jnp.broadcast_to(inputs.get("rm_template",False),L) + rm_seq = jnp.where(rm,True,jnp.broadcast_to(inputs.get("rm_template_seq",True),L)) + rm_sc = jnp.where(rm_seq,True,jnp.broadcast_to(inputs.get("rm_template_sc",True),L)) + + # define template features + template_feats = {"template_aatype":jnp.where(rm_seq,21,batch["aatype"])} + + if "dgram" in batch: + # use dgram from batch if provided + template_feats.update({"template_dgram":batch["dgram"]}) + nT,nL = inputs["template_aatype"].shape + inputs["template_dgram"] = jnp.zeros((nT,nL,nL,39)) + + if "all_atom_positions" in batch: + # get pseudo-carbon-beta coordinates (carbon-alpha for glycine) + # aatype = is used to define template's CB coordinates (CA in case of glycine) + cb, cb_mask = model.modules.pseudo_beta_fn( + jnp.where(rm_seq,0,batch["aatype"]), + batch["all_atom_positions"], + batch["all_atom_mask"]) + template_feats.update({"template_pseudo_beta": cb, + "template_pseudo_beta_mask": cb_mask, + "template_all_atom_positions": batch["all_atom_positions"], + "template_all_atom_mask": batch["all_atom_mask"]}) + + # inject template features + if self.protocol == "partial": + pos = opt["pos"] + if self._args["repeat"] or self._args["homooligomer"]: + C,L = self._args["copies"], self._len + pos = (jnp.repeat(pos,C).reshape(-1,C) + jnp.arange(C) * L).T.flatten() + + for k,v in template_feats.items(): + if self.protocol == "partial": + if k in ["template_dgram"]: + inputs[k] = inputs[k].at[0,pos[:,None],pos[None,:]].set(v) + else: + inputs[k] = inputs[k].at[0,pos].set(v) + else: + inputs[k] = inputs[k].at[0].set(v) + + # remove sidechains (mask anything beyond CB) + if k in ["template_all_atom_mask"]: + if self.protocol == "partial": + inputs[k] = inputs[k].at[:,pos,5:].set(jnp.where(rm_sc[:,None],0,inputs[k][:,pos,5:])) + inputs[k] = inputs[k].at[:,pos].set(jnp.where(rm[:,None],0,inputs[k][:,pos])) + else: + inputs[k] = inputs[k].at[...,5:].set(jnp.where(rm_sc[:,None],0,inputs[k][...,5:])) + inputs[k] = jnp.where(rm[:,None],0,inputs[k]) + +def update_seq(seq, inputs, seq_1hot=None, seq_pssm=None, mlm=None): + '''update the sequence features''' + + if seq_1hot is None: seq_1hot = seq + if seq_pssm is None: seq_pssm = seq + target_feat = seq_1hot[0,:,:20] + + seq_1hot = jnp.pad(seq_1hot,[[0,0],[0,0],[0,22-seq_1hot.shape[-1]]]) + seq_pssm = jnp.pad(seq_pssm,[[0,0],[0,0],[0,22-seq_pssm.shape[-1]]]) + msa_feat = jnp.zeros_like(inputs["msa_feat"]).at[...,0:22].set(seq_1hot).at[...,25:47].set(seq_pssm) + + # masked language modeling (randomly mask positions) + if mlm is not None: + X = jax.nn.one_hot(22,23) + X = jnp.zeros(msa_feat.shape[-1]).at[...,:23].set(X).at[...,25:48].set(X) + msa_feat = jnp.where(mlm[...,None],X,msa_feat) + + inputs.update({"msa_feat":msa_feat, "target_feat":target_feat}) + +def update_aatype(aatype, inputs): + r = residue_constants + a = {"atom14_atom_exists":r.restype_atom14_mask, + "atom37_atom_exists":r.restype_atom37_mask, + "residx_atom14_to_atom37":r.restype_atom14_to_atom37, + "residx_atom37_to_atom14":r.restype_atom37_to_atom14} + mask = inputs["seq_mask"][:,None] + inputs.update(jax.tree_map(lambda x:jnp.where(mask,jnp.asarray(x)[aatype],0),a)) + inputs["aatype"] = aatype + +def expand_copies(x, copies, block_diag=True): + ''' + given msa (N,L,20) expand to (1+N*copies,L*copies,22) if block_diag else (N,L*copies,22) + ''' + if x.shape[-1] < 22: + x = jnp.pad(x,[[0,0],[0,0],[0,22-x.shape[-1]]]) + x = jnp.tile(x,[1,copies,1]) + if copies > 1 and block_diag: + L = x.shape[1] + sub_L = L // copies + y = x.reshape((-1,1,copies,sub_L,22)) + block_diag_mask = jnp.expand_dims(jnp.eye(copies),(0,3,4)) + seq = block_diag_mask * y + gap_seq = (1-block_diag_mask) * jax.nn.one_hot(jnp.repeat(21,sub_L),22) + y = (seq + gap_seq).swapaxes(0,1).reshape(-1,L,22) + return jnp.concatenate([x[:1],y],0) + else: + return x \ No newline at end of file diff --git a/build/lib/colabdesign/af/loss.py b/build/lib/colabdesign/af/loss.py new file mode 100644 index 00000000..d9982151 --- /dev/null +++ b/build/lib/colabdesign/af/loss.py @@ -0,0 +1,548 @@ +import jax +import jax.numpy as jnp +import numpy as np + +from colabdesign.shared.utils import Key, copy_dict +from colabdesign.shared.protein import jnp_rmsd_w, _np_kabsch, _np_rmsd, _np_get_6D_loss +from colabdesign.af.alphafold.model import model, folding, all_atom +from colabdesign.af.alphafold.common import confidence, residue_constants + +#################################################### +# AF_LOSS - setup loss function +#################################################### + +class _af_loss: + # protocol specific loss functions + def _loss_fixbb(self, inputs, outputs, aux): + opt = inputs["opt"] + '''get losses''' + copies = self._args["copies"] if self._args["homooligomer"] else 1 + # rmsd loss + aln = get_rmsd_loss(inputs, outputs, copies=copies) + if self._args["realign"]: + aux["atom_positions"] = aln["align"](aux["atom_positions"]) * aux["atom_mask"][...,None] + + # supervised losses + aux["losses"].update({ + "fape": get_fape_loss(inputs, outputs, copies=copies, clamp=opt["fape_cutoff"]), + "dgram_cce": get_dgram_loss(inputs, outputs, copies=copies, aatype=inputs["aatype"]), + "rmsd": aln["rmsd"], + }) + + # unsupervised losses + self._loss_unsupervised(inputs, outputs, aux) + + def _loss_binder(self, inputs, outputs, aux): + '''get losses''' + opt = inputs["opt"] + mask = inputs["seq_mask"] + zeros = jnp.zeros_like(mask) + tL,bL = self._target_len, self._binder_len + binder_id = zeros.at[-bL:].set(mask[-bL:]) + if "hotspot" in opt: + target_id = zeros.at[opt["hotspot"]].set(mask[opt["hotspot"]]) + i_con_loss = get_con_loss(inputs, outputs, opt["i_con"], mask_1d=target_id, mask_1b=binder_id) + else: + target_id = zeros.at[:tL].set(mask[:tL]) + i_con_loss = get_con_loss(inputs, outputs, opt["i_con"], mask_1d=binder_id, mask_1b=target_id) + + # unsupervised losses + aux["losses"].update({ + "plddt": get_plddt_loss(outputs, mask_1d=binder_id), # plddt over binder + "exp_res": get_exp_res_loss(outputs, mask_1d=binder_id), + "pae": get_pae_loss(outputs, mask_1d=binder_id), # pae over binder + interface + "con": get_con_loss(inputs, outputs, opt["con"], mask_1d=binder_id, mask_1b=binder_id), + # interface + "i_con": i_con_loss, + "i_pae": get_pae_loss(outputs, mask_1d=binder_id, mask_1b=target_id), + }) + + # supervised losses + if self._args["redesign"]: + + aln = get_rmsd_loss(inputs, outputs, L=tL, include_L=False) + align_fn = aln["align"] + + # compute cce of binder + interface + aatype = inputs["aatype"] + cce = get_dgram_loss(inputs, outputs, aatype=aatype, return_mtx=True) + + # compute fape + fape = get_fape_loss(inputs, outputs, clamp=opt["fape_cutoff"], return_mtx=True) + + aux["losses"].update({ + "rmsd": aln["rmsd"], + "dgram_cce": cce[-bL:].sum() / (mask[-bL:].sum() + 1e-8), + "fape": fape[-bL:].sum() / (mask[-bL:].sum() + 1e-8) + }) + + else: + align_fn = get_rmsd_loss(inputs, outputs, L=tL)["align"] + + if self._args["realign"]: + aux["atom_positions"] = align_fn(aux["atom_positions"]) * aux["atom_mask"][...,None] + + def _loss_partial(self, inputs, outputs, aux): + '''get losses''' + opt = inputs["opt"] + pos = opt["pos"] + if self._args["repeat"] or self._args["homooligomer"]: + C,L = self._args["copies"], self._len + pos = (jnp.repeat(pos,C).reshape(-1,C) + jnp.arange(C) * L).T.flatten() + + def sub(x, axis=0): + return jax.tree_map(lambda y:jnp.take(y,pos,axis),x) + + copies = self._args["copies"] if self._args["homooligomer"] else 1 + aatype = sub(inputs["aatype"]) + dgram = {"logits":sub(sub(outputs["distogram"]["logits"]),1), + "bin_edges":outputs["distogram"]["bin_edges"]} + atoms = sub(outputs["structure_module"]["final_atom_positions"]) + + I = {"aatype": aatype, "batch": inputs["batch"], "seq_mask":sub(inputs["seq_mask"])} + O = {"distogram": dgram, "structure_module": {"final_atom_positions": atoms}} + aln = get_rmsd_loss(I, O, copies=copies) + + # supervised losses + aux["losses"].update({ + "dgram_cce": get_dgram_loss(I, O, copies=copies, aatype=I["aatype"]), + "fape": get_fape_loss(I, O, copies=copies, clamp=opt["fape_cutoff"]), + "rmsd": aln["rmsd"], + }) + + # unsupervised losses + self._loss_unsupervised(inputs, outputs, aux) + + # sidechain specific losses + if self._args["use_sidechains"] and copies == 1: + + struct = outputs["structure_module"] + pred_pos = sub(struct["final_atom14_positions"]) + true_pos = all_atom.atom37_to_atom14(inputs["batch"]["all_atom_positions"], self._sc["batch"]) + + # sc_rmsd + aln = _get_sc_rmsd_loss(true_pos, pred_pos, self._sc["pos"]) + aux["losses"]["sc_rmsd"] = aln["rmsd"] + + # sc_fape + if not self._args["use_multimer"]: + sc_struct = {**folding.compute_renamed_ground_truth(self._sc["batch"], pred_pos), + "sidechains":{k: sub(struct["sidechains"][k],1) for k in ["frames","atom_pos"]}} + batch = {**inputs["batch"], + **all_atom.atom37_to_frames(**inputs["batch"])} + aux["losses"]["sc_fape"] = folding.sidechain_loss(batch, sc_struct, + self._cfg.model.heads.structure_module)["loss"] + + else: + # TODO + print("ERROR: 'sc_fape' not currently supported for 'multimer' mode") + aux["losses"]["sc_fape"] = 0.0 + + # align final atoms + if self._args["realign"]: + aux["atom_positions"] = aln["align"](aux["atom_positions"]) * aux["atom_mask"][...,None] + + def _loss_hallucination(self, inputs, outputs, aux): + # unsupervised losses + self._loss_unsupervised(inputs, outputs, aux) + + def _loss_unsupervised(self, inputs, outputs, aux): + + # define masks + opt = inputs["opt"] + if "pos" in opt: + C,L = self._args["copies"], self._len + pos = opt["pos"] + if C > 1: pos = (jnp.repeat(pos,C).reshape(-1,C) + jnp.arange(C) * L).T.flatten() + mask_1d = inputs["seq_mask"].at[pos].set(0) + else: + mask_1d = inputs["seq_mask"] + + seq_mask_2d = inputs["seq_mask"][:,None] * inputs["seq_mask"][None,:] + mask_2d = inputs["asym_id"][:,None] == inputs["asym_id"][None,:] + masks = {"mask_1d":mask_1d, + "mask_2d":jnp.where(seq_mask_2d,mask_2d,0)} + + # define losses + losses = { + "exp_res": get_exp_res_loss(outputs, mask_1d=mask_1d), + "plddt": get_plddt_loss(outputs, mask_1d=mask_1d), + "pae": get_pae_loss(outputs, **masks), + "con": get_con_loss(inputs, outputs, opt["con"], **masks), + "helix": get_helix_loss(inputs, outputs) + } + + # define losses at interface + if self._args["copies"] > 1 and not self._args["repeat"]: + masks = {"mask_1d": mask_1d if self._args["homooligomer"] else inputs["seq_mask"], + "mask_2d": jnp.where(seq_mask_2d,mask_2d == False,0)} + losses.update({ + "i_pae": get_pae_loss(outputs, **masks), + "i_con": get_con_loss(inputs, outputs, opt["i_con"], **masks), + }) + + aux["losses"].update(losses) + +##################################################################################### + +def get_plddt(outputs): + logits = outputs["predicted_lddt"]["logits"] + num_bins = logits.shape[-1] + bin_width = 1.0 / num_bins + bin_centers = jnp.arange(start=0.5 * bin_width, stop=1.0, step=bin_width) + probs = jax.nn.softmax(logits, axis=-1) + return jnp.sum(probs * bin_centers[None, :], axis=-1) + +def get_pae(outputs): + prob = jax.nn.softmax(outputs["predicted_aligned_error"]["logits"],-1) + breaks = outputs["predicted_aligned_error"]["breaks"] + step = breaks[1]-breaks[0] + bin_centers = breaks + step/2 + bin_centers = jnp.append(bin_centers,bin_centers[-1]+step) + return (prob*bin_centers).sum(-1) + +def get_ptm(inputs, outputs, interface=False): + pae = {"residue_weights":inputs["seq_mask"], + **outputs["predicted_aligned_error"]} + if interface: + if "asym_id" not in pae: + pae["asym_id"] = inputs["asym_id"] + else: + if "asym_id" in pae: + pae.pop("asym_id") + return confidence.predicted_tm_score(**pae, use_jnp=True) + +def get_dgram_bins(outputs): + dgram = outputs["distogram"]["logits"] + if dgram.shape[-1] == 64: + dgram_bins = jnp.append(0,jnp.linspace(2.3125,21.6875,63)) + if dgram.shape[-1] == 39: + dgram_bins = jnp.linspace(3.25,50.75,39) + 1.25 + return dgram_bins + +def get_contact_map(outputs, dist=8.0): + '''get contact map from distogram''' + dist_logits = outputs["distogram"]["logits"] + dist_bins = get_dgram_bins(outputs) + return (jax.nn.softmax(dist_logits) * (dist_bins < dist)).sum(-1) + +#################### +# confidence metrics +#################### +def mask_loss(x, mask=None, mask_grad=False): + if mask is None: + return x.mean() + else: + x_masked = (x * mask).sum() / (1e-8 + mask.sum()) + if mask_grad: + return jax.lax.stop_gradient(x.mean() - x_masked) + x_masked + else: + return x_masked + +def get_exp_res_loss(outputs, mask_1d=None): + p = jax.nn.sigmoid(outputs["experimentally_resolved"]["logits"]) + p = 1 - p[...,residue_constants.atom_order["CA"]] + return mask_loss(p, mask_1d) + +def get_plddt_loss(outputs, mask_1d=None): + p = 1 - get_plddt(outputs) + return mask_loss(p, mask_1d) + +def get_pae_loss(outputs, mask_1d=None, mask_1b=None, mask_2d=None): + p = get_pae(outputs) / 31.0 + p = (p + p.T) / 2 + L = p.shape[0] + if mask_1d is None: mask_1d = jnp.ones(L) + if mask_1b is None: mask_1b = jnp.ones(L) + if mask_2d is None: mask_2d = jnp.ones((L,L)) + mask_2d = mask_2d * mask_1d[:,None] * mask_1b[None,:] + return mask_loss(p, mask_2d) + +def get_con_loss(inputs, outputs, con_opt, + mask_1d=None, mask_1b=None, mask_2d=None): + + # get top k + def min_k(x, k=1, mask=None): + y = jnp.sort(x if mask is None else jnp.where(mask,x,jnp.nan)) + k_mask = jnp.logical_and(jnp.arange(y.shape[-1]) < k, jnp.isnan(y) == False) + return jnp.where(k_mask,y,0).sum(-1) / (k_mask.sum(-1) + 1e-8) + + # decide on what offset to use + if "offset" in inputs: + offset = inputs["offset"] + else: + idx = inputs["residue_index"].flatten() + offset = idx[:,None] - idx[None,:] + + # define distogram + dgram = outputs["distogram"]["logits"] + dgram_bins = get_dgram_bins(outputs) + + p = _get_con_loss(dgram, dgram_bins, cutoff=con_opt["cutoff"], binary=con_opt["binary"]) + if "seqsep" in con_opt: + m = jnp.abs(offset) >= con_opt["seqsep"] + else: + m = jnp.ones_like(offset) + + # mask results + if mask_1d is None: mask_1d = jnp.ones(m.shape[0]) + if mask_1b is None: mask_1b = jnp.ones(m.shape[0]) + + if mask_2d is None: + m = jnp.logical_and(m, mask_1b) + else: + m = jnp.logical_and(m, mask_2d) + + p = min_k(p, con_opt["num"], m) + return min_k(p, con_opt["num_pos"], mask_1d) + +def _get_con_loss(dgram, dgram_bins, cutoff=None, binary=True): + '''dgram to contacts''' + if cutoff is None: cutoff = dgram_bins[-1] + bins = dgram_bins < cutoff + px = jax.nn.softmax(dgram) + px_ = jax.nn.softmax(dgram - 1e7 * (1-bins)) + # binary/cateogorical cross-entropy + con_loss_cat_ent = -(px_ * jax.nn.log_softmax(dgram)).sum(-1) + con_loss_bin_ent = -jnp.log((bins * px + 1e-8).sum(-1)) + return jnp.where(binary, con_loss_bin_ent, con_loss_cat_ent) + +def get_helix_loss(inputs, outputs): + # decide on what offset to use + if "offset" in inputs: + offset = inputs["offset"] + else: + idx = inputs["residue_index"].flatten() + offset = idx[:,None] - idx[None,:] + + # define distogram + dgram = outputs["distogram"]["logits"] + dgram_bins = get_dgram_bins(outputs) + + mask_2d = inputs["seq_mask"][:,None] * inputs["seq_mask"][None,:] + return _get_helix_loss(dgram, dgram_bins, offset, mask_2d=mask_2d) + +def _get_helix_loss(dgram, dgram_bins, offset=None, mask_2d=None, **kwargs): + '''helix bias loss''' + x = _get_con_loss(dgram, dgram_bins, cutoff=6.0, binary=True) + if offset is None: + if mask_2d is None: + return jnp.diagonal(x,3).mean() + else: + return jnp.diagonal(x * mask_2d,3).sum() + (jnp.diagonal(mask_2d,3).sum() + 1e-8) + else: + mask = offset == 3 + if mask_2d is not None: + mask = jnp.where(mask_2d,mask,0) + return jnp.where(mask,x,0.0).sum() / (mask.sum() + 1e-8) + +#################### +# loss functions +#################### +def get_dgram_loss(inputs, outputs, copies=1, aatype=None, return_mtx=False): + + batch = inputs["batch"] + # gather features + if aatype is None: aatype = batch["aatype"] + pred = outputs["distogram"]["logits"] + + # get true features + x, weights = model.modules.pseudo_beta_fn(aatype=aatype, + all_atom_positions=batch["all_atom_positions"], + all_atom_mask=batch["all_atom_mask"]) + + dm = jnp.square(x[:,None]-x[None,:]).sum(-1,keepdims=True) + bin_edges = jnp.linspace(2.3125, 21.6875, pred.shape[-1] - 1) + true = jax.nn.one_hot((dm > jnp.square(bin_edges)).sum(-1), pred.shape[-1]) + + def loss_fn(t,p,m): + cce = -(t*jax.nn.log_softmax(p)).sum(-1) + return cce, (cce*m).sum((-1,-2))/(m.sum((-1,-2))+1e-8) + + weights = jnp.where(inputs["seq_mask"],weights,0) + return _get_pw_loss(true, pred, loss_fn, weights=weights, copies=copies, return_mtx=return_mtx) + +def get_fape_loss(inputs, outputs, copies=1, clamp=10.0, return_mtx=False): + + def robust_norm(x, axis=-1, keepdims=False, eps=1e-8): + return jnp.sqrt(jnp.square(x).sum(axis=axis, keepdims=keepdims) + eps) + + def get_R(N, CA, C): + (v1,v2) = (C-CA, N-CA) + e1 = v1 / robust_norm(v1, axis=-1, keepdims=True) + c = jnp.einsum('li, li -> l', e1, v2)[:,None] + e2 = v2 - c * e1 + e2 = e2 / robust_norm(e2, axis=-1, keepdims=True) + e3 = jnp.cross(e1, e2, axis=-1) + return jnp.concatenate([e1[:,:,None], e2[:,:,None], e3[:,:,None]], axis=-1) + + def get_ij(R,T): + return jnp.einsum('rji,rsj->rsi',R,T[None,:]-T[:,None]) + + def loss_fn(t,p,m): + fape = robust_norm(t-p) + fape = jnp.clip(fape, 0, clamp) / 10.0 + return fape, (fape*m).sum((-1,-2))/(m.sum((-1,-2)) + 1e-8) + + true = inputs["batch"]["all_atom_positions"] + pred = outputs["structure_module"]["final_atom_positions"] + + N,CA,C = (residue_constants.atom_order[k] for k in ["N","CA","C"]) + + true_mask = jnp.where(inputs["seq_mask"][:,None],inputs["batch"]["all_atom_mask"],0) + weights = true_mask[:,N] * true_mask[:,CA] * true_mask[:,C] + + true = get_ij(get_R(true[:,N],true[:,CA],true[:,C]),true[:,CA]) + pred = get_ij(get_R(pred[:,N],pred[:,CA],pred[:,C]),pred[:,CA]) + + return _get_pw_loss(true, pred, loss_fn, weights=weights, copies=copies, return_mtx=return_mtx) + +def _get_pw_loss(true, pred, loss_fn, weights=None, copies=1, return_mtx=False): + length = true.shape[0] + + if weights is None: + weights = jnp.ones(length) + + F = {"t":true, "p":pred, "m":weights[:,None] * weights[None,:]} + + if copies > 1: + (L,C) = (length//copies, copies-1) + + # intra (L,L,F) + intra = jax.tree_map(lambda x:x[:L,:L], F) + mtx, loss = loss_fn(**intra) + + # inter (C*L,L,F) + inter = jax.tree_map(lambda x:x[L:,:L], F) + if C == 0: + i_mtx, i_loss = loss_fn(**inter) + + else: + # (C,L,L,F) + inter = jax.tree_map(lambda x:x.reshape(C,L,L,-1), inter) + inter = {"t":inter["t"][:,None], # (C,1,L,L,F) + "p":inter["p"][None,:], # (1,C,L,L,F) + "m":inter["m"][:,None,:,:,0]} # (C,1,L,L) + + # (C,C,L,L,F) → (C,C,L,L) → (C,C) → (C) → () + i_mtx, i_loss = loss_fn(**inter) + i_loss = sum([i_loss.min(i).sum() for i in [0,1]]) / 2 + + total_loss = (loss + i_loss) / copies + return (mtx, i_mtx) if return_mtx else total_loss + + else: + mtx, loss = loss_fn(**F) + return mtx if return_mtx else loss + +def get_rmsd_loss(inputs, outputs, L=None, include_L=True, copies=1): + batch = inputs["batch"] + true = batch["all_atom_positions"][:,1] + pred = outputs["structure_module"]["final_atom_positions"][:,1] + weights = jnp.where(inputs["seq_mask"],batch["all_atom_mask"][:,1],0) + return _get_rmsd_loss(true, pred, weights=weights, L=L, include_L=include_L, copies=copies) + +def _get_rmsd_loss(true, pred, weights=None, L=None, include_L=True, copies=1): + ''' + get rmsd + alignment function + align based on the first L positions, computed weighted rmsd using all + positions (if include_L=True) or remaining positions (if include_L=False). + ''' + # normalize weights + length = true.shape[-2] + if weights is None: + weights = (jnp.ones(length)/length)[...,None] + else: + weights = (weights/(weights.sum(-1,keepdims=True) + 1e-8))[...,None] + + # determine alignment [L]ength and remaining [l]ength + if copies > 1: + if L is None: + L = iL = length // copies; C = copies-1 + else: + (iL,C) = ((length-L) // copies, copies) + else: + (L,iL,C) = (length,0,0) if L is None else (L,length-L,1) + + # slice inputs + if iL == 0: + (T,P,W) = (true,pred,weights) + else: + (T,P,W) = (x[...,:L,:] for x in (true,pred,weights)) + (iT,iP,iW) = (x[...,L:,:] for x in (true,pred,weights)) + + # get alignment and rmsd functions + (T_mu,P_mu) = ((x*W).sum(-2,keepdims=True)/W.sum((-1,-2)) for x in (T,P)) + aln = _np_kabsch((P-P_mu)*W, T-T_mu) + align_fn = lambda x: (x - P_mu) @ aln + T_mu + msd_fn = lambda t,p,w: (w*jnp.square(align_fn(p)-t)).sum((-1,-2)) + + # compute rmsd + if iL == 0: + msd = msd_fn(true,pred,weights) + elif C > 1: + # all vs all alignment of remaining, get min RMSD + iT = iT.reshape(-1,C,1,iL,3).swapaxes(0,-3) + iP = iP.reshape(-1,1,C,iL,3).swapaxes(0,-3) + imsd = msd_fn(iT, iP, iW.reshape(-1,C,1,iL,1).swapaxes(0,-3)) + imsd = (imsd.min(0).sum(0) + imsd.min(1).sum(0)) / 2 + imsd = imsd.reshape(jnp.broadcast_shapes(true.shape[:-2],pred.shape[:-2])) + msd = (imsd + msd_fn(T,P,W)) if include_L else (imsd/iW.sum((-1,-2))) + else: + msd = msd_fn(true,pred,weights) if include_L else (msd_fn(iT,iP,iW)/iW.sum((-1,-2))) + rmsd = jnp.sqrt(msd + 1e-8) + + return {"rmsd":rmsd, "align":align_fn} + +def _get_sc_rmsd_loss(true, pred, sc): + '''get sidechain rmsd + alignment function''' + + # select atoms + (T, P) = (true.reshape(-1,3), pred.reshape(-1,3)) + (T, T_alt, P) = (T[sc["pos"]], T[sc["pos_alt"]], P[sc["pos"]]) + + # select non-ambigious atoms + (T_na, P_na) = (T[sc["non_amb"]], P[sc["non_amb"]]) + + # get alignment of non-ambigious atoms + if "weight_non_amb" in sc: + T_mu_na = (T_na * sc["weight_non_amb"]).sum(0) + P_mu_na = (P_na * sc["weight_non_amb"]).sum(0) + aln = _np_kabsch((P_na-P_mu_na) * sc["weight_non_amb"], T_na-T_mu_na) + else: + T_mu_na, P_mu_na = T_na.mean(0), P_na.mean(0) + aln = _np_kabsch(P_na-P_mu_na, T_na-T_mu_na) + + # apply alignment to all atoms + align_fn = lambda x: (x - P_mu_na) @ aln + T_mu_na + P = align_fn(P) + + # compute rmsd + sd = jnp.minimum(jnp.square(P-T).sum(-1), jnp.square(P-T_alt).sum(-1)) + if "weight" in sc: + msd = (sd*sc["weight"]).sum() + else: + msd = sd.mean() + rmsd = jnp.sqrt(msd + 1e-8) + return {"rmsd":rmsd, "align":align_fn} + +def get_seq_ent_loss(inputs): + opt = inputs["opt"] + x = inputs["seq"]["logits"] / opt["temp"] + ent = -(jax.nn.softmax(x) * jax.nn.log_softmax(x)).sum(-1) + mask = inputs["seq_mask"][-x.shape[1]:] + if "fix_pos" in opt: + if "pos" in opt: + p = opt["pos"][opt["fix_pos"]] + else: + p = opt["fix_pos"] + mask = mask.at[p].set(0) + ent = (ent * mask).sum() / (mask.sum() + 1e-8) + return {"seq_ent":ent.mean()} + +def get_mlm_loss(outputs, mask, truth=None): + x = outputs["masked_msa"]["logits"][...,:20] + if truth is None: truth = jax.nn.softmax(x) + ent = -(truth[...,:20] * jax.nn.log_softmax(x)).sum(-1) + ent = (ent * mask).sum(-1) / (mask.sum() + 1e-8) + return {"mlm":ent.mean()} \ No newline at end of file diff --git a/build/lib/colabdesign/af/model.py b/build/lib/colabdesign/af/model.py new file mode 100644 index 00000000..37053bda --- /dev/null +++ b/build/lib/colabdesign/af/model.py @@ -0,0 +1,245 @@ +import os +import jax +import jax.numpy as jnp +import numpy as np +from inspect import signature + +from colabdesign.af.alphafold.model import data, config, model, all_atom + +from colabdesign.shared.model import design_model +from colabdesign.shared.utils import Key + +from colabdesign.af.prep import _af_prep +from colabdesign.af.loss import _af_loss, get_plddt, get_pae, get_ptm +from colabdesign.af.loss import get_contact_map, get_seq_ent_loss, get_mlm_loss +from colabdesign.af.utils import _af_utils +from colabdesign.af.design import _af_design +from colabdesign.af.inputs import _af_inputs, update_seq, update_aatype + +################################################################ +# MK_DESIGN_MODEL - initialize model, and put it all together +################################################################ + +class mk_af_model(design_model, _af_inputs, _af_loss, _af_prep, _af_design, _af_utils): + def __init__(self, + protocol="fixbb", + use_multimer=False, + use_templates=False, + debug=False, + data_dir=".", + **kwargs): + assert protocol in ["fixbb","hallucination","binder","partial"] + + self.protocol = protocol + self._num = kwargs.pop("num_seq",1) + self._args = {"use_templates":use_templates, "use_multimer":use_multimer, "use_bfloat16":True, + "recycle_mode":"last", "use_mlm": False, "realign": True, + "debug":debug, "repeat":False, "homooligomer":False, "copies":1, + "optimizer":"sgd", "best_metric":"loss", + "traj_iter":1, "traj_max":10000, + "clear_prev": True, "use_dgram":False, + "shuffle_first":True, "use_remat":True, + "alphabet_size":20, + "use_initial_guess":False, "use_initial_atom_pos":False} + + if self.protocol == "binder": self._args["use_templates"] = True + + self.opt = {"dropout":True, "pssm_hard":False, "learning_rate":0.1, "norm_seq_grad":True, + "num_recycles":0, "num_models":1, "sample_models":True, + "temp":1.0, "soft":0.0, "hard":0.0, "alpha":2.0, + "con": {"num":2, "cutoff":14.0, "binary":False, "seqsep":9, "num_pos":float("inf")}, + "i_con": {"num":1, "cutoff":21.6875, "binary":False, "num_pos":float("inf")}, + "template": {"rm_ic":False}, + "weights": {"seq_ent":0.0, "plddt":0.0, "pae":0.0, "exp_res":0.0, "helix":0.0}, + "fape_cutoff":10.0} + + self._params = {} + self._inputs = {} + self._tmp = {"traj":{"seq":[],"xyz":[],"plddt":[],"pae":[]}, + "log":[],"best":{}} + + # set arguments/options + if "initial_guess" in kwargs: kwargs["use_initial_guess"] = kwargs.pop("initial_guess") + model_names = kwargs.pop("model_names",None) + keys = list(kwargs.keys()) + for k in keys: + if k in self._args: self._args[k] = kwargs.pop(k) + if k in self.opt: self.opt[k] = kwargs.pop(k) + + # collect callbacks + self._callbacks = {"model": {"pre": kwargs.pop("pre_callback",None), + "post":kwargs.pop("post_callback",None), + "loss":kwargs.pop("loss_callback",None)}, + "design":{"pre": kwargs.pop("pre_design_callback",None), + "post":kwargs.pop("post_design_callback",None)}} + + for m,n in self._callbacks.items(): + for k,v in n.items(): + if v is None: v = [] + if not isinstance(v,list): v = [v] + self._callbacks[m][k] = v + + if self._args["use_mlm"]: + self.opt["mlm_dropout"] = 0.15 + self.opt["weights"]["mlm"] = 0.1 + + assert len(kwargs) == 0, f"ERROR: the following inputs were not set: {kwargs}" + + ############################# + # configure AlphaFold + ############################# + if self._args["use_multimer"]: + self._cfg = config.model_config("model_1_multimer") + # TODO + self.opt["pssm_hard"] = True + else: + self._cfg = config.model_config("model_1_ptm" if self._args["use_templates"] else "model_3_ptm") + + if self._args["recycle_mode"] in ["average","first","last","sample"]: + num_recycles = 0 + else: + num_recycles = self.opt["num_recycles"] + self._cfg.model.num_recycle = num_recycles + self._cfg.model.global_config.use_remat = self._args["use_remat"] + self._cfg.model.global_config.use_dgram = self._args["use_dgram"] + self._cfg.model.global_config.bfloat16 = self._args["use_bfloat16"] + + # load model_params + if model_names is None: + model_names = [] + if self._args["use_multimer"]: + model_names += [f"model_{k}_multimer_v3" for k in [1,2,3,4,5]] + else: + if self._args["use_templates"]: + model_names += [f"model_{k}_ptm" for k in [1,2]] + else: + model_names += [f"model_{k}_ptm" for k in [1,2,3,4,5]] + + self._model_params, self._model_names = [],[] + for model_name in model_names: + params = data.get_model_haiku_params(model_name=model_name, data_dir=data_dir, fuse=True) + if params is not None: + if not self._args["use_multimer"] and not self._args["use_templates"]: + params = {k:v for k,v in params.items() if "template" not in k} + self._model_params.append(params) + self._model_names.append(model_name) + else: + print(f"WARNING: '{model_name}' not found") + + ##################################### + # set protocol specific functions + ##################################### + idx = ["fixbb","hallucination","binder","partial"].index(self.protocol) + self.prep_inputs = [self._prep_fixbb, self._prep_hallucination, self._prep_binder, self._prep_partial][idx] + self._get_loss = [self._loss_fixbb, self._loss_hallucination, self._loss_binder, self._loss_partial][idx] + + def _get_model(self, cfg, callback=None): + + a = self._args + runner = model.RunModel(cfg, + recycle_mode=a["recycle_mode"], + use_multimer=a["use_multimer"]) + + # setup function to get gradients + def _model(params, model_params, inputs, key): + inputs["params"] = params + opt = inputs["opt"] + + aux = {} + key = Key(key=key).get + + ####################################################################### + # INPUTS + ####################################################################### + # get sequence + seq = self._get_seq(inputs, aux, key()) + + # update sequence features + pssm = jnp.where(opt["pssm_hard"], seq["hard"], seq["pseudo"]) + if a["use_mlm"]: + shape = seq["pseudo"].shape[:2] + mlm = jax.random.bernoulli(key(),opt["mlm_dropout"],shape) + update_seq(seq["pseudo"], inputs, seq_pssm=pssm, mlm=mlm) + else: + update_seq(seq["pseudo"], inputs, seq_pssm=pssm) + + # update amino acid sidechain identity + update_aatype(seq["pseudo"][0].argmax(-1), inputs) + + # define masks + inputs["msa_mask"] = jnp.where(inputs["seq_mask"],inputs["msa_mask"],0) + + inputs["seq"] = aux["seq"] + + # update template features + inputs["mask_template_interchain"] = opt["template"]["rm_ic"] + if a["use_templates"]: + self._update_template(inputs, key()) + + # set dropout + inputs["use_dropout"] = opt["dropout"] + + if "batch" not in inputs: + inputs["batch"] = None + + # pre callback + for fn in self._callbacks["model"]["pre"]: + fn_args = {"inputs":inputs, "opt":opt, "aux":aux, + "seq":seq, "key":key(), "params":params} + sub_args = {k:fn_args.get(k,None) for k in signature(fn).parameters} + fn(**sub_args) + + ####################################################################### + # OUTPUTS + ####################################################################### + outputs = runner.apply(model_params, key(), inputs) + + # add aux outputs + aux.update({"atom_positions": outputs["structure_module"]["final_atom_positions"], + "atom_mask": outputs["structure_module"]["final_atom_mask"], + "residue_index": inputs["residue_index"], + "aatype": inputs["aatype"], + "plddt": get_plddt(outputs), + "pae": get_pae(outputs), + "ptm": get_ptm(inputs, outputs), + "i_ptm": get_ptm(inputs, outputs, interface=True), + "cmap": get_contact_map(outputs, opt["con"]["cutoff"]), + "i_cmap": get_contact_map(outputs, opt["i_con"]["cutoff"]), + "prev": outputs["prev"]}) + + ####################################################################### + # LOSS + ####################################################################### + aux["losses"] = {} + + # add protocol specific losses + self._get_loss(inputs=inputs, outputs=outputs, aux=aux) + + # sequence entropy loss + aux["losses"].update(get_seq_ent_loss(inputs)) + + # experimental masked-language-modeling + if a["use_mlm"]: + aux["mlm"] = outputs["masked_msa"]["logits"] + mask = jnp.where(inputs["seq_mask"],mlm,0) + aux["losses"].update(get_mlm_loss(outputs, mask=mask, truth=seq["pssm"])) + + # run user defined callbacks + for c in ["loss","post"]: + for fn in self._callbacks["model"][c]: + fn_args = {"inputs":inputs, "outputs":outputs, "opt":opt, + "aux":aux, "seq":seq, "key":key(), "params":params} + sub_args = {k:fn_args.get(k,None) for k in signature(fn).parameters} + if c == "loss": aux["losses"].update(fn(**sub_args)) + if c == "post": fn(**sub_args) + + # save for debugging + if a["debug"]: aux["debug"] = {"inputs":inputs,"outputs":outputs} + + # weighted loss + w = opt["weights"] + loss = sum([v * w[k] if k in w else v for k,v in aux["losses"].items()]) + return loss, aux + + return {"grad_fn":jax.jit(jax.value_and_grad(_model, has_aux=True, argnums=0)), + "fn":jax.jit(_model), "runner":runner} diff --git a/build/lib/colabdesign/af/prep.py b/build/lib/colabdesign/af/prep.py new file mode 100644 index 00000000..b5dc238c --- /dev/null +++ b/build/lib/colabdesign/af/prep.py @@ -0,0 +1,581 @@ +import jax +import jax.numpy as jnp +import numpy as np +import re + +from colabdesign.af.alphafold.data import pipeline, prep_inputs +from colabdesign.af.alphafold.common import protein, residue_constants +from colabdesign.af.alphafold.model.tf import shape_placeholders +from colabdesign.af.alphafold.model import config + + +from colabdesign.shared.protein import _np_get_cb, pdb_to_string +from colabdesign.shared.prep import prep_pos +from colabdesign.shared.utils import copy_dict +from colabdesign.shared.model import order_aa + +resname_to_idx = residue_constants.resname_to_idx +idx_to_resname = dict((v,k) for k,v in resname_to_idx.items()) + +################################################# +# AF_PREP - input prep functions +################################################# +class _af_prep: + + def _prep_model(self, **kwargs): + '''prep model''' + if not hasattr(self,"_model") or self._cfg != self._model["runner"].config: + self._cfg.model.global_config.subbatch_size = None + self._model = self._get_model(self._cfg) + if sum(self._lengths) > 384: + self._cfg.model.global_config.subbatch_size = 4 + self._model["fn"] = self._get_model(self._cfg)["fn"] + + self._opt = copy_dict(self.opt) + self.restart(**kwargs) + + def _prep_features(self, num_res, num_seq=None, num_templates=1): + '''process features''' + if num_seq is None: num_seq = self._num + return prep_input_features(L=num_res, N=num_seq, T=num_templates) + + def _prep_fixbb(self, pdb_filename, chain=None, + copies=1, repeat=False, homooligomer=False, + rm_template=False, + rm_template_seq=True, + rm_template_sc=True, + rm_template_ic=False, + fix_pos=None, ignore_missing=True, **kwargs): + ''' + prep inputs for fixed backbone design + --------------------------------------------------- + if copies > 1: + -homooligomer=True - input pdb chains are parsed as homo-oligomeric units + -repeat=True - tie the repeating sequence within single chain + -rm_template_seq - if template is defined, remove information about template sequence + -fix_pos="1,2-10" - specify which positions to keep fixed in the sequence + note: supervised loss is applied to all positions, use "partial" + protocol to apply supervised loss to only subset of positions + -ignore_missing=True - skip positions that have missing density (no CA coordinate) + --------------------------------------------------- + ''' + # prep features + self._pdb = prep_pdb(pdb_filename, chain=chain, ignore_missing=ignore_missing, + offsets=kwargs.pop("pdb_offsets",None), + lengths=kwargs.pop("pdb_lengths",None)) + + self._len = self._pdb["residue_index"].shape[0] + self._lengths = [self._len] + + # feat dims + num_seq = self._num + res_idx = self._pdb["residue_index"] + + # get [pos]itions of interests + if fix_pos is not None and fix_pos != "": + self._pos_info = prep_pos(fix_pos, **self._pdb["idx"]) + self.opt["fix_pos"] = self._pos_info["pos"] + + if homooligomer and chain is not None and copies == 1: + copies = len(chain.split(",")) + + # repeat/homo-oligomeric support + if copies > 1: + + if repeat or homooligomer: + self._len = self._len // copies + if "fix_pos" in self.opt: + self.opt["fix_pos"] = self.opt["fix_pos"][self.opt["fix_pos"] < self._len] + + if repeat: + self._lengths = [self._len * copies] + block_diag = False + + else: + self._lengths = [self._len] * copies + block_diag = not self._args["use_multimer"] + + res_idx = repeat_idx(res_idx[:self._len], copies) + num_seq = (self._num * copies + 1) if block_diag else self._num + self.opt["weights"].update({"i_pae":0.0, "i_con":0.0}) + + self._args.update({"copies":copies, "repeat":repeat, "homooligomer":homooligomer, "block_diag":block_diag}) + homooligomer = not repeat + else: + self._lengths = self._pdb["lengths"] + + # configure input features + self._inputs = self._prep_features(num_res=sum(self._lengths), num_seq=num_seq) + self._inputs["residue_index"] = res_idx + self._inputs["batch"] = make_fixed_size(self._pdb["batch"], num_res=sum(self._lengths)) + self._inputs.update(get_multi_id(self._lengths, homooligomer=homooligomer)) + + # configure options/weights + self.opt["weights"].update({"dgram_cce":1.0, "rmsd":0.0, "fape":0.0, "con":0.0}) + self._wt_aatype = self._inputs["batch"]["aatype"][:self._len] + + # configure template [opt]ions + rm,L = {},sum(self._lengths) + for n,x in {"rm_template": rm_template, + "rm_template_seq":rm_template_seq, + "rm_template_sc": rm_template_sc}.items(): + rm[n] = np.full(L,False) + if isinstance(x,str): + rm[n][prep_pos(x,**self._pdb["idx"])["pos"]] = True + else: + rm[n][:] = x + self.opt["template"]["rm_ic"] = rm_template_ic + self._inputs.update(rm) + + self._prep_model(**kwargs) + + def _prep_hallucination(self, length=100, copies=1, repeat=False, **kwargs): + ''' + prep inputs for hallucination + --------------------------------------------------- + if copies > 1: + -repeat=True - tie the repeating sequence within single chain + --------------------------------------------------- + ''' + + # define num copies (for repeats/ homo-oligomers) + if not repeat and copies > 1 and not self._args["use_multimer"]: + (num_seq, block_diag) = (self._num * copies + 1, True) + else: + (num_seq, block_diag) = (self._num, False) + + self._args.update({"repeat":repeat,"block_diag":block_diag,"copies":copies}) + + # prep features + self._len = length + + # set weights + self.opt["weights"].update({"con":1.0}) + if copies > 1: + if repeat: + offset = 1 + self._lengths = [self._len * copies] + self._args["repeat"] = True + else: + offset = 50 + self._lengths = [self._len] * copies + self.opt["weights"].update({"i_pae":0.0, "i_con":1.0}) + self._args["homooligomer"] = True + res_idx = repeat_idx(np.arange(length), copies, offset=offset) + else: + self._lengths = [self._len] + res_idx = np.arange(length) + + # configure input features + self._inputs = self._prep_features(num_res=sum(self._lengths), num_seq=num_seq) + self._inputs["residue_index"] = res_idx + self._inputs.update(get_multi_id(self._lengths, homooligomer=True)) + + self._prep_model(**kwargs) + + def _prep_binder(self, pdb_filename, + target_chain="A", binder_len=50, + rm_target = False, + rm_target_seq = False, + rm_target_sc = False, + + # if binder_chain is defined + binder_chain=None, + rm_binder=True, + rm_binder_seq=True, + rm_binder_sc=True, + rm_template_ic=False, + + hotspot=None, ignore_missing=True, **kwargs): + ''' + prep inputs for binder design + --------------------------------------------------- + -binder_len = length of binder to hallucinate (option ignored if binder_chain is defined) + -binder_chain = chain of binder to redesign + -use_binder_template = use binder coordinates as template input + -rm_template_ic = use target and binder coordinates as seperate template inputs + -hotspot = define position/hotspots on target + -rm_[binder/target]_seq = remove sequence info from template + -rm_[binder/target]_sc = remove sidechain info from template + -ignore_missing=True - skip positions that have missing density (no CA coordinate) + --------------------------------------------------- + ''' + redesign = binder_chain is not None + rm_binder = not kwargs.pop("use_binder_template", not rm_binder) + + self._args.update({"redesign":redesign}) + + # get pdb info + target_chain = kwargs.pop("chain",target_chain) # backward comp + chains = f"{target_chain},{binder_chain}" if redesign else target_chain + im = [True] * len(target_chain.split(",")) + if redesign: im += [ignore_missing] * len(binder_chain.split(",")) + + self._pdb = prep_pdb(pdb_filename, chain=chains, ignore_missing=im) + res_idx = self._pdb["residue_index"] + + if redesign: + self._target_len = sum([(self._pdb["idx"]["chain"] == c).sum() for c in target_chain.split(",")]) + self._binder_len = sum([(self._pdb["idx"]["chain"] == c).sum() for c in binder_chain.split(",")]) + else: + self._target_len = self._pdb["residue_index"].shape[0] + self._binder_len = binder_len + res_idx = np.append(res_idx, res_idx[-1] + np.arange(binder_len) + 50) + + self._len = self._binder_len + self._lengths = [self._target_len, self._binder_len] + + # gather hotspot info + if hotspot is not None: + self.opt["hotspot"] = prep_pos(hotspot, **self._pdb["idx"])["pos"] + + if redesign: + # binder redesign + self._wt_aatype = self._pdb["batch"]["aatype"][self._target_len:] + self.opt["weights"].update({"dgram_cce":1.0, "rmsd":0.0, "fape":0.0, + "con":0.0, "i_con":0.0, "i_pae":0.0}) + else: + # binder hallucination + self._pdb["batch"] = make_fixed_size(self._pdb["batch"], num_res=sum(self._lengths)) + self.opt["weights"].update({"plddt":0.1, "con":0.0, "i_con":1.0, "i_pae":0.0}) + + # configure input features + self._inputs = self._prep_features(num_res=sum(self._lengths), num_seq=1) + self._inputs["residue_index"] = res_idx + self._inputs["batch"] = self._pdb["batch"] + self._inputs.update(get_multi_id(self._lengths)) + + # configure template rm masks + (T,L,rm) = (self._lengths[0],sum(self._lengths),{}) + rm_opt = { + "rm_template": {"target":rm_target, "binder":rm_binder}, + "rm_template_seq":{"target":rm_target_seq,"binder":rm_binder_seq}, + "rm_template_sc": {"target":rm_target_sc, "binder":rm_binder_sc} + } + for n,x in rm_opt.items(): + rm[n] = np.full(L,False) + for m,y in x.items(): + if isinstance(y,str): + rm[n][prep_pos(y,**self._pdb["idx"])["pos"]] = True + else: + if m == "target": rm[n][:T] = y + if m == "binder": rm[n][T:] = y + + # set template [opt]ions + self.opt["template"]["rm_ic"] = rm_template_ic + self._inputs.update(rm) + + self._prep_model(**kwargs) + + def _prep_partial(self, pdb_filename, chain=None, length=None, + copies=1, repeat=False, homooligomer=False, + pos=None, fix_pos=None, use_sidechains=False, atoms_to_exclude=None, + rm_template=False, + rm_template_seq=False, + rm_template_sc=False, + rm_template_ic=False, + ignore_missing=True, **kwargs): + ''' + prep input for partial hallucination + --------------------------------------------------- + -length=100 - total length of protein (if different from input PDB) + -pos="1,2-10" - specify which positions to apply supervised loss to + -use_sidechains=True - add a sidechain supervised loss to the specified positions + -atoms_to_exclude=["N","C","O"] (for sc_rmsd loss, specify which atoms to exclude) + -rm_template_seq - if template is defined, remove information about template sequence + -ignore_missing=True - skip positions that have missing density (no CA coordinate) + --------------------------------------------------- + ''' + # prep features + self._pdb = prep_pdb(pdb_filename, chain=chain, ignore_missing=ignore_missing, + offsets=kwargs.pop("pdb_offsets",None), + lengths=kwargs.pop("pdb_lengths",None)) + + self._pdb["len"] = sum(self._pdb["lengths"]) + + self._len = self._pdb["len"] if length is None else length + self._lengths = [self._len] + + # feat dims + num_seq = self._num + res_idx = np.arange(self._len) + + # get [pos]itions of interests + if pos is None: + self.opt["pos"] = self._pdb["pos"] = np.arange(self._pdb["len"]) + self._pos_info = {"length":np.array([self._pdb["len"]]), "pos":self._pdb["pos"]} + else: + self._pos_info = prep_pos(pos, **self._pdb["idx"]) + self.opt["pos"] = self._pdb["pos"] = self._pos_info["pos"] + + if homooligomer and chain is not None and copies == 1: + copies = len(chain.split(",")) + + # repeat/homo-oligomeric support + if copies > 1: + + if repeat or homooligomer: + self._len = self._len // copies + self._pdb["len"] = self._pdb["len"] // copies + self.opt["pos"] = self._pdb["pos"][self._pdb["pos"] < self._pdb["len"]] + + # repeat positions across copies + self._pdb["pos"] = repeat_pos(self.opt["pos"], copies, self._pdb["len"]) + + if repeat: + self._lengths = [self._len * copies] + block_diag = False + + else: + self._lengths = [self._len] * copies + block_diag = not self._args["use_multimer"] + + num_seq = (self._num * copies + 1) if block_diag else self._num + res_idx = repeat_idx(np.arange(self._len), copies) + + self.opt["weights"].update({"i_pae":0.0, "i_con":1.0}) + + self._args.update({"copies":copies, "repeat":repeat, "homooligomer":homooligomer, "block_diag":block_diag}) + homooligomer = not repeat + + # configure input features + self._inputs = self._prep_features(num_res=sum(self._lengths), num_seq=num_seq) + self._inputs["residue_index"] = res_idx + self._inputs["batch"] = jax.tree_util.tree_map(lambda x:x[self._pdb["pos"]], self._pdb["batch"]) + self._inputs.update(get_multi_id(self._lengths, homooligomer=homooligomer)) + + # configure options/weights + self.opt["weights"].update({"dgram_cce":1.0, "rmsd":0.0, "fape":0.0, "con":1.0}) + self._wt_aatype = self._pdb["batch"]["aatype"][self.opt["pos"]] + + # configure sidechains + self._args["use_sidechains"] = use_sidechains + if use_sidechains: + self._sc = {"batch":prep_inputs.make_atom14_positions(self._inputs["batch"]), + "pos":get_sc_pos(self._wt_aatype, atoms_to_exclude)} + self.opt["weights"].update({"sc_rmsd":0.1, "sc_fape":0.1}) + self.opt["fix_pos"] = np.arange(self.opt["pos"].shape[0]) + self._wt_aatype_sub = self._wt_aatype + + elif fix_pos is not None and fix_pos != "": + sub_fix_pos = [] + sub_i = [] + pos = self.opt["pos"].tolist() + for i in prep_pos(fix_pos, **self._pdb["idx"])["pos"]: + if i in pos: + sub_i.append(i) + sub_fix_pos.append(pos.index(i)) + self.opt["fix_pos"] = np.array(sub_fix_pos) + self._wt_aatype_sub = self._pdb["batch"]["aatype"][sub_i] + + elif kwargs.pop("fix_seq",False): + self.opt["fix_pos"] = np.arange(self.opt["pos"].shape[0]) + self._wt_aatype_sub = self._wt_aatype + + self.opt["template"].update({"rm_ic":rm_template_ic}) + self._inputs.update({"rm_template": rm_template, + "rm_template_seq": rm_template_seq, + "rm_template_sc": rm_template_sc}) + + self._prep_model(**kwargs) + +####################### +# utils +####################### +def repeat_idx(idx, copies=1, offset=50): + idx_offset = np.repeat(np.cumsum([0]+[idx[-1]+offset]*(copies-1)),len(idx)) + return np.tile(idx,copies) + idx_offset + +def repeat_pos(pos, copies, length): + return (np.repeat(pos,copies).reshape(-1,copies) + np.arange(copies) * length).T.flatten() + +def prep_pdb(pdb_filename, chain=None, + offsets=None, lengths=None, + ignore_missing=False): + '''extract features from pdb''' + + def add_cb(batch): + '''add missing CB atoms based on N,CA,C''' + p,m = batch["all_atom_positions"], batch["all_atom_mask"] + atom_idx = residue_constants.atom_order + atoms = {k:p[...,atom_idx[k],:] for k in ["N","CA","C"]} + cb = atom_idx["CB"] + cb_atoms = _np_get_cb(**atoms, use_jax=False) + cb_mask = np.prod([m[...,atom_idx[k]] for k in ["N","CA","C"]],0) + batch["all_atom_positions"][...,cb,:] = np.where(m[:,cb,None], p[:,cb,:], cb_atoms) + batch["all_atom_mask"][...,cb] = (m[:,cb] + cb_mask) > 0 + return {"atoms":batch["all_atom_positions"][:,cb],"mask":cb_mask} + + if isinstance(chain,str) and "," in chain: + chains = chain.split(",") + elif not isinstance(chain,list): + chains = [chain] + + o,last = [],0 + residue_idx, chain_idx = [],[] + full_lengths = [] + + # go through each defined chain + for n,chain in enumerate(chains): + pdb_str = pdb_to_string(pdb_filename, chains=chain, models=[1]) + protein_obj = protein.from_pdb_string(pdb_str, chain_id=chain) + batch = {'aatype': protein_obj.aatype, + 'all_atom_positions': protein_obj.atom_positions, + 'all_atom_mask': protein_obj.atom_mask, + 'residue_index': protein_obj.residue_index} + + cb_feat = add_cb(batch) # add in missing cb (in the case of glycine) + + im = ignore_missing[n] if isinstance(ignore_missing,list) else ignore_missing + if im: + r = batch["all_atom_mask"][:,0] == 1 + batch = jax.tree_util.tree_map(lambda x:x[r], batch) + residue_index = batch["residue_index"] + last + + else: + # pad values + offset = 0 if offsets is None else (offsets[n] if isinstance(offsets,list) else offsets) + r = offset + (protein_obj.residue_index - protein_obj.residue_index.min()) + length = (r.max()+1) if lengths is None else (lengths[n] if isinstance(lengths,list) else lengths) + def scatter(x, value=0): + shape = (length,) + x.shape[1:] + y = np.full(shape, value, dtype=x.dtype) + y[r] = x + return y + + batch = {"aatype":scatter(batch["aatype"],-1), + "all_atom_positions":scatter(batch["all_atom_positions"]), + "all_atom_mask":scatter(batch["all_atom_mask"]), + "residue_index":scatter(batch["residue_index"],-1)} + + residue_index = np.arange(length) + last + + last = residue_index[-1] + 50 + o.append({"batch":batch, + "residue_index": residue_index, + "cb_feat":cb_feat}) + + residue_idx.append(batch.pop("residue_index")) + chain_idx.append([chain] * len(residue_idx[-1])) + full_lengths.append(len(residue_index)) + + # concatenate chains + o = jax.tree_util.tree_map(lambda *x:np.concatenate(x,0),*o) + + # save original residue and chain index + o["idx"] = {"residue":np.concatenate(residue_idx), "chain":np.concatenate(chain_idx)} + o["lengths"] = full_lengths + return o + +def make_fixed_size(feat, num_res, num_seq=1, num_templates=1): + '''pad input features''' + shape_schema = {k:v for k,v in config.CONFIG.data.eval.feat.items()} + + pad_size_map = { + shape_placeholders.NUM_RES: num_res, + shape_placeholders.NUM_MSA_SEQ: num_seq, + shape_placeholders.NUM_EXTRA_SEQ: 1, + shape_placeholders.NUM_TEMPLATES: num_templates + } + for k,v in feat.items(): + if k == "batch": + feat[k] = make_fixed_size(v, num_res) + else: + shape = list(v.shape) + schema = shape_schema[k] + assert len(shape) == len(schema), ( + f'Rank mismatch between shape and shape schema for {k}: ' + f'{shape} vs {schema}') + pad_size = [pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)] + padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)] + feat[k] = np.pad(v, padding) + return feat + +def get_sc_pos(aa_ident, atoms_to_exclude=None): + '''get sidechain indices/weights for all_atom14_positions''' + + # decide what atoms to exclude for each residue type + a2e = {} + for r in resname_to_idx: + if isinstance(atoms_to_exclude,dict): + a2e[r] = atoms_to_exclude.get(r,atoms_to_exclude.get("ALL",["N","C","O"])) + else: + a2e[r] = ["N","C","O"] if atoms_to_exclude is None else atoms_to_exclude + + # collect atom indices + pos,pos_alt = [],[] + N,N_non_amb = [],[] + for n,a in enumerate(aa_ident): + aa = idx_to_resname[a] + atoms = set(residue_constants.residue_atoms[aa]) + atoms14 = residue_constants.restype_name_to_atom14_names[aa] + swaps = residue_constants.residue_atom_renaming_swaps.get(aa,{}) + swaps.update({v:k for k,v in swaps.items()}) + for atom in atoms.difference(a2e[aa]): + pos.append(n * 14 + atoms14.index(atom)) + if atom in swaps: + pos_alt.append(n * 14 + atoms14.index(swaps[atom])) + else: + pos_alt.append(pos[-1]) + N_non_amb.append(n) + N.append(n) + + pos, pos_alt = np.asarray(pos), np.asarray(pos_alt) + non_amb = pos == pos_alt + N, N_non_amb = np.asarray(N), np.asarray(N_non_amb) + w = np.array([1/(n == N).sum() for n in N]) + w_na = np.array([1/(n == N_non_amb).sum() for n in N_non_amb]) + w, w_na = w/w.sum(), w_na/w_na.sum() + return {"pos":pos, "pos_alt":pos_alt, "non_amb":non_amb, + "weight":w, "weight_non_amb":w_na[:,None]} + +def prep_input_features(L, N=1, T=1, eN=1): + ''' + given [L]ength, [N]umber of sequences and number of [T]emplates + return dictionary of blank features + ''' + inputs = {'aatype': np.zeros(L,int), + 'target_feat': np.zeros((L,20)), + 'msa_feat': np.zeros((N,L,49)), + # 23 = one_hot -> (20, UNK, GAP, MASK) + # 1 = has deletion + # 1 = deletion_value + # 23 = profile + # 1 = deletion_mean_value + + 'seq_mask': np.ones(L), + 'msa_mask': np.ones((N,L)), + 'msa_row_mask': np.ones(N), + 'atom14_atom_exists': np.zeros((L,14)), + 'atom37_atom_exists': np.zeros((L,37)), + 'residx_atom14_to_atom37': np.zeros((L,14),int), + 'residx_atom37_to_atom14': np.zeros((L,37),int), + 'residue_index': np.arange(L), + 'extra_deletion_value': np.zeros((eN,L)), + 'extra_has_deletion': np.zeros((eN,L)), + 'extra_msa': np.zeros((eN,L),int), + 'extra_msa_mask': np.zeros((eN,L)), + 'extra_msa_row_mask': np.zeros(eN), + + # for template inputs + 'template_aatype': np.zeros((T,L),int), + 'template_all_atom_mask': np.zeros((T,L,37)), + 'template_all_atom_positions': np.zeros((T,L,37,3)), + 'template_mask': np.zeros(T), + 'template_pseudo_beta': np.zeros((T,L,3)), + 'template_pseudo_beta_mask': np.zeros((T,L)), + + # for alphafold-multimer + 'asym_id': np.zeros(L), + 'sym_id': np.zeros(L), + 'entity_id': np.zeros(L), + 'all_atom_positions': np.zeros((N,37,3))} + return inputs + +def get_multi_id(lengths, homooligomer=False): + '''set info for alphafold-multimer''' + i = np.concatenate([[n]*l for n,l in enumerate(lengths)]) + if homooligomer: + return {"asym_id":i, "sym_id":i, "entity_id":np.zeros_like(i)} + else: + return {"asym_id":i, "sym_id":i, "entity_id":i} \ No newline at end of file diff --git a/build/lib/colabdesign/af/utils.py b/build/lib/colabdesign/af/utils.py new file mode 100644 index 00000000..ca9dd725 --- /dev/null +++ b/build/lib/colabdesign/af/utils.py @@ -0,0 +1,189 @@ +import jax +import jax.numpy as jnp +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.gridspec import GridSpec + +from colabdesign.shared.protein import _np_kabsch +from colabdesign.shared.utils import update_dict, Key +from colabdesign.shared.plot import plot_pseudo_3D, make_animation, show_pdb +from colabdesign.shared.protein import renum_pdb_str +from colabdesign.af.alphafold.common import protein + +#################################################### +# AF_UTILS - various utils (save, plot, etc) +#################################################### +class _af_utils: + + def set_opt(self, *args, **kwargs): + ''' + set [opt]ions + ------------------- + note: model.restart() resets the [opt]ions to their defaults + use model.set_opt(..., set_defaults=True) + or model.restart(..., reset_opt=False) to avoid this + ------------------- + model.set_opt(num_models=1, num_recycles=0) + model.set_opt(con=dict(num=1)) or set_opt({"con":{"num":1}}) or set_opt("con",num=1) + model.set_opt(lr=1, set_defaults=True) + ''' + ks = list(kwargs.keys()) + self.set_args(**{k:kwargs.pop(k) for k in ks if k in self._args}) + + if kwargs.pop("set_defaults", False): + update_dict(self._opt, *args, **kwargs) + + update_dict(self.opt, *args, **kwargs) + + def set_args(self, **kwargs): + ''' + set [arg]uments + ''' + for k in ["best_metric", "traj_iter", "shuffle_first"]: + if k in kwargs: self._args[k] = kwargs.pop(k) + + if "recycle_mode" in kwargs: + ok_recycle_mode_swap = ["average","sample","first","last"] + if kwargs["recycle_mode"] in ok_recycle_mode_swap and self._args["recycle_mode"] in ok_recycle_mode_swap: + self._args["recycle_mode"] = kwargs.pop("recycle_mode") + else: + print(f"ERROR: use {self.__class__.__name__}(recycle_mode=...) to set the recycle_mode") + + if "optimizer" in kwargs: + self.set_optimizer(kwargs.pop("optimizer"), + learning_rate=kwargs.pop("learning_rate", None)) + + ks = list(kwargs.keys()) + if len(ks) > 0: + print(f"ERROR: the following args were not set: {ks}") + + def get_loss(self, x="loss"): + '''output the loss (for entire trajectory)''' + return np.array([loss[x] for loss in self._tmp["log"]]) + + def save_pdb(self, filename=None, get_best=True, renum_pdb=True, aux=None): + ''' + save pdb coordinates (if filename provided, otherwise return as string) + - set get_best=False, to get the last sampled sequence + ''' + if aux is None: + aux = self._tmp["best"]["aux"] if (get_best and "aux" in self._tmp["best"]) else self.aux + aux = aux["all"] + + p = {k:aux[k] for k in ["aatype","residue_index","atom_positions","atom_mask"]} + p["b_factors"] = 100 * p["atom_mask"] * aux["plddt"][...,None] + + def to_pdb_str(x, n=None): + p_str = protein.to_pdb(protein.Protein(**x)) + p_str = "\n".join(p_str.splitlines()[1:-2]) + if renum_pdb: p_str = renum_pdb_str(p_str, self._lengths) + if n is not None: + p_str = f"MODEL{n:8}\n{p_str}\nENDMDL\n" + return p_str + + p_str = "" + for n in range(p["atom_positions"].shape[0]): + p_str += to_pdb_str(jax.tree_map(lambda x:x[n],p), n+1) + p_str += "END\n" + + if filename is None: + return p_str + else: + with open(filename, 'w') as f: + f.write(p_str) + + #------------------------------------- + # plotting functions + #------------------------------------- + def animate(self, s=0, e=None, dpi=100, get_best=True, traj=None, aux=None, color_by="plddt"): + ''' + animate the trajectory + - use [s]tart and [e]nd to define range to be animated + - use dpi to specify the resolution of animation + - color_by = ["plddt","chain","rainbow"] + ''' + if aux is None: + aux = self._tmp["best"]["aux"] if (get_best and "aux" in self._tmp["best"]) else self.aux + aux = aux["all"] + if self.protocol in ["fixbb","binder"]: + pos_ref = self._inputs["batch"]["all_atom_positions"][:,1].copy() + pos_ref[(pos_ref == 0).any(-1)] = np.nan + else: + pos_ref = aux["atom_positions"][0,:,1,:] + + if traj is None: traj = self._tmp["traj"] + sub_traj = {k:v[s:e] for k,v in traj.items()} + + align_xyz = self.protocol == "hallucination" + return make_animation(**sub_traj, pos_ref=pos_ref, length=self._lengths, + color_by=color_by, align_xyz=align_xyz, dpi=dpi) + + def plot_pdb(self, show_sidechains=False, show_mainchains=False, + color="pLDDT", color_HP=False, size=(800,480), animate=False, + get_best=True, aux=None, pdb_str=None): + ''' + use py3Dmol to plot pdb coordinates + - color=["pLDDT","chain","rainbow"] + ''' + if pdb_str is None: + pdb_str = self.save_pdb(get_best=get_best, aux=aux) + view = show_pdb(pdb_str, + show_sidechains=show_sidechains, + show_mainchains=show_mainchains, + color=color, + Ls=self._lengths, + color_HP=color_HP, + size=size, + animate=animate) + view.show() + + def plot_traj(self, dpi=100): + fig = plt.figure(figsize=(5,5), dpi=dpi) + gs = GridSpec(4,1, figure=fig) + ax1 = fig.add_subplot(gs[:3,:]) + ax2 = fig.add_subplot(gs[3:,:]) + ax1_ = ax1.twinx() + + if self.protocol in ["fixbb","partial"] or (self.protocol == "binder" and self._args["redesign"]): + if self.protocol == "partial" and self._args["use_sidechains"]: + rmsd = self.get_loss("sc_rmsd") + else: + rmsd = self.get_loss("rmsd") + for k in [0.5,1,2,4,8,16,32]: + ax1.plot([0,len(rmsd)],[k,k],color="lightgrey") + ax1.plot(rmsd,color="black") + seqid = self.get_loss("seqid") + ax1_.plot(seqid,color="green",label="seqid") + # axes labels + ax1.set_yscale("log") + ticks = [0.25,0.5,1,2,4,8,16,32,64] + ax1.set(xticks=[]) + ax1.set_yticks(ticks); ax1.set_yticklabels(ticks) + ax1.set_ylabel("RMSD",color="black");ax1_.set_ylabel("seqid",color="green") + ax1.set_ylim(0.25,64) + ax1_.set_ylim(0,0.8) + # extras + ax2.plot(self.get_loss("soft"),color="yellow",label="soft") + ax2.plot(self.get_loss("temp"),color="orange",label="temp") + ax2.plot(self.get_loss("hard"),color="red",label="hard") + ax2.set_ylim(-0.1,1.1) + ax2.set_xlabel("iterations") + ax2.legend(loc='center left') + else: + print("TODO") + plt.show() + + def clear_best(self): + self._tmp["best"] = {} + + def save_current_pdb(self, filename=None): + '''save pdb coordinates (if filename provided, otherwise return as string)''' + self.save_pdb(filename=filename, get_best=False) + + def plot_current_pdb(self, show_sidechains=False, show_mainchains=False, + color="pLDDT", color_HP=False, size=(800,480), animate=False): + '''use py3Dmol to plot pdb coordinates + - color=["pLDDT","chain","rainbow"] + ''' + self.plot_pdb(show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color, + color_HP=color_HP, size=size, animate=animate, get_best=False) \ No newline at end of file diff --git a/build/lib/colabdesign/af/weights/__init__.py b/build/lib/colabdesign/af/weights/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/build/lib/colabdesign/af/weights/__init__.py @@ -0,0 +1 @@ + diff --git a/build/lib/colabdesign/af/weights/template_dgram_head.npy b/build/lib/colabdesign/af/weights/template_dgram_head.npy new file mode 100644 index 00000000..3aa36daf Binary files /dev/null and b/build/lib/colabdesign/af/weights/template_dgram_head.npy differ diff --git a/build/lib/colabdesign/esm_msa/__init__.py b/build/lib/colabdesign/esm_msa/__init__.py new file mode 100644 index 00000000..940a2bf0 --- /dev/null +++ b/build/lib/colabdesign/esm_msa/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Levinthal, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .data import Alphabet, BatchConverter, FastaBatchedDataset # noqa +from .model import MSATransformer, RunModel # noqa +from . import pretrained # noqa +from . import config diff --git a/build/lib/colabdesign/esm_msa/axial_attention.py b/build/lib/colabdesign/esm_msa/axial_attention.py new file mode 100644 index 00000000..4f1a8f6b --- /dev/null +++ b/build/lib/colabdesign/esm_msa/axial_attention.py @@ -0,0 +1,200 @@ +# Copyright (c) Levinthal, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import jax +import math +import jax.numpy as jnp +import haiku as hk +from jax.nn import softmax + +from colabdesign.shared.prng import SafeKey + +class RowSelfAttention(hk.Module): + """Compute self-attention over rows of a 2D input.""" + + def __init__( + self, + config, + ): + super().__init__() + self.head_num = config.RowAtt.head_num + self.embed_dim = config.RowAtt.embed_dim + self.dropout = config.dropout + self.max_tokens_per_msa = config.max_tokens_per_msa + + self.head_dim = self.embed_dim // self.head_num + self.scaling = self.head_dim ** -0.5 + self.attn_shape = "hij" + + self.k_proj = hk.Linear(self.embed_dim, name="k_proj") + self.v_proj = hk.Linear(self.embed_dim, name="v_proj") + self.q_proj = hk.Linear(self.embed_dim, name="q_proj") + self.out_proj = hk.Linear(self.embed_dim, name="out_proj") + self.safe_key = SafeKey(hk.next_rng_key()) + + def align_scaling(self, q): + num_rows = q.shape[0] + return self.scaling / math.sqrt(num_rows) + + def _batched_forward( + self, + x, + self_attn_padding_mask, + ): + num_rows, num_cols, embed_dim = x.shape + max_rows = max(1, self.max_tokens_per_msa // num_cols) + attns = 0 + scaling = self.align_scaling(x) + for start in range(0, num_rows, max_rows): + attn_weights = self.compute_attention_weights( + x[start: start + max_rows], + scaling, + self_attn_padding_mask=self_attn_padding_mask[:, start: start + max_rows] + ) + attns += attn_weights + attn_probs = softmax(attns, -1) + self.safe_key, use_key = self.safe_key.split() + attn_probs = hk.dropout(use_key.get(), self.dropout, attn_probs) + + outputs = [] + for start in range(0, num_rows, max_rows): + output = self.compute_attention_update(x[start: start + max_rows], attn_probs) + outputs.append(output) + + output = jnp.concatenate(outputs, 0) + return output, attn_probs + + def compute_attention_weights( + self, + x, + scaling: float, + self_attn_padding_mask, + ): + num_rows, num_cols, embed_dim = x.shape + q = self.q_proj(x).reshape([num_rows, num_cols, self.head_num, self.head_dim]) + k = self.k_proj(x).reshape([num_rows, num_cols, self.head_num, self.head_dim]) + q *= scaling + # Zero out any padded aligned positions - this is important since + # we take a sum across the alignment axis. + q *= 1 - jnp.expand_dims(jnp.expand_dims(self_attn_padding_mask, 2), 3) + + attn_weights = jnp.einsum(f"rihd,rjhd->{self.attn_shape}", q, k) + attn_weights *= 1 - jnp.expand_dims(jnp.expand_dims(self_attn_padding_mask[0], 0), 2) + attn_weights += jnp.expand_dims(jnp.expand_dims(self_attn_padding_mask[0], 0), 2) * -10000 + + return attn_weights + + def compute_attention_update( + self, + x, + attn_probs, + ): + num_rows, num_cols, embed_dim = x.shape + v = self.v_proj(x).reshape([num_rows, num_cols, self.head_num, self.head_dim]) + context = jnp.einsum(f"{self.attn_shape},rjhd->rihd", attn_probs, v) + context = context.reshape([num_rows, num_cols, embed_dim]) + output = self.out_proj(context) + return output + + def __call__(self, x, + self_attn_padding_mask,): + + num_rows, num_cols, embed_dim = x.shape + if num_rows * num_cols > self.max_tokens_per_msa: + return self._batched_forward(x, self_attn_padding_mask) + else: + scaling = self.align_scaling(x) + attn_weights = self.compute_attention_weights( + x, scaling, self_attn_padding_mask + ) + attn_probs = softmax(attn_weights, -1) + self.safe_key, use_key = self.safe_key.split() + attn_probs = hk.dropout(use_key.get(), self.dropout, attn_probs) + output = self.compute_attention_update(x, attn_probs) + return output, attn_probs + + +class ColumnSelfAttention(hk.Module): + """Compute self-attention over columns of a 2D input.""" + + def __init__(self, config): + super().__init__() + self.head_num = config.ColAtt.head_num + self.embed_dim = config.RowAtt.embed_dim + self.dropout = config.dropout + self.max_tokens_per_msa = config.max_tokens_per_msa + + self.head_dim = self.embed_dim // self.head_num + self.scaling = self.head_dim ** -0.5 + self.safe_key = SafeKey(hk.next_rng_key()) + + self.k_proj = hk.Linear(self.embed_dim, name='k_proj') + self.v_proj = hk.Linear(self.embed_dim, name='v_proj') + self.q_proj = hk.Linear(self.embed_dim, name='q_proj') + self.out_proj = hk.Linear(self.embed_dim, name='out_proj') + + def _batched_forward( + self, + x, + self_attn_padding_mask, + ): + num_rows, num_cols, embed_dim = x.shape + max_cols = max(1, self.max_tokens_per_msa // num_rows) + outputs = [] + attns = [] + for start in range(0, num_cols, max_cols): + output, attn = self.compute_attention_update( + x[:, start: start + max_cols], + self_attn_padding_mask=self_attn_padding_mask[:, :, start: start + max_cols] + ) + outputs.append(output) + attns.append(attn) + output = jnp.concatenate(outputs, 1) + attns = jnp.concatenate(attns, 1) + return output, attns + + def compute_attention_update( + self, + x, + self_attn_padding_mask, + ): + num_rows, num_cols, embed_dim = x.shape + if num_rows == 1: + attn_probs = jnp.ones( + [self.head_num, num_cols, num_rows, num_rows], + dtype=x.dtype, + ) + output = self.out_proj(self.v_proj(x)) + return output, attn_probs + else: + q = self.q_proj(x).reshape([num_rows, num_cols, self.head_num, self.head_dim]) + k = self.k_proj(x).reshape([num_rows, num_cols, self.head_num, self.head_dim]) + v = self.v_proj(x).reshape([num_rows, num_cols, self.head_num, self.head_dim]) + q *= self.scaling + + attn_weights = jnp.einsum("ichd,jchd->hcij", q, k) + attn_weights *= 1 - jnp.expand_dims(jnp.expand_dims(self_attn_padding_mask.transpose(), 0), 3) + attn_weights += jnp.expand_dims(jnp.expand_dims(self_attn_padding_mask.transpose(), 0), 3) * -10000 + attn_probs = softmax(attn_weights, -1) + + self.safe_key, use_key = self.safe_key.split() + attn_probs = hk.dropout(use_key.get(), self.dropout, attn_probs) + context = jnp.einsum("hcij,jchd->ichd", attn_probs, v) + context = context.reshape([num_rows, num_cols, embed_dim]) + output = self.out_proj(context) + return output, attn_probs + + def __call__( + self, + x, + self_attn_padding_mask, + ): + # if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled(): + num_rows, num_cols, embed_dim = x.shape + if num_rows * num_cols > self.max_tokens_per_msa: + return self._batched_forward(x, self_attn_padding_mask) + else: + return self.compute_attention_update(x, self_attn_padding_mask) + diff --git a/build/lib/colabdesign/esm_msa/config.py b/build/lib/colabdesign/esm_msa/config.py new file mode 100644 index 00000000..99c19c81 --- /dev/null +++ b/build/lib/colabdesign/esm_msa/config.py @@ -0,0 +1,37 @@ +# Copyright 2021 Levinthal Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model config.""" + +import ml_collections + +CONFIG = ml_collections.ConfigDict({ + 'RowAtt': { + 'embed_dim': 768, + 'head_num': 12, + + }, + 'ColAtt': { + 'embed_dim': 768, + 'head_num': 12, + + }, + 'Ffn': { + 'embed_dim': 3072, + }, + 'dropout': 0.0, + 'max_tokens_per_msa': 2 ** 16, + 'layer_num': 12, + 'embed_dim': 768, + 'max_position': 1024, +}) diff --git a/build/lib/colabdesign/esm_msa/constants.py b/build/lib/colabdesign/esm_msa/constants.py new file mode 100644 index 00000000..c975e75d --- /dev/null +++ b/build/lib/colabdesign/esm_msa/constants.py @@ -0,0 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# fmt: off +proteinseq_toks = { + 'toks': ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-'] +} +# fmt: on diff --git a/build/lib/colabdesign/esm_msa/data.py b/build/lib/colabdesign/esm_msa/data.py new file mode 100644 index 00000000..263913c6 --- /dev/null +++ b/build/lib/colabdesign/esm_msa/data.py @@ -0,0 +1,294 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Sequence, Tuple, Union +import re +import numpy as np +import jax.numpy as jnp + +from .constants import proteinseq_toks + +RawMSA = Sequence[Tuple[str, str]] + + +class FastaBatchedDataset(object): + def __init__(self, sequence_labels, sequence_strs): + self.sequence_labels = list(sequence_labels) + self.sequence_strs = list(sequence_strs) + + @classmethod + def from_file(cls, fasta_file): + sequence_labels, sequence_strs = [], [] + cur_seq_label = None + buf = [] + + def _flush_current_seq(): + nonlocal cur_seq_label, buf + if cur_seq_label is None: + return + sequence_labels.append(cur_seq_label) + sequence_strs.append("".join(buf)) + cur_seq_label = None + buf = [] + + with open(fasta_file, "r") as infile: + for line_idx, line in enumerate(infile): + if line.startswith(">"): # label line + _flush_current_seq() + line = line[1:].strip() + if len(line) > 0: + cur_seq_label = line + else: + cur_seq_label = f"seqnum{line_idx:09d}" + else: # sequence line + buf.append(line.strip()) + + _flush_current_seq() + + assert len(set(sequence_labels)) == len(sequence_labels), "Found duplicate sequence labels" + + return cls(sequence_labels, sequence_strs) + + def __len__(self): + return len(self.sequence_labels) + + def __getitem__(self, idx): + return self.sequence_labels[idx], self.sequence_strs[idx] + + def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0): + sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)] + sizes.sort() + batches = [] + buf = [] + max_len = 0 + + def _flush_current_buf(): + nonlocal max_len, buf + if len(buf) == 0: + return + batches.append(buf) + buf = [] + max_len = 0 + + for sz, i in sizes: + sz += extra_toks_per_seq + if max(sz, max_len) * (len(buf) + 1) > toks_per_batch: + _flush_current_buf() + max_len = max(max_len, sz) + buf.append(i) + + _flush_current_buf() + return batches + + +class Alphabet(object): + def __init__( + self, + standard_toks: Sequence[str], + prepend_toks: Sequence[str] = ("", "", "", ""), + append_toks: Sequence[str] = ("", "", ""), + prepend_bos: bool = True, + append_eos: bool = False, + use_msa: bool = False, + ): + self.standard_toks = list(standard_toks) + self.prepend_toks = list(prepend_toks) + self.append_toks = list(append_toks) + self.prepend_bos = prepend_bos + self.append_eos = append_eos + self.use_msa = use_msa + + self.all_toks = list(self.prepend_toks) + self.all_toks.extend(self.standard_toks) + for i in range((8 - (len(self.all_toks) % 8)) % 8): + self.all_toks.append(f"") + self.all_toks.extend(self.append_toks) + + self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)} + + self.unk_idx = self.tok_to_idx[""] + self.padding_idx = self.get_idx("") + self.cls_idx = self.get_idx("") + self.mask_idx = self.get_idx("") + self.eos_idx = self.get_idx("") + + def __len__(self): + return len(self.all_toks) + + def get_idx(self, tok): + return self.tok_to_idx.get(tok, self.unk_idx) + + def get_tok(self, ind): + return self.all_toks[ind] + + def to_dict(self): + return {"toks": self.toks} + + def get_batch_converter(self): + if self.use_msa: + return MSABatchConverter(self) + else: + return BatchConverter(self) + + @classmethod + def from_dict(cls, d, **kwargs): + return cls(standard_toks=d["toks"], **kwargs) + + @classmethod + def from_architecture(cls, name: str) -> "Alphabet": + if name in ("ESM-1", "protein_bert_base"): + standard_toks = proteinseq_toks["toks"] + prepend_toks: Tuple[str, ...] = ("", "", "", "") + append_toks: Tuple[str, ...] = ("", "", "") + prepend_bos = True + append_eos = False + use_msa = False + elif name in ("ESM-1b", "roberta_large"): + standard_toks = proteinseq_toks["toks"] + prepend_toks = ("", "", "", "") + append_toks = ("",) + prepend_bos = True + append_eos = True + use_msa = False + elif name in ("MSA Transformer", "msa_transformer"): + standard_toks = proteinseq_toks["toks"] + prepend_toks = ("", "", "", "") + append_toks = ("",) + prepend_bos = True + append_eos = False + use_msa = True + else: + raise ValueError("Unknown architecture selected") + return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos, use_msa) + + +class BatchConverter(object): + """Callable to convert an unprocessed (labels + strings) batch to a + processed (labels + tensor) batch. + """ + + def __init__(self, alphabet): + self.alphabet = alphabet + + def __call__(self, raw_batch: Sequence[Tuple[str, str]], return_j=True): + # RoBERTa uses an eos token, while ESM-1 does not. + batch_size = len(raw_batch) + max_len = max(len(seq_str) for _, seq_str in raw_batch) + tokens_np = np.ones( + [ + batch_size, + max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos) + ], + dtype=np.int64 + ) * self.alphabet.padding_idx + + labels = [] + strs = [] + + for i, (label, seq_str) in enumerate(raw_batch): + labels.append(label) + strs.append(seq_str) + if self.alphabet.prepend_bos: + tokens_np[i, 0] = self.alphabet.cls_idx + seq = np.array([self.alphabet.get_idx(s) for s in seq_str], dtype=np.int64) + tokens_np[ + i, + int(self.alphabet.prepend_bos): len(seq_str) + int(self.alphabet.prepend_bos), + ] = seq + if self.alphabet.append_eos: + tokens_np[i, len(seq_str) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx + + if return_j: + tokens = jnp.array(tokens_np) + else: + tokens = tokens_np + + return labels, strs, tokens + + +class MSABatchConverter(BatchConverter): + def __call__(self, inputs: Union[Sequence[RawMSA], RawMSA], return_j=True): + if isinstance(inputs[0][0], str): + # Input is a single MSA + raw_batch: Sequence[RawMSA] = [inputs] # type: ignore + else: + raw_batch = inputs # type: ignore + + batch_size = len(raw_batch) + max_alignments = max(len(msa) for msa in raw_batch) + max_seqlen = max(len(msa[0][1]) for msa in raw_batch) + + tokens_np = np.ones( + [ + batch_size, + max_alignments, + max_seqlen + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos), + ], + dtype=np.int64, + ) * self.alphabet.padding_idx + + labels = [] + strs = [] + + for i, msa in enumerate(raw_batch): + msa_seqlens = set(len(seq) for _, seq in msa) + if not len(msa_seqlens) == 1: + raise RuntimeError( + "Received unaligned sequences for input to MSA, all sequence " + "lengths must be equal." + ) + msa_labels, msa_strs, msa_tokens = super().__call__(msa, return_j=False) + labels.append(msa_labels) + strs.append(msa_strs) + tokens_np[i, :msa_tokens.shape[0], :msa_tokens.shape[1]] = msa_tokens + + if return_j: + tokens = jnp.array(tokens_np) + else: + tokens = tokens_np + + return labels, strs, tokens + + +def read_fasta( + path, + keep_gaps=True, + keep_insertions=True, + to_upper=False, +): + with open(path, "r") as f: + for result in read_alignment_lines( + f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper + ): + yield result + + +def read_alignment_lines( + lines, + keep_gaps=True, + keep_insertions=True, + to_upper=False, +): + seq = desc = None + + def parse(s): + if not keep_gaps: + s = re.sub("-", "", s) + if not keep_insertions: + s = re.sub("[a-z]", "", s) + return s.upper() if to_upper else s + + for line in lines: + # Line may be empty if seq % file_line_width == 0 + if len(line) > 0 and line[0] == ">": + if seq is not None: + yield desc, parse(seq) + desc = line.strip() + seq = "" + else: + assert isinstance(seq, str) + seq += line.strip() + assert isinstance(seq, str) and isinstance(desc, str) + yield desc, parse(seq) diff --git a/build/lib/colabdesign/esm_msa/model.py b/build/lib/colabdesign/esm_msa/model.py new file mode 100644 index 00000000..90eff9d8 --- /dev/null +++ b/build/lib/colabdesign/esm_msa/model.py @@ -0,0 +1,141 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import joblib +import jax.numpy as jnp +import numpy as np +import haiku as hk +import jax + +from .modules import ( + AxialTransformerLayer, + EmbedPosition, + MSAPositionEmbedding, + ContactPredictionHead, + LmHead, +) + +from colabdesign.shared.prng import SafeKey + +class MSATransformer(hk.Module): + def __init__(self, alphabet, config): + super().__init__() + self.alphabet_size = len(alphabet) + self.padding_idx = alphabet.padding_idx + self.mask_idx = alphabet.mask_idx + self.cls_idx = alphabet.cls_idx + self.eos_idx = alphabet.eos_idx + self.prepend_bos = alphabet.prepend_bos + self.append_eos = alphabet.append_eos + self.config = config + self.dropout = config.dropout + + self.embed_tokens = hk.Embed( + vocab_size=self.alphabet_size, + embed_dim=self.config.embed_dim, + ) + + self.msa_position_embedding = MSAPositionEmbedding(self.config.embed_dim) + self.safe_key = SafeKey(hk.next_rng_key()) + + self.layers = [ + AxialTransformerLayer(self.config) + for _ in range(self.config.layer_num) + ] + + self.contact_head = ContactPredictionHead( + self.config.layer_num * self.config.RowAtt.head_num, + self.prepend_bos, + self.append_eos, + eos_idx=self.eos_idx, + ) + self.embed_positions = EmbedPosition( + self.config, + self.padding_idx, + ) + + self.emb_layer_norm_before = hk.LayerNorm(-1, create_scale=True, create_offset=True) + self.emb_layer_norm_after = hk.LayerNorm(-1, create_scale=True, create_offset=True) + + self.lm_head = LmHead( + config=self.config, + output_dim=self.alphabet_size, + weight=self.embed_tokens.embeddings.transpose(), + ) + + def __call__(self, tokens): + num_alignments, seqlen = tokens.shape + padding_mask = jnp.equal(tokens, self.padding_idx) # R, C + x = self.embed_tokens(tokens) + x += self.embed_positions(tokens) + x += self.msa_position_embedding(tokens) + x = self.emb_layer_norm_before(x) + + self.safe_key, use_key = self.safe_key.split() + x = hk.dropout(use_key.get(), self.dropout, x) + x = x * (1 - jnp.expand_dims(padding_mask, axis=-1)) + + row_attn_weights = [] + col_attn_weights = [] + + for layer_idx, layer in enumerate(self.layers): + x = layer( + x, + self_attn_padding_mask=padding_mask, + ) + x, col_attn, row_attn = x + col_attn_weights.append(col_attn) + row_attn_weights.append(row_attn) + + x = self.emb_layer_norm_after(x) + x = self.lm_head(x) + + result = {"logits": x} + # col_attentions: L x H x C x R x R + col_attentions = jnp.stack(col_attn_weights, 0) + # row_attentions: L x H x C x C + row_attentions = jnp.stack(row_attn_weights, 0) + result["col_attentions"] = col_attentions + result["row_attentions"] = row_attentions + contacts = self.contact_head(tokens, row_attentions) + result["contacts"] = contacts + + return result + + +class RunModel: + '''container for msa transformer''' + + def __init__(self, alphabet, config): + self.padding_idx = alphabet.padding_idx + + def _forward(tokens): + model = MSATransformer(alphabet, config) + return model(tokens) + + _forward_t = hk.transform(_forward) + self.init = jax.jit(_forward_t.init) + self.apply = jax.jit(_forward_t.apply) + self.key = jax.random.PRNGKey(42) + + def load_params(self, path): + self.params = joblib.load(path) + + def __call__(self, tokens): + assert tokens.ndim == 2 + num_alignments, seqlen = tokens.shape + + if num_alignments > 1024: + raise RuntimeError( + "Using model with MSA position embedding trained on maximum MSA " + f"depth of 1024, but received {num_alignments} alignments." + ) + + self.key, use_key = jax.random.split(self.key) + result = self.apply(self.params, use_key, tokens) + result_new = {} + for ikey in result.keys(): + result_new[ikey] = np.array(result[ikey]) + return result_new diff --git a/build/lib/colabdesign/esm_msa/modules.py b/build/lib/colabdesign/esm_msa/modules.py new file mode 100644 index 00000000..ca6ba98a --- /dev/null +++ b/build/lib/colabdesign/esm_msa/modules.py @@ -0,0 +1,229 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import haiku as hk +import jax +import jax.numpy as jnp + +from .axial_attention import ColumnSelfAttention, RowSelfAttention +from colabdesign.shared.prng import SafeKey + + + +def symmetrize(x): + "Make layer symmetric in final two dimensions, used for contact prediction." + return x + x.transpose([0, 2, 1]) + + +def apc(x): + "Perform average product correct, used for contact prediction." + a1 = x.sum(-1, keepdims=True) + a2 = x.sum(-2, keepdims=True) + a12 = x.sum((-1, -2), keepdims=True) + + avg = a1 * a2 + avg = avg / a12 + normalized = x - avg + return normalized + + +class AxialTransformerLayer(hk.Module): + """Implements an Axial MSA Transformer block.""" + + def __init__( + self, + config, + ) -> None: + super().__init__() + self.config = config + + row_self_attention = RowSelfAttention(config) + column_self_attention = ColumnSelfAttention(config) + feed_forward_layer = FeedForwardNetwork(config) + + self.row_self_attention = self.build_residual(row_self_attention, name='row_self_attention') + self.column_self_attention = self.build_residual(column_self_attention, name='column_self_attention') + self.feed_forward_layer = self.build_residual(feed_forward_layer, name='feed_forward_layer') + + def build_residual(self, layer: hk.Module, name=None): + return NormalizedResidualBlock( + layer, + self.config, + name=name, + ) + + def __call__( + self, + x, + self_attn_padding_mask, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer implementation. + """ + x, row_attn = self.row_self_attention( + x, + self_attn_padding_mask=self_attn_padding_mask, + ) + x, column_attn = self.column_self_attention( + x, + self_attn_padding_mask=self_attn_padding_mask, + ) + x = self.feed_forward_layer(x) + return x, column_attn, row_attn + + +class LmHead(hk.Module): + def __init__(self, config, output_dim, weight): + super().__init__() + self.layer_norm = hk.LayerNorm(-1, create_scale=True, create_offset=True) + self.dense = hk.Linear(config.embed_dim, name='dense') + self.weight = weight + self.bias = hk.get_parameter(name='bias', shape=[output_dim], init=jnp.zeros) + + def __call__(self, input): + x = self.dense(input) + x = jax.nn.gelu(x) + x = self.layer_norm(x) + x = jnp.dot(x, self.weight) + self.bias + return x + + +class ContactPredictionHead(hk.Module): + """Performs symmetrization, apc, and computes a logistic regression on the output features""" + + def __init__( + self, + in_features: int, + prepend_bos: bool, + append_eos: bool, + bias=True, + eos_idx: Optional[int] = None, + ): + super().__init__() + self.in_features = in_features + self.prepend_bos = prepend_bos + self.append_eos = append_eos + self.eos_idx = eos_idx + self.regression = hk.Linear(1, with_bias=bias) + self.activation = jax.nn.sigmoid + + def __call__(self, tokens, attentions): + # remove eos token attentions + if self.append_eos: + eos_mask = jnp.not_equal(tokens, self.eos_idx) + eos_mask = jnp.expand_dims(eos_mask, axis=0) * jnp.expand_dims(eos_mask, axis=1) + attentions = attentions * eos_mask[None, None, :, :] + attentions = attentions[..., :-1, :-1] + + # remove cls token attentions + if self.prepend_bos: + attentions = attentions[..., 1:, 1:] + + layers, heads, seqlen, _ = attentions.shape + attentions = attentions.reshape([layers * heads, seqlen, seqlen]) + + # features: C x T x T + attentions = apc(symmetrize(attentions)) + attentions = attentions.transpose([1, 2, 0]) + return self.activation(self.regression(attentions).squeeze(2)) + + +class NormalizedResidualBlock(hk.Module): + def __init__( + self, + layer: hk.Module, + config, + name=None, + ): + super().__init__(name=name) + self.embed_dim = config.embed_dim + self.dropout = config.dropout + self.safe_key = SafeKey(hk.next_rng_key()) + + self.layer = layer + self.layer_norm = hk.LayerNorm(-1, create_scale=True, create_offset=True) + + def __call__(self, x, *args, **kwargs): + residual = x + x = self.layer_norm(x) + outputs = self.layer(x, *args, **kwargs) + if isinstance(outputs, tuple): + x, *out = outputs + else: + x = outputs + out = None + + self.safe_key, use_key = self.safe_key.split() + x = hk.dropout(use_key.get(), self.dropout, x) + x = residual + x + + if out is not None: + return (x,) + tuple(out) + else: + return x + + +class FeedForwardNetwork(hk.Module): + def __init__( + self, + config, + ): + super().__init__() + self.embed_dim = config.embed_dim + self.ffn_embed_dim = config.Ffn.embed_dim + self.max_tokens_per_msa = config.max_tokens_per_msa + self.dropout = config.dropout + + self.safe_key = SafeKey(hk.next_rng_key()) + self.activation_fn = jax.nn.gelu + + self.fc1 = hk.Linear(self.ffn_embed_dim, name='fc1') + self.fc2 = hk.Linear(self.embed_dim, name='fc2') + + def __call__(self, x): + x = self.activation_fn(self.fc1(x)) + self.safe_key, use_key = self.safe_key.split() + x = hk.dropout(use_key.get(), self.dropout, x) + x = self.fc2(x) + return x + + +class MSAPositionEmbedding(hk.Module): + def __init__(self, embed_dim): + super().__init__() + self.embed_dim = embed_dim + self.weight = hk.get_parameter(name='data', + shape=[1024, 1, embed_dim], + init=jnp.zeros) + + def __call__(self, x): + # num_alignments, seq_len = x.shape + num_rows, num_cols = x.shape + return self.weight[:num_rows] + + +class EmbedPosition(hk.Module): + def __init__(self, config, padding_idx): + super().__init__() + self.max_position = config.max_position + self.embed_dim = config.embed_dim + self.padding_idx = padding_idx + self.max_position_ = self.max_position + self.padding_idx + 1 + self.embed = hk.Embed(vocab_size=self.max_position_, + embed_dim=self.embed_dim) + + def __call__(self, tokens): + mask = jnp.not_equal(tokens, self.padding_idx) + # tokens always begin with , do not consider. is before in alphabet. + positions = jnp.cumsum(mask, axis=-1, dtype='int32') * mask + self.padding_idx + + # position_oh = jax.nn.one_hot(positions, self.max_position_) + # weight = hk.get_parameter('weight', shape=[self.max_position_, self.embed_dim], init=jnp.zeros) + # x = jnp.dot(position_one_hot, weight) + # return x + return self.embed(positions) diff --git a/build/lib/colabdesign/esm_msa/pretrained.py b/build/lib/colabdesign/esm_msa/pretrained.py new file mode 100644 index 00000000..89dd4773 --- /dev/null +++ b/build/lib/colabdesign/esm_msa/pretrained.py @@ -0,0 +1,10 @@ +from colabdesign import esm_msa +import joblib + + +def get_model(): + alphabet = esm_msa.Alphabet.from_architecture('msa_transformer') + config = esm_msa.config.CONFIG + model = esm_msa.RunModel(alphabet, config) + + return model, alphabet diff --git a/build/lib/colabdesign/mpnn/__init__.py b/build/lib/colabdesign/mpnn/__init__.py new file mode 100644 index 00000000..0b3b65a7 --- /dev/null +++ b/build/lib/colabdesign/mpnn/__init__.py @@ -0,0 +1,10 @@ +import os,jax +# disable triton_gemm for jax versions > 0.3 +if int(jax.__version__.split(".")[1]) > 3: + os.environ["XLA_FLAGS"] = "--xla_gpu_enable_triton_gemm=false" + +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) + +from colabdesign.shared.utils import clear_mem +from .model import mk_mpnn_model \ No newline at end of file diff --git a/build/lib/colabdesign/mpnn/ensemble_model.py b/build/lib/colabdesign/mpnn/ensemble_model.py new file mode 100644 index 00000000..5d3ef258 --- /dev/null +++ b/build/lib/colabdesign/mpnn/ensemble_model.py @@ -0,0 +1,582 @@ +""" +ProteinMPNN adapted to molecular dynamics trajectories. + +""" + +import jax +import jax.numpy as jnp +import numpy as np +import re +import copy +import random +import os +import joblib +from functools import partial + +from .modules import RunModel +from .mdtraj_io import prep_from_mdtraj + +from scipy.special import softmax, log_softmax + +from colabdesign.shared.prep import prep_pos +from colabdesign.shared.utils import Key, copy_dict +from colabdesign.shared.chunked_vmap import vmap_chunked as cvmap + +# borrow some stuff from AfDesign +from colabdesign.af.prep import prep_pdb +from colabdesign.af.alphafold.common import protein, residue_constants +aa_order = residue_constants.restype_order +order_aa = {b:a for a,b in aa_order.items()} + +class mk_mpnn_ensemble_model(): + def __init__(self, model_name="v_48_020", + backbone_noise=0.0, dropout=0.0, + seed=None, verbose=False, weights="original", # weights can be set to either original or soluble + batch_size=500): + + # load model + if weights == "original": + from .weights import __file__ as mpnn_path + elif weights == "soluble": + from .weights_soluble import __file__ as mpnn_path + else: + raise ValueError(f'Invalid value {weights} supplied for weights. Value must be either "original" or "soluble".') + + path = os.path.join(os.path.dirname(mpnn_path), f'{model_name}.pkl') + checkpoint = joblib.load(path) + config = {'num_letters': 21, + 'node_features': 128, + 'edge_features': 128, + 'hidden_dim': 128, + 'num_encoder_layers': 3, + 'num_decoder_layers': 3, + 'augment_eps': backbone_noise, + 'k_neighbors': checkpoint['num_edges'], + 'dropout': dropout} + + self._model = RunModel(config) + self._model.params = jax.tree_util.tree_map(np.array, checkpoint['model_state_dict']) + self.batch_size = batch_size + self.set_seed(seed) + self._num = 1 + self._inputs = {} + self._tied_lengths = False + self._setup() + + def prep_inputs( + self, + traj=None, + chain=None, + homooligomer=False, + fix_pos=None, + inverse=False, + rm_aa=None, + verbose=False, + **kwargs, + ): + """Get inputs from an MDTraj object.""" + if traj is not None: + traj = prep_from_mdtraj(traj, chain=chain,) + else: + raise ValueError( + "One of 'mdtraj_frame', 'pdb_filename', or 'pdb_string' must be provided." + ) + + # atom idx + atom_idx = tuple(residue_constants.atom_order[k] for k in ["N", "CA", "C", "O"]) + chain_idx = np.concatenate([[n] * l for n, l in enumerate(traj["lengths"])]) + self._lengths = traj["lengths"] + L = sum(self._lengths) + + self._inputs = { + "X": traj["batch"]["all_atom_positions"][:, :, atom_idx], # atom_idx_moved_back_one + "mask": traj["batch"]["all_atom_mask"][:, 1], + "S": traj["batch"]["aatype"], + "residue_idx": traj["residue_index"], + "chain_idx": chain_idx, + "lengths": np.array(self._lengths), + "bias": np.zeros((L, 20)), + } + + if rm_aa is not None: + for aa in rm_aa.split(","): + self._inputs["bias"][..., aa_order[aa]] -= 1e6 + + if fix_pos is not None: + p = prep_pos(fix_pos, **traj["idx"])["pos"] + if inverse: + p = np.delete(np.arange(L), p) + self._inputs["fix_pos"] = p + self._inputs["bias"][p] = 1e7 * np.eye(21)[self._inputs["S"]][p, :20] + + if homooligomer: + #raise NotImplementedError("'homooligomer=True' not yet implemented") + assert min(self._lengths) == max(self._lengths) + self._tied_lengths = True + self._len = self._lengths[0] + else: + self._tied_lengths = False + self._len = sum(self._lengths) + + self.traj = traj + + if verbose: + print("lengths", self._lengths) + if "fix_pos" in self._inputs: + print("the following positions will be fixed:") + print(self._inputs["fix_pos"]) + + def get_af_inputs(self, af): + '''get inputs from alphafold model''' + + self._lengths = af._lengths + self._len = af._len + + self._inputs["residue_idx"] = af._inputs["residue_index"] + self._inputs["chain_idx"] = af._inputs["asym_id"] + self._inputs["lengths"] = np.array(self._lengths) + + # set bias + L = sum(self._lengths) + self._inputs["bias"] = np.zeros((L,20)) + self._inputs["bias"][-af._len:] = af._inputs["bias"] + + if "offset" in af._inputs: + self._inputs["offset"] = af._inputs["offset"] + + if "batch" in af._inputs: + atom_idx = tuple(residue_constants.atom_order[k] for k in ["N","CA","C","O"]) + batch = af._inputs["batch"] + self._inputs["X"] = batch["all_atom_positions"][:,atom_idx] + self._inputs["mask"] = batch["all_atom_mask"][:,1] + self._inputs["S"] = batch["aatype"] + + # fix positions + if af.protocol == "binder": + p = np.arange(af._target_len) + else: + p = af.opt.get("fix_pos",None) + + if p is not None: + self._inputs["fix_pos"] = p + self._inputs["bias"][p] = 1e7 * np.eye(21)[self._inputs["S"]][p,:20] + + # tie positions + if af._args["homooligomer"]: + assert min(self._lengths) == max(self._lengths) + self._tied_lengths = True + else: + self._tied_lengths = False + + def sample(self, temperature=0.1, rescore=False, **kwargs): + '''Sample one sequence for each conformer''' + I = copy_dict(self._inputs) + I.update(kwargs) + key = I.pop("key",self.key()) + keys = jax.random.split(key,1) + O = self._sample_conformers(keys, I, temperature, self._tied_lengths) + if rescore: + O = self._rescore_parallel(keys, I, O["S"], O["decoding_order"]) + + # must squeeze here, unlike regular model + O = jax.tree_util.tree_map(lambda x: jnp.squeeze(jnp.array(x)), O) + + # process outputs to human-readable form + O.update(self._get_seq(O)) + O.update(self._get_score(I,O)) + return O + + def sample_minimal(self, temperature=0.1, rescore=False, **kwargs): + '''Sample one sequence for each conformer''' + I = copy_dict(self._inputs) + I.update(kwargs) + key = I.pop("key",self.key()) + keys = jax.random.split(key,1) + O = self._sample_conformers(keys, I, temperature, self._tied_lengths) + + # must squeeze here, unlike regular model + O = jax.tree_util.tree_map(lambda x: jnp.squeeze(jnp.array(x)), O) + return O + + def sample_parallel(self, batch=10, temperature=0.1, rescore=False, **kwargs): + '''sample new sequence(s) in parallel + NOT IMPLEMENTED + + ''' + if batch != 1: + raise NotImplementedError("Batched sampling not implemented for conformational ensembles.") + + + def _get_seq(self, O): + """one_hot to amino acid sequence (still returns Python strings)""" + + def split_seq(seq_str, lengths, tied_lengths): # pass lengths and tied_lengths explicitly + if len(lengths) > 1: + # This string manipulation cannot be JITted. + # If this were inside a JITted function, it would be a host callback. + seq_str = "".join(np.insert(list(seq_str), np.cumsum(lengths[:-1]), "/")) + if tied_lengths: + seq_str = seq_str.split("/")[0] + return seq_str + + seqs = [] + # Assuming O["S"] is (batch, L, 21) or (L, 21) + # Convert JAX array to NumPy for iteration and string conversion + S_numpy = np.array(O["S"].argmax(axis=1)).T # modified axis=1 and transposed for ensemble + if S_numpy.ndim == 1: + S_numpy = S_numpy[None, :] # ensure batch dimension + + for s_np in S_numpy: + # This part is Python string manipulation + seq = "".join([order_aa[a_idx] for a_idx in s_np]) + seq = split_seq(seq, self._lengths, self._tied_lengths) # pass necessary attributes + seqs.append(seq) + return {"seq": np.array(seqs)} + + def _get_score(self, I, O): + ''' + logits to score/sequence_recovery + return {"score":score (L, n_frames), "seqid":seqid (L, n_frames)} + ''' + # this is reasonably fast, even without jax + mask = I["mask"].copy() + if "fix_pos" in I: + mask[I["fix_pos"]] = 0 + + mask = np.expand_dims(mask, -1) + + # softmaxes are now mapped over axis=1 + log_q = log_softmax(O["logits"], axis=1)[:,:20,:] + q = softmax(O["logits"][:,:20,:], axis=1) + + # sums are over axis 0 + if "S" in O: + S = O["S"][:,:20,:] + score = -(S * log_q).sum(axis=1) + seqid = S.argmax(axis=1) == np.expand_dims(self._inputs["S"], -1) + else: + score = -(q * log_q).sum(axis=0) + seqid = np.zeros_like(score) + + score = (score * mask).sum(axis=0) / (mask.sum() + 1e-8) + seqid = (seqid * mask).sum(axis=0) / (mask.sum() + 1e-8) + + return {"score":score, "seqid":seqid} + + + + def _get_score_jax(self, inputs_S, logits, mask, fix_pos=None): # Pass necessary inputs directly + ''' logits to score/sequence_recovery - JAX compatible version ''' + + + current_mask = mask + if fix_pos is not None and fix_pos.shape[0] > 0: # Ensure fix_pos is not empty + # Ensure mask is a JAX array for .at.set to work + current_mask = jnp.array(current_mask) # Convert if it's numpy + current_mask = current_mask.at[fix_pos].set(0) + + # Use jax.nn functions + log_q = jax.nn.log_softmax(logits, axis=-1)[..., :20] + q = jax.nn.softmax(logits[..., :20], axis=-1) + + S_scored_one_hot = jax.nn.one_hot(inputs_S, num_classes=21)[...,:20] # Assuming inputs_S is integer encoded for the sequence to score + # This would be I["S"] from the score() method + + score = -(S_scored_one_hot * log_q).sum(-1) + + seqid = (inputs_S == self._inputs["S"]) + + masked_score_sum = (score * current_mask).sum(-1) + masked_seqid_sum = (seqid * current_mask).sum(-1) + mask_sum = current_mask.sum() + 1e-8 + + final_score = masked_score_sum / mask_sum + final_seqid = masked_seqid_sum / mask_sum + + return {"score": final_score, "seqid": final_seqid} + + + def score(self, seq_numeric=None, **kwargs): # seq_numeric is an integer array + '''score sequence - JAX compatible version (mostly)''' + current_inputs = jax.tree_util.tree_map(jnp.array, self._inputs) + + if seq_numeric is not None: + # seq_numeric is expected to be an integer array of amino acid indices + p = jnp.arange(current_inputs["S"].shape[0]) + s_shape_0 = current_inputs["S"].shape[0] # Store shape for JAX tracing + + if self._tied_lengths and seq_numeric.shape[0] == self._lengths[0]: + # Assuming self._lengths is available and compatible + # seq_numeric might need tiling if it represents one chain of a homooligomer + num_repeats = len(self._lengths) + seq_numeric = jnp.tile(seq_numeric, num_repeats) + + if "fix_pos" in current_inputs and current_inputs["fix_pos"].shape[0] > 0: + # Ensure shapes are concrete or JAX can trace them + if seq_numeric.shape[0] == (s_shape_0 - current_inputs["fix_pos"].shape[0]): + p = jnp.delete(p, current_inputs["fix_pos"], axis=0) + + # Update S using .at[].set() + # Ensure seq_numeric is correctly broadcasted or indexed if p is tricky + current_inputs["S"] = current_inputs["S"].at[p].set(seq_numeric) + + # Combine kwargs with current_inputs, ensuring JAX types + for k, v in kwargs.items(): + current_inputs[k] = jnp.asarray(v) if not isinstance(v, jax.Array) else v + + + key_to_use = current_inputs.pop("key", self.key()) # self.key() provides a JAX key + + # _score is already JITted and expects JAX-compatible inputs + # The arguments to _score are X, mask, residue_idx, chain_idx, key, S, bias, decoding_order etc. + # Ensure all these are present in current_inputs and are JAX arrays. + + # Prepare arguments for self._score, ensuring they are all JAX arrays + score_fn_args = {k: current_inputs[k] for k in [ + 'X', 'mask', 'residue_idx', 'chain_idx', 'S', 'bias' + ] if k in current_inputs} + + if "decoding_order" in current_inputs: + score_fn_args["decoding_order"] = current_inputs["decoding_order"] + if "fix_pos" in current_inputs: # _score uses fix_pos to adjust decoding_order + score_fn_args["fix_pos"] = current_inputs["fix_pos"] + + + # O will be a dictionary of JAX arrays + O = self._score(**score_fn_args, key=key_to_use) + + # Call the JAX-compatible _get_score + # It needs: current_inputs["S"] (the sequence being scored, possibly modified), + # O["logits"], current_inputs["mask"], and current_inputs.get("fix_pos") + score_info = self._get_score( + inputs_S=current_inputs["S"], # This is the S that was actually scored by _score + logits=O["logits"], + mask=current_inputs["mask"], + fix_pos=current_inputs.get("fix_pos") + ) + O.update(score_info) # O remains a dict of JAX arrays + + # If you need to convert to NumPy arrays for external use, do it here, + # but the function itself now primarily deals with JAX arrays. + # For full JAX compatibility of `score` itself (e.g. to JIT it), + # this conversion should be outside. + # return jax.tree_map(np.array, O) + return O # Returns dict of JAX array + + def get_logits(self, **kwargs): + '''get logits''' + return self.score(**kwargs)["logits"] + + def get_unconditional_logits(self, **kwargs): + L = self._inputs["X"].shape[0] + kwargs["ar_mask"] = np.zeros((L,L)) + return self.score(**kwargs)["logits"] + + def set_seed(self, seed=None): + np.random.seed(seed=seed) + self.key = Key(seed=seed).get + + def _setup(self): + def _score_internal( + X, mask, residue_idx, chain_idx, key, S, bias, **kwargs + ): # Added S and bias + I = { + "X": X, + "mask": mask, + "residue_idx": residue_idx, + "chain_idx": chain_idx, + "S": S, # Pass S + "bias": bias, # Pass bias + } + I.update(kwargs) + + if "decoding_order" not in I: + key, sub_key = jax.random.split(key) + randn = _randomize_sophie(sub_key, X) + randn = jnp.where(I["mask"], randn, randn + 1) + if "fix_pos" in I and I["fix_pos"].shape[0] > 0: # check if fix_pos is not empty + randn = randn.at[I["fix_pos"]].add(-1) + I["decoding_order"] = randn.argsort() + + # _aa_convert is JAX-compatible + for k_item in ["S", "bias"]: # Use k_item to avoid conflict with key + if k_item in I: + I[k_item] = _aa_convert(I[k_item]) + + output_dict = self._model.score(self._model.params, key, I) + output_dict["S"] = _aa_convert(output_dict["S"], rev=True) + output_dict["logits"] = _aa_convert(output_dict["logits"], rev=True) + return output_dict + + self._score = jax.jit(_score_internal) + + def _sample_internal( + X, + mask, + residue_idx, + chain_idx, + key, + temperature=0.1, + tied_lengths=False, + bias=None, + **kwargs, + ): # added bias + # single conformer sampling + I = { + "X": X, + "mask": mask, + "residue_idx": residue_idx, + "chain_idx": chain_idx, + "temperature": temperature, + "bias": bias, # Pass bias + } + I.update(kwargs) + + # define decoding order (as in original _sample) + if "decoding_order" in I: + if I["decoding_order"].ndim == 1: + I["decoding_order"] = I["decoding_order"][:, None] + else: + key, sub_key = jax.random.split(key) + #randn = jax.random.uniform(sub_key, (I["X"].shape[0],)) + randn = _randomize_sophie(sub_key, X) + + + + randn = jnp.where(I["mask"], randn, randn + 1) + if "fix_pos" in I and I["fix_pos"].shape[0] > 0: + randn = randn.at[I["fix_pos"]].add(-1) + if tied_lengths: + copies = I["lengths"].shape[0] + decoding_order_tied = randn.reshape(copies, -1).mean(0).argsort() + I["decoding_order"] = ( + jnp.arange(I["X"].shape[0]).reshape(copies, -1).T[decoding_order_tied] + ) + else: + I["decoding_order"] = randn.argsort()[:, None] + + # S is not an input to _model.sample, but bias is + if "S" in I: + I["S"] = _aa_convert( + I["S"] + ) # If S is somehow passed (e.g. for conditioning, though MPNN typically doesn't) + if "bias" in I: + I["bias"] = _aa_convert(I["bias"]) + + O_dict = self._model.sample(self._model.params, key, I) + O_dict["S"] = _aa_convert(O_dict["S"], rev=True) # This is the sampled S + O_dict["logits"] = _aa_convert(O_dict["logits"], rev=True) + return O_dict + + self._sample = jax.jit(_sample_internal, static_argnames=["tied_lengths"]) + + # + def _vmap_sample_seqs_from_conformers(key, inputs, temperature, tied_lengths): + inputs_copy = dict(inputs) # Shallow copy for modification + inputs_copy.pop("temperature", None) + inputs_copy.pop("key", None) + # Ensure 'bias' is correctly handled if it's part of 'inputs' + f_of_X = jax.jit( + partial(self._sample, key=key, + **{k : v for k,v in inputs_copy.items() if k not in ("X",)}, + temperature=temperature, tied_lengths=tied_lengths), static_argnames=["tied_lengths"] + ) + # vmap over positions + return cvmap(f_of_X, chunk_size=min(self.batch_size, inputs_copy["X"].shape[0]))(inputs_copy["X"]) + + # this is vmap over keys, but there's only one. + # this is just the easiest way to square with earliest code + # but might be good to refactor + fn_vmap_sample_conformers = jax.vmap(_vmap_sample_seqs_from_conformers, in_axes=[0, None, None, None]) + # difference, no jit for now + self._sample_conformers = fn_vmap_sample_conformers + + @jax.jit + def _vmap_rescore_parallel(key, inputs, S_rescore, decoding_order_rescore): + inputs_copy = dict(inputs) # Shallow copy + inputs_copy.pop("S", None) + inputs_copy.pop("decoding_order", None) + inputs_copy.pop("key", None) + # Ensure 'bias' from original inputs is used, and S_rescore is the new S + return self._score( + **inputs_copy, key=key, S=S_rescore, decoding_order=decoding_order_rescore + ) # Pass S and decoding_order + fn_vmap_rescore = jax.vmap(_vmap_rescore_parallel, in_axes=[0, None, 0, 0]) + self._rescore_parallel = fn_vmap_rescore + +####################################################################################### + +def _aa_convert(x, rev=False): + mpnn_alphabet = 'ACDEFGHIKLMNPQRSTVWYX' + af_alphabet = 'ARNDCQEGHILKMFPSTWYVX' + if x is None: + return x + else: + if rev: + return x[...,tuple(mpnn_alphabet.index(k) for k in af_alphabet)] + else: + x = jax.nn.one_hot(x,21) if jnp.issubdtype(x.dtype, jnp.integer) else x + if x.shape[-1] == 20: + x = jnp.pad(x,[[0,0],[0,1]]) + return x[...,tuple(af_alphabet.index(k) for k in mpnn_alphabet)] + +unknown_aa_index = aa_order.get('X', 20) # Default index for unknown AAs + +def convert_sequence_to_numeric(sequence_str: str, + aa_map: dict = aa_order, + all_chain_lengths: list = None, + is_homooligomer_tied: bool = False) -> jnp.array: + """ + Converts a protein sequence string into a JAX integer array. + + Args: + sequence_str: The amino acid sequence string. + - For monomers: "ACEG..." + - For heteromers (chains separated by '/'): "ACEG.../FGHI..." + - For homooligomers where is_homooligomer_tied is True and + only one chain's sequence is provided: "ACEG..." (will be tiled). + aa_map: Dictionary mapping amino acid characters to integers (e.g., aa_order). + all_chain_lengths: List of lengths of all chains in the complex. + Example: [100, 100] for a dimer of length 100 each. + Used for homooligomer tiling. + is_homooligomer_tied: Boolean. If True and sequence_str is for a single + chain of a homooligomer, the sequence will be tiled. + + Returns: + jnp.array: A JAX array of integers representing the full sequence. + """ + numeric_sequence_list = [] + + # Handle homooligomer case where a single chain sequence is provided to be tiled + if is_homooligomer_tied and \ + all_chain_lengths and \ + len(all_chain_lengths) > 0 and \ + "/" not in sequence_str: + # Check if the provided sequence string matches the length of one chain + if len(sequence_str) == all_chain_lengths[0]: # Assuming all chains have the same length + num_chains = len(all_chain_lengths) + # Tile the string sequence before converting to numeric + sequence_str = "/".join([sequence_str] * num_chains) + # TODO: add a warning or error if the lengths don't match + + # Process chain by chain if '/' is present, otherwise process the whole string + chains = sequence_str.split('/') + + for chain_seq_str in chains: + for aa_char in chain_seq_str: + # Use .get(key, default_value) to handle unexpected characters + numeric_sequence_list.append(aa_map.get(aa_char, unknown_aa_index)) + + return jnp.array(numeric_sequence_list, dtype=jnp.int32) + +### >:D >:D >:D >:D >:D >:D >:D >:D # >:D >:D >:D >:D >:D >:D >:D >:D # >:D >:D >:D >:D +def _randomize_sophie(key, X_conformer, max_freq=1e9, min_freq=1e3): + """ + WARNING: EXPERIMENTAL + + Use X_conformer as a natural entropy source to randomize decoding order + by transforming spatial coordinates with a random-frequency sine wave. + """ + randfreq = (max_freq - min_freq)*jax.random.uniform(key) + min_freq + randn = 0.5*(1 + jnp.sin(X_conformer * randfreq).sum(axis=(1,2))) + return randn \ No newline at end of file diff --git a/build/lib/colabdesign/mpnn/mdtraj_io.py b/build/lib/colabdesign/mpnn/mdtraj_io.py new file mode 100644 index 00000000..da4f03e5 --- /dev/null +++ b/build/lib/colabdesign/mpnn/mdtraj_io.py @@ -0,0 +1,106 @@ +import jax +import numpy as np + +from colabdesign.af.alphafold.common import residue_constants +from colabdesign.shared.protein import _np_get_cb + +order_aa = {b: a for a, b in residue_constants.restype_order.items()} +aa_order = residue_constants.restype_order + + +def prep_from_mdtraj(traj, chain=None, **kwargs): + """ + Extracts features directly from an mdtraj.Trajectory object. + """ + + chains_to_process = [] + if chain is None: + chains_to_process = list(traj.topology.chains) + else: + requested_chain_ids = list(chain) + for c in traj.topology.chains: + if c.chain_id in requested_chain_ids: + chains_to_process.append(c) + + all_chains_data = [] + last_res_idx = 0 + full_lengths = [] + + for chain_obj in chains_to_process: + chain_id = chain_obj.chain_id + atom_indices = [a.index for a in chain_obj.atoms] + + chain_top = traj.topology.subset(atom_indices) + chain_xyz = traj.xyz[:, atom_indices, :] * 10 # Convert nm to Angstroms + n_res = chain_top.n_residues + + all_atom_positions = np.zeros((traj.n_frames, n_res, residue_constants.atom_type_num, 3)) + all_atom_mask = np.zeros((n_res, residue_constants.atom_type_num)) + aatype = np.zeros(n_res, dtype=int) + residue_index = np.zeros(n_res, dtype=int) + + for res_idx, residue in enumerate(chain_top.residues): + res_name = residue.name + aatype[res_idx] = residue_constants.resname_to_idx.get( + res_name, residue_constants.resname_to_idx["UNK"] + ) + residue_index[res_idx] = residue.resSeq + + for atom in residue.atoms: + if atom.name in residue_constants.atom_order: + atom_type_idx = residue_constants.atom_order[atom.name] + chain_atom_index = next( + a.index for a in chain_top.atoms if a.serial == atom.serial + ) + all_atom_positions[:,res_idx, atom_type_idx] = chain_xyz[:,chain_atom_index] + all_atom_mask[res_idx, atom_type_idx] = 1 + + batch = { + "aatype": aatype, + "all_atom_positions": all_atom_positions, + "all_atom_mask": all_atom_mask, + } + + p, m = batch["all_atom_positions"], batch["all_atom_mask"] + atom_idx = residue_constants.atom_order + atoms = {k: p[..., atom_idx[k], :] for k in ["N", "CA", "C",]} + + cb_atoms = _np_get_cb(**atoms, use_jax=False) + cb_mask = np.prod([m[..., atom_idx[k]] for k in ["N", "CA", "C"]], 0) + cb_idx = atom_idx["CB"] + batch["all_atom_positions"][..., cb_idx, :] = np.where( + m[:, cb_idx, None], p[..., cb_idx, :], cb_atoms + ) + batch["all_atom_mask"][..., cb_idx] = (m[:, cb_idx] + cb_mask) > 0 + #batch["all_atom_positions"] = np.moveaxis(batch["all_atom_positions"], 0, -1) + + chain_data = { + "batch": batch, + "residue_index": residue_index + last_res_idx, + "chain_id": [chain_id] * n_res, + "res_indices_original": residue_index, + } + all_chains_data.append(chain_data) + + last_res_idx += n_res + 50 + full_lengths.append(n_res) + + if not all_chains_data: + raise ValueError("No valid chains found or processed from the mdtraj frame.") + + # the one with 4 dimensions is mapped of axis=1, this will be the coordinates + final_batch = jax.tree_util.tree_map( + lambda *x: np.concatenate(x, int(len(x[0].shape)==4)), *[d.pop("batch") for d in all_chains_data] + ) + final_residue_index = np.concatenate([d.pop("residue_index") for d in all_chains_data]) + final_idx = { + "residue": np.concatenate([d.pop("res_indices_original") for d in all_chains_data]), + "chain": np.concatenate([d.pop("chain_id") for d in all_chains_data]), + } + + return { + "batch": final_batch, + "residue_index": final_residue_index, + "idx": final_idx, + "lengths": full_lengths, + } diff --git a/build/lib/colabdesign/mpnn/model.py b/build/lib/colabdesign/mpnn/model.py new file mode 100644 index 00000000..077ed2fd --- /dev/null +++ b/build/lib/colabdesign/mpnn/model.py @@ -0,0 +1,460 @@ +import jax +import jax.numpy as jnp +import numpy as np +import re +import copy +import random +import os +import joblib + +from .modules import RunModel + +from colabdesign.shared.prep import prep_pos +from colabdesign.shared.utils import Key, copy_dict + +# borrow some stuff from AfDesign +from colabdesign.af.prep import prep_pdb +from colabdesign.af.alphafold.common import protein, residue_constants +aa_order = residue_constants.restype_order +order_aa = {b:a for a,b in aa_order.items()} + +class mk_mpnn_model(): + def __init__(self, model_name="v_48_020", + backbone_noise=0.0, dropout=0.0, + seed=None, verbose=False, weights="original", + batch_size=1000): # weights can be set to either original or soluble + # load model + if weights == "original": + from .weights import __file__ as mpnn_path + elif weights == "soluble": + from .weights_soluble import __file__ as mpnn_path + else: + raise ValueError(f'Invalid value {weights} supplied for weights. Value must be either "original" or "soluble".') + + path = os.path.join(os.path.dirname(mpnn_path), f'{model_name}.pkl') + checkpoint = joblib.load(path) + config = {'num_letters': 21, + 'node_features': 128, + 'edge_features': 128, + 'hidden_dim': 128, + 'num_encoder_layers': 3, + 'num_decoder_layers': 3, + 'augment_eps': backbone_noise, + 'k_neighbors': checkpoint['num_edges'], + 'dropout': dropout} + + self._model = RunModel(config) + self._model.params = jax.tree_util.tree_map(np.array, checkpoint['model_state_dict']) + self._setup() + self.set_seed(seed) + + self._num = 1 + self._inputs = {} + self._tied_lengths = False + self.batch_size + + def prep_inputs(self, pdb_filename=None, chain=None, homooligomer=False, + ignore_missing=True, fix_pos=None, inverse=False, + rm_aa=None, verbose=False, **kwargs): + + '''get inputs from input pdb''' + pdb = prep_pdb(pdb_filename, chain, ignore_missing=ignore_missing) + atom_idx = tuple(residue_constants.atom_order[k] for k in ["N","CA","C","O"]) + chain_idx = np.concatenate([[n]*l for n,l in enumerate(pdb["lengths"])]) + self._lengths = pdb["lengths"] + L = sum(self._lengths) + + self._inputs = {"X": pdb["batch"]["all_atom_positions"][:,atom_idx], + "mask": pdb["batch"]["all_atom_mask"][:,1], + "S": pdb["batch"]["aatype"], + "residue_idx": pdb["residue_index"], + "chain_idx": chain_idx, + "lengths": np.array(self._lengths), + "bias": np.zeros((L,20))} + + + if rm_aa is not None: + for aa in rm_aa.split(","): + self._inputs["bias"][...,aa_order[aa]] -= 1e6 + + if fix_pos is not None: + p = prep_pos(fix_pos, **pdb["idx"])["pos"] + if inverse: + p = np.delete(np.arange(L),p) + self._inputs["fix_pos"] = p + self._inputs["bias"][p] = 1e7 * np.eye(21)[self._inputs["S"]][p,:20] + + if homooligomer: + assert min(self._lengths) == max(self._lengths) + self._tied_lengths = True + self._len = self._lengths[0] + else: + self._tied_lengths = False + self._len = sum(self._lengths) + + self.pdb = pdb + + if verbose: + print("lengths", self._lengths) + if "fix_pos" in self._inputs: + print("the following positions will be fixed:") + print(self._inputs["fix_pos"]) + + def get_af_inputs(self, af): + '''get inputs from alphafold model''' + + self._lengths = af._lengths + self._len = af._len + + self._inputs["residue_idx"] = af._inputs["residue_index"] + self._inputs["chain_idx"] = af._inputs["asym_id"] + self._inputs["lengths"] = np.array(self._lengths) + + # set bias + L = sum(self._lengths) + self._inputs["bias"] = np.zeros((L,20)) + self._inputs["bias"][-af._len:] = af._inputs["bias"] + + if "offset" in af._inputs: + self._inputs["offset"] = af._inputs["offset"] + + if "batch" in af._inputs: + atom_idx = tuple(residue_constants.atom_order[k] for k in ["N","CA","C","O"]) + batch = af._inputs["batch"] + self._inputs["X"] = batch["all_atom_positions"][:,atom_idx] + self._inputs["mask"] = batch["all_atom_mask"][:,1] + self._inputs["S"] = batch["aatype"] + + # fix positions + if af.protocol == "binder": + p = np.arange(af._target_len) + else: + p = af.opt.get("fix_pos",None) + + if p is not None: + self._inputs["fix_pos"] = p + self._inputs["bias"][p] = 1e7 * np.eye(21)[self._inputs["S"]][p,:20] + + # tie positions + if af._args["homooligomer"]: + assert min(self._lengths) == max(self._lengths) + self._tied_lengths = True + else: + self._tied_lengths = False + + def sample(self, num=1, batch=1, temperature=0.1, rescore=False, **kwargs): + '''sample sequence''' + O = [] + for _ in range(num): + O.append(self.sample_parallel(batch, temperature, rescore, **kwargs)) + return jax.tree_util.tree_map(lambda *x:np.concatenate(x,0),*O) + + def sample_parallel(self, batch=10, temperature=0.1, rescore=False, **kwargs): + '''sample new sequence(s) in parallel''' + I = copy_dict(self._inputs) + I.update(kwargs) + key = I.pop("key",self.key()) + keys = jax.random.split(key,batch) + O = self._sample_parallel(keys, I, temperature, self._tied_lengths) + if rescore: + O = self._rescore_parallel(keys, I, O["S"], O["decoding_order"]) + O = jax.tree_util.tree_map(np.array, O) + + # process outputs to human-readable form + O.update(self._get_seq(O)) + O.update(self._get_score(I,O)) + return O + + def _get_seq(self, O): + ''' one_hot to amino acid sequence (still returns Python strings) ''' + # This function inherently produces Python strings, so it remains a host-side utility. + # If you needed JAX-manipulable sequences from here, the output format would need to change. + def split_seq(seq_str, lengths, tied_lengths): # pass lengths and tied_lengths explicitly + if len(lengths) > 1: + # This string manipulation cannot be JITted. + # If this were inside a JITted function, it would be a host callback. + seq_str = "".join(np.insert(list(seq_str), np.cumsum(lengths[:-1]), "/")) + if tied_lengths: + seq_str = seq_str.split("/")[0] + return seq_str + + seqs = [] + # Assuming O["S"] is (batch, L, 21) or (L, 21) + # Convert JAX array to NumPy for iteration and string conversion + S_numpy = np.array(O["S"].argmax(axis=-1)) + if S_numpy.ndim == 1: S_numpy = S_numpy[None,:] # ensure batch dimension + + for s_np in S_numpy: + # This part is Python string manipulation + seq = "".join([order_aa[a_idx] for a_idx in s_np]) + seq = split_seq(seq, self._lengths, self._tied_lengths) # pass necessary attributes + seqs.append(seq) + return {"seq": np.array(seqs)} + + + def _get_score(self, inputs_S, logits, mask, fix_pos=None): # Pass necessary inputs directly + ''' logits to score/sequence_recovery - JAX compatible version ''' + + + current_mask = mask + if fix_pos is not None and fix_pos.shape[0] > 0: # Ensure fix_pos is not empty + # Ensure mask is a JAX array for .at.set to work + current_mask = jnp.array(current_mask) # Convert if it's numpy + current_mask = current_mask.at[fix_pos].set(0) + + # Use jax.nn functions + log_q = jax.nn.log_softmax(logits, axis=-1)[..., :20] + q = jax.nn.softmax(logits[..., :20], axis=-1) + + S_scored_one_hot = jax.nn.one_hot(inputs_S, num_classes=21)[...,:20] # Assuming inputs_S is integer encoded for the sequence to score + # This would be I["S"] from the score() method + + score = -(S_scored_one_hot * log_q).sum(-1) + + seqid = (inputs_S == self._inputs["S"]) + + masked_score_sum = (score * current_mask).sum(-1) + masked_seqid_sum = (seqid * current_mask).sum(-1) + mask_sum = current_mask.sum() + 1e-8 + + final_score = masked_score_sum / mask_sum + final_seqid = masked_seqid_sum / mask_sum + + return {"score": final_score, "seqid": final_seqid} + + + def score(self, seq_numeric=None, **kwargs): # seq_numeric is an integer array + '''score sequence - JAX compatible version (mostly)''' + current_inputs = jax.tree_util.tree_map(jnp.array, self._inputs) + + if seq_numeric is not None: + # seq_numeric is expected to be an integer array of amino acid indices + p = jnp.arange(current_inputs["S"].shape[0]) + s_shape_0 = current_inputs["S"].shape[0] # Store shape for JAX tracing + + if self._tied_lengths and seq_numeric.shape[0] == self._lengths[0]: + # Assuming self._lengths is available and compatible + # seq_numeric might need tiling if it represents one chain of a homooligomer + num_repeats = len(self._lengths) + seq_numeric = jnp.tile(seq_numeric, num_repeats) + + if "fix_pos" in current_inputs and current_inputs["fix_pos"].shape[0] > 0: + # Ensure shapes are concrete or JAX can trace them + if seq_numeric.shape[0] == (s_shape_0 - current_inputs["fix_pos"].shape[0]): + p = jnp.delete(p, current_inputs["fix_pos"], axis=0) + + # Update S using .at[].set() + # Ensure seq_numeric is correctly broadcasted or indexed if p is tricky + current_inputs["S"] = current_inputs["S"].at[p].set(seq_numeric) + + # Combine kwargs with current_inputs, ensuring JAX types + for k, v in kwargs.items(): + current_inputs[k] = jnp.asarray(v) if not isinstance(v, jax.Array) else v + + + key_to_use = current_inputs.pop("key", self.key()) # self.key() provides a JAX key + + # _score is already JITted and expects JAX-compatible inputs + # The arguments to _score are X, mask, residue_idx, chain_idx, key, S, bias, decoding_order etc. + # Ensure all these are present in current_inputs and are JAX arrays. + + # Prepare arguments for self._score, ensuring they are all JAX arrays + score_fn_args = {k: current_inputs[k] for k in [ + 'X', 'mask', 'residue_idx', 'chain_idx', 'S', 'bias' + ] if k in current_inputs} + + if "decoding_order" in current_inputs: + score_fn_args["decoding_order"] = current_inputs["decoding_order"] + if "fix_pos" in current_inputs: # _score uses fix_pos to adjust decoding_order + score_fn_args["fix_pos"] = current_inputs["fix_pos"] + + + # O will be a dictionary of JAX arrays + O = self._score(**score_fn_args, key=key_to_use) + + # Call the JAX-compatible _get_score + # It needs: current_inputs["S"] (the sequence being scored, possibly modified), + # O["logits"], current_inputs["mask"], and current_inputs.get("fix_pos") + score_info = self._get_score( + inputs_S=current_inputs["S"], # This is the S that was actually scored by _score + logits=O["logits"], + mask=current_inputs["mask"], + fix_pos=current_inputs.get("fix_pos") + ) + O.update(score_info) # O remains a dict of JAX arrays + + # If you need to convert to NumPy arrays for external use, do it here, + # but the function itself now primarily deals with JAX arrays. + # For full JAX compatibility of `score` itself (e.g. to JIT it), + # this conversion should be outside. + # return jax.tree_map(np.array, O) + return O # Returns dict of JAX array + + def get_logits(self, **kwargs): + '''get logits''' + return self.score(**kwargs)["logits"] + + def get_unconditional_logits(self, **kwargs): + L = self._inputs["X"].shape[0] + kwargs["ar_mask"] = np.zeros((L,L)) + return self.score(**kwargs)["logits"] + + def set_seed(self, seed=None): + np.random.seed(seed=seed) + self.key = Key(seed=seed).get + + def _setup(self): + def _score_internal(X, mask, residue_idx, chain_idx, key, S, bias, **kwargs): # Added S and bias + I = {'X': X, + 'mask': mask, + 'residue_idx': residue_idx, + 'chain_idx': chain_idx, + 'S': S, # Pass S + 'bias': bias # Pass bias + } + I.update(kwargs) + + if "decoding_order" not in I: + key, sub_key = jax.random.split(key) + randn = jax.random.uniform(sub_key, (I["X"].shape[0],)) + randn = jnp.where(I["mask"], randn, randn+1) + if "fix_pos" in I and I["fix_pos"].shape[0] > 0: # check if fix_pos is not empty + randn = randn.at[I["fix_pos"]].add(-1) + I["decoding_order"] = randn.argsort() + + # _aa_convert is JAX-compatible + for k_item in ["S", "bias"]: # Use k_item to avoid conflict with key + if k_item in I: I[k_item] = _aa_convert(I[k_item]) + + output_dict = self._model.score(self._model.params, key, I) + output_dict["S"] = _aa_convert(output_dict["S"], rev=True) + output_dict["logits"] = _aa_convert(output_dict["logits"], rev=True) + return output_dict + + self._score = jax.jit(_score_internal) + + def _sample_internal(X, mask, residue_idx, chain_idx, key, + temperature=0.1, tied_lengths=False, bias=None, **kwargs): # added bias + I = {'X': X, + 'mask': mask, + 'residue_idx': residue_idx, + 'chain_idx': chain_idx, + 'temperature': temperature, + 'bias': bias # Pass bias + } + I.update(kwargs) + + # define decoding order (as in original _sample) + if "decoding_order" in I: + if I["decoding_order"].ndim == 1: + I["decoding_order"] = I["decoding_order"][:,None] + else: + key, sub_key = jax.random.split(key) + randn = jax.random.uniform(sub_key, (I["X"].shape[0],)) + randn = jnp.where(I["mask"], randn, randn+1) + if "fix_pos" in I and I["fix_pos"].shape[0] > 0: randn = randn.at[I["fix_pos"]].add(-1) + if tied_lengths: + copies = I["lengths"].shape[0] + decoding_order_tied = randn.reshape(copies,-1).mean(0).argsort() + I["decoding_order"] = jnp.arange(I["X"].shape[0]).reshape(copies,-1).T[decoding_order_tied] + else: + I["decoding_order"] = randn.argsort()[:,None] + + # S is not an input to _model.sample, but bias is + if "S" in I: I["S"] = _aa_convert(I["S"]) # If S is somehow passed (e.g. for conditioning, though MPNN typically doesn't) + if "bias" in I : I["bias"] = _aa_convert(I["bias"]) + + + O_dict = self._model.sample(self._model.params, key, I) + O_dict["S"] = _aa_convert(O_dict["S"], rev=True) # This is the sampled S + O_dict["logits"] = _aa_convert(O_dict["logits"], rev=True) + return O_dict + + self._sample = jax.jit(_sample_internal, static_argnames=["tied_lengths"]) + + # Adjust vmapped functions accordingly + def _vmap_sample_parallel(key, inputs, temperature, tied_lengths): + inputs_copy = dict(inputs) # Shallow copy for modification + inputs_copy.pop("temperature",None) + inputs_copy.pop("key",None) + # Ensure 'bias' is correctly handled if it's part of 'inputs' + return self._sample(**inputs_copy, key=key, temperature=temperature, tied_lengths=tied_lengths, batch_size=self.batch_size) + + fn_vmap_sample = jax.lax.map(_vmap_sample_parallel, in_axes=[0,None,None,None]) + self._sample_parallel = jax.jit(fn_vmap_sample, static_argnames=["tied_lengths"]) + + def _vmap_rescore_parallel(key, inputs, S_rescore, decoding_order_rescore): + inputs_copy = dict(inputs) # Shallow copy + inputs_copy.pop("S",None) + inputs_copy.pop("decoding_order",None) + inputs_copy.pop("key",None) + # Ensure 'bias' from original inputs is used, and S_rescore is the new S + return self._score(**inputs_copy, key=key, S=S_rescore, decoding_order=decoding_order_rescore) # Pass S and decoding_order + + fn_vmap_rescore = jax.lax.map(_vmap_rescore_parallel, in_axes=[0,None,0,0]) + self._rescore_parallel = jax.jit(fn_vmap_rescore) + +####################################################################################### + +def _aa_convert(x, rev=False): + mpnn_alphabet = 'ACDEFGHIKLMNPQRSTVWYX' + af_alphabet = 'ARNDCQEGHILKMFPSTWYVX' + if x is None: + return x + else: + if rev: + return x[...,tuple(mpnn_alphabet.index(k) for k in af_alphabet)] + else: + x = jax.nn.one_hot(x,21) if jnp.issubdtype(x.dtype, jnp.integer) else x + if x.shape[-1] == 20: + x = jnp.pad(x,[[0,0],[0,1]]) + return x[...,tuple(af_alphabet.index(k) for k in mpnn_alphabet)] + +unknown_aa_index = aa_order.get('X', 20) # Default index for unknown AAs + +def convert_sequence_to_numeric(sequence_str: str, + aa_map: dict = aa_order, + all_chain_lengths: list = None, + is_homooligomer_tied: bool = False) -> jnp.array: + """ + Converts a protein sequence string into a JAX integer array. + + Args: + sequence_str: The amino acid sequence string. + - For monomers: "ACEG..." + - For heteromers (chains separated by '/'): "ACEG.../FGHI..." + - For homooligomers where is_homooligomer_tied is True and + only one chain's sequence is provided: "ACEG..." (will be tiled). + aa_map: Dictionary mapping amino acid characters to integers (e.g., aa_order). + all_chain_lengths: List of lengths of all chains in the complex. + Example: [100, 100] for a dimer of length 100 each. + Used for homooligomer tiling. + is_homooligomer_tied: Boolean. If True and sequence_str is for a single + chain of a homooligomer, the sequence will be tiled. + + Returns: + jnp.array: A JAX array of integers representing the full sequence. + """ + numeric_sequence_list = [] + + # Handle homooligomer case where a single chain sequence is provided to be tiled + if is_homooligomer_tied and \ + all_chain_lengths and \ + len(all_chain_lengths) > 0 and \ + "/" not in sequence_str: + # Check if the provided sequence string matches the length of one chain + if len(sequence_str) == all_chain_lengths[0]: # Assuming all chains have the same length + num_chains = len(all_chain_lengths) + # Tile the string sequence before converting to numeric + sequence_str = "/".join([sequence_str] * num_chains) + # TODO: add a warning or error if the lengths don't match + + # Process chain by chain if '/' is present, otherwise process the whole string + chains = sequence_str.split('/') + + for chain_seq_str in chains: + for aa_char in chain_seq_str: + # Use .get(key, default_value) to handle unexpected characters + numeric_sequence_list.append(aa_map.get(aa_char, unknown_aa_index)) + + return jnp.array(numeric_sequence_list, dtype=jnp.int32) diff --git a/build/lib/colabdesign/mpnn/modules.py b/build/lib/colabdesign/mpnn/modules.py new file mode 100644 index 00000000..462246c5 --- /dev/null +++ b/build/lib/colabdesign/mpnn/modules.py @@ -0,0 +1,332 @@ +import functools +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import joblib + +from colabdesign.shared.prng import SafeKey +from .utils import cat_neighbors_nodes, get_ar_mask +from .sample import mpnn_sample +from .score import mpnn_score + +Gelu = functools.partial(jax.nn.gelu, approximate=False) + +class dropout_cust(hk.Module): + def __init__(self, rate) -> None: + super().__init__() + self.rate = rate + self.safe_key = SafeKey(hk.next_rng_key()) + + def __call__(self, x): + self.safe_key, use_key = self.safe_key.split() + return hk.dropout(use_key.get(), self.rate, x) + + +class EncLayer(hk.Module): + def __init__(self, num_hidden, + num_in, dropout=0.1, + num_heads=None, scale=30, + name=None): + super(EncLayer, self).__init__() + self.num_hidden = num_hidden + self.num_in = num_in + self.scale = scale + + self.safe_key = SafeKey(hk.next_rng_key()) + + self.dropout1 = dropout_cust(dropout) + self.dropout2 = dropout_cust(dropout) + self.dropout3 = dropout_cust(dropout) + self.norm1 = hk.LayerNorm(-1, create_scale=True, create_offset=True, + name=name + '_norm1') + self.norm2 = hk.LayerNorm(-1, create_scale=True, create_offset=True, + name=name + '_norm2') + self.norm3 = hk.LayerNorm(-1, create_scale=True, create_offset=True, + name=name + '_norm3') + + self.W1 = hk.Linear(num_hidden, with_bias=True, name=name + '_W1') + self.W2 = hk.Linear(num_hidden, with_bias=True, name=name + '_W2') + self.W3 = hk.Linear(num_hidden, with_bias=True, name=name + '_W3') + self.W11 = hk.Linear(num_hidden, with_bias=True, name=name + '_W11') + self.W12 = hk.Linear(num_hidden, with_bias=True, name=name + '_W12') + self.W13 = hk.Linear(num_hidden, with_bias=True, name=name + '_W13') + self.act = Gelu + self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4, + name=name + '_dense') + + def __call__(self, h_V, h_E, E_idx, + mask_V=None, mask_attend=None): + """ Parallel computation of full transformer layer """ + + h_EV = cat_neighbors_nodes(h_V, h_E, E_idx) + h_V_expand = jnp.tile(jnp.expand_dims(h_V, -2),[1, h_EV.shape[-2], 1]) + h_EV = jnp.concatenate([h_V_expand, h_EV], -1) + + h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV))))) + if mask_attend is not None: + h_message = jnp.expand_dims(mask_attend, -1)* h_message + dh = jnp.sum(h_message, -2) / self.scale + h_V = self.norm1(h_V + self.dropout1(dh)) + + dh = self.dense(h_V) + h_V = self.norm2(h_V + self.dropout2(dh)) + if mask_V is not None: + mask_V = jnp.expand_dims(mask_V, -1) + h_V = mask_V * h_V + + h_EV = cat_neighbors_nodes(h_V, h_E, E_idx) + h_V_expand = jnp.tile(jnp.expand_dims(h_V, -2),[1, h_EV.shape[-2], 1]) + h_EV = jnp.concatenate([h_V_expand, h_EV], -1) + + h_message = self.W13(self.act(self.W12(self.act(self.W11(h_EV))))) + h_E = self.norm3(h_E + self.dropout3(h_message)) + return h_V, h_E + +class DecLayer(hk.Module): + def __init__(self, num_hidden, num_in, + dropout=0.1, num_heads=None, + scale=30, name=None): + super(DecLayer, self).__init__() + self.num_hidden = num_hidden + self.num_in = num_in + self.scale = scale + self.dropout1 = dropout_cust(dropout) + self.dropout2 = dropout_cust(dropout) + self.norm1 = hk.LayerNorm(-1, create_scale=True, create_offset=True, + name=name + '_norm1') + self.norm2 = hk.LayerNorm(-1, create_scale=True, create_offset=True, + name=name + '_norm2') + + self.W1 = hk.Linear(num_hidden, with_bias=True, name=name + '_W1') + self.W2 = hk.Linear(num_hidden, with_bias=True, name=name + '_W2') + self.W3 = hk.Linear(num_hidden, with_bias=True, name=name + '_W3') + self.act = Gelu + self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4, + name=name + '_dense') + + + def __call__(self, h_V, h_E, + mask_V=None, mask_attend=None): + """ Parallel computation of full transformer layer """ + + # Concatenate h_V_i to h_E_ij + h_V_expand = jnp.tile(jnp.expand_dims(h_V, -2),[1, h_E.shape[-2], 1]) + h_EV = jnp.concatenate([h_V_expand, h_E], -1) + + h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV))))) + if mask_attend is not None: + h_message = jnp.expand_dims(mask_attend, -1) * h_message + dh = jnp.sum(h_message, -2) / self.scale + + h_V = self.norm1(h_V + self.dropout1(dh)) + + # Position-wise feedforward + dh = self.dense(h_V) + h_V = self.norm2(h_V + self.dropout2(dh)) + + if mask_V is not None: + mask_V = jnp.expand_dims(mask_V, -1) + h_V = mask_V * h_V + return h_V + +class PositionWiseFeedForward(hk.Module): + def __init__(self, num_hidden, num_ff, name=None): + super(PositionWiseFeedForward, self).__init__() + self.W_in = hk.Linear(num_ff, with_bias=True, name=name + '_W_in') + self.W_out = hk.Linear(num_hidden, with_bias=True, name=name + '_W_out') + self.act = Gelu + def __call__(self, h_V): + h = self.act(self.W_in(h_V), approximate=False) + h = self.W_out(h) + return h + +class PositionalEncodings(hk.Module): + def __init__(self, num_embeddings, max_relative_feature=32): + super(PositionalEncodings, self).__init__() + self.num_embeddings = num_embeddings + self.max_relative_feature = max_relative_feature + self.linear = hk.Linear(num_embeddings, name='embedding_linear') + + def __call__(self, offset, mask): + d = jnp.clip(offset + self.max_relative_feature, 0, 2*self.max_relative_feature) * mask + \ + (1 - mask) * (2*self.max_relative_feature + 1) + d_onehot = jax.nn.one_hot(d, 2*self.max_relative_feature + 1 + 1) + E = self.linear(d_onehot) + return E + +class RunModel: + def __init__(self, config) -> None: + self.config = config + + def _forward_score(inputs): + model = ProteinMPNN(**self.config) + return model.score(inputs) + self.score = hk.transform(_forward_score).apply + + def _forward_sample(inputs): + model = ProteinMPNN(**self.config) + return model.sample(inputs) + self.sample = hk.transform(_forward_sample).apply + + def load_params(self, path): + self.params = joblib.load(path) + +class ProteinFeatures(hk.Module): + def __init__(self, edge_features, node_features, + num_positional_embeddings=16, + num_rbf=16, top_k=30, + augment_eps=0., num_chain_embeddings=16): + + """ Extract protein features """ + super(ProteinFeatures, self).__init__() + self.edge_features = edge_features + self.node_features = node_features + self.top_k = top_k + self.augment_eps = augment_eps + self.num_rbf = num_rbf + self.num_positional_embeddings = num_positional_embeddings + + self.embeddings = PositionalEncodings(num_positional_embeddings) + node_in, edge_in = 6, num_positional_embeddings + num_rbf*25 + self.edge_embedding = hk.Linear(edge_features, with_bias=False, name='edge_embedding') + self.norm_edges = hk.LayerNorm(-1, create_scale=True, create_offset=True, name='norm_edges') + + self.safe_key = SafeKey(hk.next_rng_key()) + + def _get_edge_idx(self, X, mask, eps=1E-6): + ''' get edge index + input: mask.shape = (...,L), X.shape = (...,L,3) + return: (...,L,k) + ''' + mask_2D = mask[...,None,:] * mask[...,:,None] + dX = X[...,None,:,:] - X[...,:,None,:] + D = jnp.sqrt(jnp.square(dX).sum(-1) + eps) + D_masked = jnp.where(mask_2D,D,D.max(-1,keepdims=True)) + k = min(self.top_k, X.shape[-2]) + return jax.lax.approx_min_k(D_masked, k, reduction_dimension=-1)[1] + + def _rbf(self, D): + ''' radial basis function (RBF) + input: (...,L,k) + output: (...,L,k,?) + ''' + D_min, D_max, D_count = 2., 22., self.num_rbf + D_mu = jnp.linspace(D_min, D_max, D_count) + D_sigma = (D_max - D_min) / D_count + return jnp.exp(-((D[...,None] - D_mu) / D_sigma)**2) + + def _get_rbf(self, A, B, E_idx): + D = jnp.sqrt(jnp.square(A[...,:,None,:] - B[...,None,:,:]).sum(-1) + 1e-6) + D_neighbors = jnp.take_along_axis(D, E_idx, 1) + return self._rbf(D_neighbors) + + def __call__(self, I): + if self.augment_eps > 0: + self.safe_key, use_key = self.safe_key.split() + X = I["X"] + self.augment_eps * jax.random.normal(use_key.get(), I["X"].shape) + else: + X = I["X"] + + ########################## + # get atoms + ########################## + # N,Ca,C,O,Cb + Y = X.swapaxes(0,1) #(length, atoms, 3) -> (atoms, length, 3) + if Y.shape[0] == 4: + # add Cb + b,c = (Y[1]-Y[0]),(Y[2]-Y[1]) + Cb = -0.58273431*jnp.cross(b,c) + 0.56802827*b - 0.54067466*c + Y[1] + Y = jnp.concatenate([Y,Cb[None]],0) + + ########################## + # gather edge features + ########################## + # get edge indices (based on ca-ca distances) + E_idx = self._get_edge_idx(Y[1], I["mask"]) + + # rbf encode distances between atoms + edges = jnp.array([[1,1],[0,0],[2,2],[3,3],[4,4], + [1,0],[1,2],[1,3],[1,4],[0,2], + [0,3],[0,4],[4,2],[4,3],[3,2], + [0,1],[2,1],[3,1],[4,1],[2,0], + [3,0],[4,0],[2,4],[3,4],[2,3]]) + RBF_all = jax.vmap(lambda x:self._get_rbf(Y[x[0]],Y[x[1]],E_idx))(edges) + RBF_all = RBF_all.transpose((1,2,0,3)) + RBF_all = RBF_all.reshape(RBF_all.shape[:-2]+(-1,)) + + ########################## + # position embedding + ########################## + # residue index offset + if "offset" not in I: + I["offset"] = I["residue_idx"][:,None] - I["residue_idx"][None,:] + offset = jnp.take_along_axis(I["offset"], E_idx, 1) + + # chain index offset + E_chains = (I["chain_idx"][:,None] == I["chain_idx"][None,:]).astype(int) + E_chains = jnp.take_along_axis(E_chains, E_idx, 1) + E_positional = self.embeddings(offset, E_chains) + + ########################## + # define edges + ########################## + E = jnp.concatenate((E_positional, RBF_all), -1) + E = self.edge_embedding(E) + E = self.norm_edges(E) + return E, E_idx + +class EmbedToken(hk.Module): + def __init__(self, vocab_size, embed_dim): + super().__init__() + self.vocab_size = vocab_size + self.embed_dim = embed_dim + self.w_init = hk.initializers.TruncatedNormal() + + @property + def embeddings(self): + return hk.get_parameter("W_s", + [self.vocab_size, self.embed_dim], + init=self.w_init) + + def __call__(self, arr): + if jnp.issubdtype(arr.dtype, jnp.integer): + one_hot = jax.nn.one_hot(arr, self.vocab_size) + else: + one_hot = arr + return jnp.tensordot(one_hot, self.embeddings, 1) + +class ProteinMPNN(hk.Module, mpnn_sample, mpnn_score): + def __init__(self, num_letters, + node_features, edge_features, hidden_dim, + num_encoder_layers=3, num_decoder_layers=3, + vocab=21, k_neighbors=64, + augment_eps=0.05, dropout=0.1): + super(ProteinMPNN, self).__init__() + + # Hyperparameters + self.node_features = node_features + self.edge_features = edge_features + self.hidden_dim = hidden_dim + + # Featurization layers + self.features = ProteinFeatures(edge_features, + node_features, + top_k=k_neighbors, + augment_eps=augment_eps) + + self.W_e = hk.Linear(hidden_dim, with_bias=True, name='W_e') + self.W_s = EmbedToken(vocab_size=vocab, embed_dim=hidden_dim) + + # Encoder layers + self.encoder_layers = [ + EncLayer(hidden_dim, hidden_dim*2, dropout=dropout, name='enc' + str(i)) + for i in range(num_encoder_layers) + ] + + # Decoder layers + self.decoder_layers = [ + DecLayer(hidden_dim, hidden_dim*3, dropout=dropout, name='dec' + str(i)) + for i in range(num_decoder_layers) + ] + self.W_out = hk.Linear(num_letters, with_bias=True, name='W_out') \ No newline at end of file diff --git a/build/lib/colabdesign/mpnn/sample.py b/build/lib/colabdesign/mpnn/sample.py new file mode 100644 index 00000000..401e1884 --- /dev/null +++ b/build/lib/colabdesign/mpnn/sample.py @@ -0,0 +1,107 @@ +import jax +import jax.numpy as jnp +import haiku as hk +import numpy as np + +from .utils import cat_neighbors_nodes, get_ar_mask + +class mpnn_sample: + def sample(self, I): + """ + I = { + [[required]] + 'X' = (L,4,3) + 'mask' = (L,) + 'residue_index' = (L,) + 'chain_idx' = (L,) + 'decoding_order' = (L,) + + [[optional]] + 'ar_mask' = (L,L) + 'bias' = (L,21) + 'temperature' = 1.0 + } + """ + + key = hk.next_rng_key() + L = I["X"].shape[0] + temperature = I.get("temperature",1.0) + + # prepare node and edge embeddings + E, E_idx = self.features(I) + h_V = jnp.zeros((E.shape[0], E.shape[-1])) + h_E = self.W_e(E) + + ############## + # encoder + ############## + mask_attend = jnp.take_along_axis(I["mask"][:,None] * I["mask"][None,:], E_idx, 1) + for layer in self.encoder_layers: + h_V, h_E = layer(h_V, h_E, E_idx, I["mask"], mask_attend) + + # get autoregressive mask + ar_mask = I.get("ar_mask",get_ar_mask(I["decoding_order"])) + + mask_attend = jnp.take_along_axis(ar_mask, E_idx, 1) + mask_1D = I["mask"][:,None] + mask_bw = mask_1D * mask_attend + mask_fw = mask_1D * (1 - mask_attend) + + h_EX_encoder = cat_neighbors_nodes(jnp.zeros_like(h_V), h_E, E_idx) + h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) + h_EXV_encoder = mask_fw[...,None] * h_EXV_encoder + + def fwd(x, t, key): + h_EXV_encoder_t = h_EXV_encoder[t] + E_idx_t = E_idx[t] + mask_t = I["mask"][t] + mask_bw_t = mask_bw[t] + h_ES_t = cat_neighbors_nodes(x["h_S"], h_E[t], E_idx_t) + + ############## + # decoder + ############## + for l,layer in enumerate(self.decoder_layers): + h_V = x["h_V"][l] + h_ESV_decoder_t = cat_neighbors_nodes(h_V, h_ES_t, E_idx_t) + h_ESV_t = mask_bw_t[...,None] * h_ESV_decoder_t + h_EXV_encoder_t + h_V_t = layer(h_V[t], h_ESV_t, mask_V=mask_t) + # update + x["h_V"] = x["h_V"].at[l+1,t].set(h_V_t) + + logits_t = self.W_out(h_V_t) + x["logits"] = x["logits"].at[t].set(logits_t) + + ############## + # sample + ############## + + # add bias + if "bias" in I: logits_t += I["bias"][t] + + # sample character + logits_t = logits_t/temperature + jax.random.gumbel(key, logits_t.shape) + + # tie positions + logits_t = logits_t.mean(0, keepdims=True) + + S_t = jax.nn.one_hot(logits_t[...,:20].argmax(-1), 21) + + # update + x["h_S"] = x["h_S"].at[t].set(self.W_s(S_t)) + x["S"] = x["S"].at[t].set(S_t) + return x, None + + # initial values + X = {"h_S": jnp.zeros_like(h_V), + "h_V": jnp.array([h_V] + [jnp.zeros_like(h_V)] * len(self.decoder_layers)), + "S": jnp.zeros((L,21)), + "logits": jnp.zeros((L,21))} + + # scan over decoding order + t = I["decoding_order"] + if t.ndim == 1: t = t[:,None] + XS = {"t":t, "key":jax.random.split(key,t.shape[0])} + X = hk.scan(lambda x, xs: fwd(x, xs["t"], xs["key"]), X, XS)[0] + + return {"S":X["S"], "logits":X["logits"], "decoding_order":t} \ No newline at end of file diff --git a/build/lib/colabdesign/mpnn/score.py b/build/lib/colabdesign/mpnn/score.py new file mode 100644 index 00000000..575d3896 --- /dev/null +++ b/build/lib/colabdesign/mpnn/score.py @@ -0,0 +1,80 @@ +import jax +import jax.numpy as jnp +import haiku as hk +import numpy as np + +from .utils import cat_neighbors_nodes, get_ar_mask + +class mpnn_score: + def score(self, I): + """ + I = { + [[required]] + 'X' = (L,4,3) + 'mask' = (L,) + 'residue_index' = (L,) + 'chain_idx' = (L,) + + [[optional]] + 'S' = (L,21) + 'decoding_order' = (L,) + 'ar_mask' = (L,L) + } + """ + + key = hk.next_rng_key() + # Prepare node and edge embeddings + E, E_idx = self.features(I) + h_V = jnp.zeros((E.shape[0], E.shape[-1])) + h_E = self.W_e(E) + + # Encoder is unmasked self-attention + mask_attend = jnp.take_along_axis(I["mask"][:,None] * I["mask"][None,:], E_idx, 1) + + for layer in self.encoder_layers: + h_V, h_E = layer(h_V, h_E, E_idx, I["mask"], mask_attend) + + # Build encoder embeddings + h_EX_encoder = cat_neighbors_nodes(jnp.zeros_like(h_V), h_E, E_idx) + h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) + + if "S" not in I: + ########################################## + # unconditional_probs + ########################################## + h_EXV_encoder_fw = h_EXV_encoder + for layer in self.decoder_layers: + h_V = layer(h_V, h_EXV_encoder_fw, I["mask"]) + decoding_order = None + else: + ########################################## + # conditional_probs + ########################################## + + # Concatenate sequence embeddings for autoregressive decoder + h_S = self.W_s(I["S"]) + h_ES = cat_neighbors_nodes(h_S, h_E, E_idx) + + # get autoregressive mask + if "ar_mask" in I: + decoding_order = None + ar_mask = I["ar_mask"] + else: + decoding_order = I["decoding_order"] + ar_mask = get_ar_mask(decoding_order) + + mask_attend = jnp.take_along_axis(ar_mask, E_idx, 1) + mask_1D = I["mask"][:,None] + mask_bw = mask_1D * mask_attend + mask_fw = mask_1D * (1 - mask_attend) + + h_EXV_encoder_fw = mask_fw[...,None] * h_EXV_encoder + for layer in self.decoder_layers: + # Masked positions attend to encoder information, unmasked see. + h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx) + h_ESV = mask_bw[...,None] * h_ESV + h_EXV_encoder_fw + h_V = layer(h_V, h_ESV, I["mask"]) + + logits = self.W_out(h_V) + S = I.get("S",None) + return {"logits": logits, "decoding_order":decoding_order, "S":S} \ No newline at end of file diff --git a/build/lib/colabdesign/mpnn/utils.py b/build/lib/colabdesign/mpnn/utils.py new file mode 100644 index 00000000..2c49296b --- /dev/null +++ b/build/lib/colabdesign/mpnn/utils.py @@ -0,0 +1,26 @@ +import jax.numpy as jnp +import jax + +def gather_nodes(nodes, neighbor_idx): + # Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C] + # Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C] + neighbors_flat = neighbor_idx.reshape([neighbor_idx[None].shape[0], -1]) + neighbors_flat = jnp.tile(jnp.expand_dims(neighbors_flat, -1),[1, 1, nodes[None].shape[2]]) + # Gather and re-pack + neighbor_features = jnp.take_along_axis(nodes[None], neighbors_flat, 1) + neighbor_features = neighbor_features.reshape(list(neighbor_idx[None].shape[:3]) + [-1]) + return neighbor_features[0] + +def cat_neighbors_nodes(h_nodes, h_neighbors, E_idx): + h_nodes = gather_nodes(h_nodes, E_idx)[None] + h_nn = jnp.concatenate([h_neighbors[None], h_nodes], -1) + return h_nn[0] + +def get_ar_mask(order): + '''compute autoregressive mask, given order of positions''' + order = order.flatten() + L = order.shape[-1] + tri = jnp.tri(L, k=-1) + idx = order.argsort() + ar_mask = tri[idx,:][:,idx] + return ar_mask \ No newline at end of file diff --git a/build/lib/colabdesign/mpnn/weights/__init__.py b/build/lib/colabdesign/mpnn/weights/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/build/lib/colabdesign/mpnn/weights/__init__.py @@ -0,0 +1 @@ + diff --git a/build/lib/colabdesign/mpnn/weights/v_48_002.pkl b/build/lib/colabdesign/mpnn/weights/v_48_002.pkl new file mode 100644 index 00000000..61e67126 Binary files /dev/null and b/build/lib/colabdesign/mpnn/weights/v_48_002.pkl differ diff --git a/build/lib/colabdesign/mpnn/weights/v_48_010.pkl b/build/lib/colabdesign/mpnn/weights/v_48_010.pkl new file mode 100644 index 00000000..2c2a9206 Binary files /dev/null and b/build/lib/colabdesign/mpnn/weights/v_48_010.pkl differ diff --git a/build/lib/colabdesign/mpnn/weights/v_48_020.pkl b/build/lib/colabdesign/mpnn/weights/v_48_020.pkl new file mode 100644 index 00000000..d7d3d52d Binary files /dev/null and b/build/lib/colabdesign/mpnn/weights/v_48_020.pkl differ diff --git a/build/lib/colabdesign/mpnn/weights/v_48_030.pkl b/build/lib/colabdesign/mpnn/weights/v_48_030.pkl new file mode 100644 index 00000000..17e6c520 Binary files /dev/null and b/build/lib/colabdesign/mpnn/weights/v_48_030.pkl differ diff --git a/build/lib/colabdesign/mpnn/weights_soluble/__init__.py b/build/lib/colabdesign/mpnn/weights_soluble/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/build/lib/colabdesign/mpnn/weights_soluble/__init__.py @@ -0,0 +1 @@ + diff --git a/build/lib/colabdesign/mpnn/weights_soluble/v_48_002.pkl b/build/lib/colabdesign/mpnn/weights_soluble/v_48_002.pkl new file mode 100644 index 00000000..252fac9f Binary files /dev/null and b/build/lib/colabdesign/mpnn/weights_soluble/v_48_002.pkl differ diff --git a/build/lib/colabdesign/mpnn/weights_soluble/v_48_010.pkl b/build/lib/colabdesign/mpnn/weights_soluble/v_48_010.pkl new file mode 100644 index 00000000..10408ec8 Binary files /dev/null and b/build/lib/colabdesign/mpnn/weights_soluble/v_48_010.pkl differ diff --git a/build/lib/colabdesign/mpnn/weights_soluble/v_48_020.pkl b/build/lib/colabdesign/mpnn/weights_soluble/v_48_020.pkl new file mode 100644 index 00000000..51b30e52 Binary files /dev/null and b/build/lib/colabdesign/mpnn/weights_soluble/v_48_020.pkl differ diff --git a/build/lib/colabdesign/mpnn/weights_soluble/v_48_030.pkl b/build/lib/colabdesign/mpnn/weights_soluble/v_48_030.pkl new file mode 100644 index 00000000..37a27d8c Binary files /dev/null and b/build/lib/colabdesign/mpnn/weights_soluble/v_48_030.pkl differ diff --git a/build/lib/colabdesign/rf/__init__.py b/build/lib/colabdesign/rf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/colabdesign/rf/designability_test.py b/build/lib/colabdesign/rf/designability_test.py new file mode 100644 index 00000000..a33e6b31 --- /dev/null +++ b/build/lib/colabdesign/rf/designability_test.py @@ -0,0 +1,199 @@ +import os,sys + +from colabdesign.mpnn import mk_mpnn_model +from colabdesign.af import mk_af_model +from colabdesign.shared.protein import pdb_to_string +from colabdesign.shared.parse_args import parse_args + +import pandas as pd +import numpy as np +from string import ascii_uppercase, ascii_lowercase +alphabet_list = list(ascii_uppercase+ascii_lowercase) + +def get_info(contig): + F = [] + free_chain = False + fixed_chain = False + sub_contigs = [x.split("-") for x in contig.split("/")] + for n,(a,b) in enumerate(sub_contigs): + if a[0].isalpha(): + L = int(b)-int(a[1:]) + 1 + F += [1] * L + fixed_chain = True + else: + L = int(b) + F += [0] * L + free_chain = True + return F,[fixed_chain,free_chain] + +def main(argv): + ag = parse_args() + ag.txt("-------------------------------------------------------------------------------------") + ag.txt("Designability Test") + ag.txt("-------------------------------------------------------------------------------------") + ag.txt("REQUIRED") + ag.txt("-------------------------------------------------------------------------------------") + ag.add(["pdb=" ], None, str, ["input pdb"]) + ag.add(["loc=" ], None, str, ["location to save results"]) + ag.add(["contigs=" ], None, str, ["contig definition"]) + ag.txt("-------------------------------------------------------------------------------------") + ag.txt("OPTIONAL") + ag.txt("-------------------------------------------------------------------------------------") + ag.add(["copies=" ], 1, int, ["number of repeating copies"]) + ag.add(["num_seqs=" ], 8, int, ["number of mpnn designs to evaluate"]) + ag.add(["initial_guess" ], False, None, ["initialize previous coordinates"]) + ag.add(["use_multimer" ], False, None, ["use alphafold_multimer_v3"]) + ag.add(["use_soluble" ], False, None, ["use solubleMPNN"]) + ag.add(["num_recycles=" ], 3, int, ["number of recycles"]) + ag.add(["rm_aa="], "C", str, ["disable specific amino acids from being sampled"]) + ag.add(["num_designs=" ], 1, int, ["number of designs to evaluate"]) + ag.add(["mpnn_sampling_temp=" ], 0.1, float, ["sampling temperature used by proteinMPNN"]) + ag.txt("-------------------------------------------------------------------------------------") + o = ag.parse(argv) + + if None in [o.pdb, o.loc, o.contigs]: + ag.usage("Missing Required Arguments") + + if o.rm_aa == "": + o.rm_aa = None + + # filter contig input + contigs = [] + for contig_str in o.contigs.replace(" ",":").replace(",",":").split(":"): + if len(contig_str) > 0: + contig = [] + for x in contig_str.split("/"): + if x != "0": contig.append(x) + contigs.append("/".join(contig)) + + chains = alphabet_list[:len(contigs)] + info = [get_info(x) for x in contigs] + fixed_pos = [] + fixed_chains = [] + free_chains = [] + both_chains = [] + for pos,(fixed_chain,free_chain) in info: + fixed_pos += pos + fixed_chains += [fixed_chain and not free_chain] + free_chains += [free_chain and not fixed_chain] + both_chains += [fixed_chain and free_chain] + + flags = {"initial_guess":o.initial_guess, + "best_metric":"rmsd", + "use_multimer":o.use_multimer, + "model_names":["model_1_multimer_v3" if o.use_multimer else "model_1_ptm"]} + + if sum(both_chains) == 0 and sum(fixed_chains) > 0 and sum(free_chains) > 0: + protocol = "binder" + print("protocol=binder") + target_chains = [] + binder_chains = [] + for n,x in enumerate(fixed_chains): + if x: target_chains.append(chains[n]) + else: binder_chains.append(chains[n]) + af_model = mk_af_model(protocol="binder",**flags) + prep_flags = {"target_chain":",".join(target_chains), + "binder_chain":",".join(binder_chains), + "rm_aa":o.rm_aa} + opt_extra = {} + + elif sum(fixed_pos) > 0: + protocol = "partial" + print("protocol=partial") + af_model = mk_af_model(protocol="fixbb", + use_templates=True, + **flags) + rm_template = np.array(fixed_pos) == 0 + prep_flags = {"chain":",".join(chains), + "rm_template":rm_template, + "rm_template_seq":rm_template, + "copies":o.copies, + "homooligomer":o.copies>1, + "rm_aa":o.rm_aa} + else: + protocol = "fixbb" + print("protocol=fixbb") + af_model = mk_af_model(protocol="fixbb",**flags) + prep_flags = {"chain":",".join(chains), + "copies":o.copies, + "homooligomer":o.copies>1, + "rm_aa":o.rm_aa} + + batch_size = 8 + if o.num_seqs < batch_size: + batch_size = o.num_seqs + + print("running proteinMPNN...") + sampling_temp = o.mpnn_sampling_temp + mpnn_model = mk_mpnn_model(weights="soluble" if o.use_soluble else "original") + outs = [] + pdbs = [] + for m in range(o.num_designs): + if o.num_designs == 0: + pdb_filename = o.pdb + else: + pdb_filename = o.pdb.replace("_0.pdb",f"_{m}.pdb") + pdbs.append(pdb_filename) + af_model.prep_inputs(pdb_filename, **prep_flags) + if protocol == "partial": + p = np.where(fixed_pos)[0] + af_model.opt["fix_pos"] = p[p < af_model._len] + + mpnn_model.get_af_inputs(af_model) + outs.append(mpnn_model.sample(num=o.num_seqs//batch_size, batch=batch_size, temperature=sampling_temp)) + + if protocol == "binder": + af_terms = ["plddt","i_ptm","i_pae","rmsd"] + elif o.copies > 1: + af_terms = ["plddt","ptm","i_ptm","pae","i_pae","rmsd"] + else: + af_terms = ["plddt","ptm","pae","rmsd"] + + labels = ["design","n","score"] + af_terms + ["seq"] + data = [] + best = {"rmsd":np.inf,"design":0,"n":0} + print("running AlphaFold...") + os.system(f"mkdir -p {o.loc}/all_pdb") + with open(f"{o.loc}/design.fasta","w") as fasta: + for m,(out,pdb_filename) in enumerate(zip(outs,pdbs)): + out["design"] = [] + out["n"] = [] + af_model.prep_inputs(pdb_filename, **prep_flags) + for k in af_terms: out[k] = [] + for n in range(o.num_seqs): + out["design"].append(m) + out["n"].append(n) + sub_seq = out["seq"][n].replace("/","")[-af_model._len:] + af_model.predict(seq=sub_seq, num_recycles=o.num_recycles, verbose=False) + for t in af_terms: out[t].append(af_model.aux["log"][t]) + if "i_pae" in out: + out["i_pae"][-1] = out["i_pae"][-1] * 31 + if "pae" in out: + out["pae"][-1] = out["pae"][-1] * 31 + rmsd = out["rmsd"][-1] + if rmsd < best["rmsd"]: + best = {"design":m,"n":n,"rmsd":rmsd} + af_model.save_current_pdb(f"{o.loc}/all_pdb/design{m}_n{n}.pdb") + af_model._save_results(save_best=True, verbose=False) + af_model._k += 1 + score_line = [f'design:{m} n:{n}',f'mpnn:{out["score"][n]:.3f}'] + for t in af_terms: + score_line.append(f'{t}:{out[t][n]:.3f}') + print(" ".join(score_line)+" "+out["seq"][n]) + line = f'>{"|".join(score_line)}\n{out["seq"][n]}' + fasta.write(line+"\n") + data += [[out[k][n] for k in labels] for n in range(o.num_seqs)] + af_model.save_pdb(f"{o.loc}/best_design{m}.pdb") + + # save best + with open(f"{o.loc}/best.pdb", "w") as handle: + remark_text = f"design {best['design']} N {best['n']} RMSD {best['rmsd']:.3f}" + handle.write(f"REMARK 001 {remark_text}\n") + handle.write(open(f"{o.loc}/best_design{best['design']}.pdb", "r").read()) + + labels[2] = "mpnn" + df = pd.DataFrame(data, columns=labels) + df.to_csv(f'{o.loc}/mpnn_results.csv') + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/build/lib/colabdesign/rf/utils.py b/build/lib/colabdesign/rf/utils.py new file mode 100644 index 00000000..9b06f305 --- /dev/null +++ b/build/lib/colabdesign/rf/utils.py @@ -0,0 +1,246 @@ +import matplotlib +import matplotlib.pyplot as plt +from matplotlib import animation +from colabdesign.shared.plot import plot_pseudo_3D, pymol_cmap, _np_kabsch +from string import ascii_uppercase, ascii_lowercase +alphabet_list = list(ascii_uppercase+ascii_lowercase) +import numpy as np + +def sym_it(coords, center, cyclic_symmetry_axis, reflection_axis=None): + + def rotation_matrix(axis, theta): + axis = axis / np.linalg.norm(axis) + a = np.cos(theta / 2) + b, c, d = -axis * np.sin(theta / 2) + return np.array([[a*a+b*b-c*c-d*d, 2*(b*c-a*d), 2*(b*d+a*c)], + [2*(b*c+a*d), a*a+c*c-b*b-d*d, 2*(c*d-a*b)], + [2*(b*d-a*c), 2*(c*d+a*b), a*a+d*d-b*b-c*c]]) + + def align_axes(coords, source_axis, target_axis): + rotation_axis = np.cross(source_axis, target_axis) + rotation_angle = np.arccos(np.dot(source_axis, target_axis)) + rot_matrix = rotation_matrix(rotation_axis, rotation_angle) + return np.dot(coords, rot_matrix) + + # Center the coordinates + coords = coords - center + + # Align cyclic symmetry axis with Z-axis + z_axis = np.array([0, 0, 1]) + coords = align_axes(coords, cyclic_symmetry_axis, z_axis) + + if reflection_axis is not None: + # Align reflection axis with X-axis + x_axis = np.array([1, 0, 0]) + coords = align_axes(coords, reflection_axis, x_axis) + return coords + +def fix_partial_contigs(contigs, parsed_pdb): + INF = float("inf") + + # get unique chains + chains = [] + for c, i in parsed_pdb["pdb_idx"]: + if c not in chains: chains.append(c) + + # get observed positions and chains + ok = [] + for contig in contigs: + for x in contig.split("/"): + if x[0].isalpha: + C,x = x[0],x[1:] + S,E = -INF,INF + if x.startswith("-"): + E = int(x[1:]) + elif x.endswith("-"): + S = int(x[:-1]) + elif "-" in x: + (S,E) = (int(y) for y in x.split("-")) + elif x.isnumeric(): + S = E = int(x) + for c, i in parsed_pdb["pdb_idx"]: + if c == C and i >= S and i <= E: + if [c,i] not in ok: ok.append([c,i]) + + # define new contigs + new_contigs = [] + for C in chains: + new_contig = [] + unseen = [] + seen = [] + for c,i in parsed_pdb["pdb_idx"]: + if c == C: + if [c,i] in ok: + L = len(unseen) + if L > 0: + new_contig.append(f"{L}-{L}") + unseen = [] + seen.append([c,i]) + else: + L = len(seen) + if L > 0: + new_contig.append(f"{seen[0][0]}{seen[0][1]}-{seen[-1][1]}") + seen = [] + unseen.append([c,i]) + L = len(unseen) + if L > 0: + new_contig.append(f"{L}-{L}") + L = len(seen) + if L > 0: + new_contig.append(f"{seen[0][0]}{seen[0][1]}-{seen[-1][1]}") + new_contigs.append("/".join(new_contig)) + + return new_contigs + +def fix_contigs(contigs,parsed_pdb): + def fix_contig(contig): + INF = float("inf") + X = contig.split("/") + Y = [] + for n,x in enumerate(X): + if x[0].isalpha(): + C,x = x[0],x[1:] + S,E = -INF,INF + if x.startswith("-"): + E = int(x[1:]) + elif x.endswith("-"): + S = int(x[:-1]) + elif "-" in x: + (S,E) = (int(y) for y in x.split("-")) + elif x.isnumeric(): + S = E = int(x) + new_x = "" + c_,i_ = None,0 + for c, i in parsed_pdb["pdb_idx"]: + if c == C and i >= S and i <= E: + if c_ is None: + new_x = f"{c}{i}" + else: + if c != c_ or i != i_+1: + new_x += f"-{i_}/{c}{i}" + c_,i_ = c,i + Y.append(new_x + f"-{i_}") + elif "-" in x: + # sample length + s,e = x.split("-") + m = np.random.randint(int(s),int(e)+1) + Y.append(f"{m}-{m}") + elif x.isnumeric() and x != "0": + Y.append(f"{x}-{x}") + return "/".join(Y) + return [fix_contig(x) for x in contigs] + +def fix_pdb(pdb_str, contigs): + def get_range(contig): + L_init = 1 + R = [] + sub_contigs = [x.split("-") for x in contig.split("/")] + for n,(a,b) in enumerate(sub_contigs): + if a[0].isalpha(): + if n > 0: + pa,pb = sub_contigs[n-1] + if pa[0].isalpha() and a[0] == pa[0]: + L_init += int(a[1:]) - int(pb) - 1 + L = int(b)-int(a[1:]) + 1 + else: + L = int(b) + R += range(L_init,L_init+L) + L_init += L + return R + + contig_ranges = [get_range(x) for x in contigs] + R,C = [],[] + for n,r in enumerate(contig_ranges): + R += r + C += [alphabet_list[n]] * len(r) + + pdb_out = [] + r_, c_,n = None, None, 0 + for line in pdb_str.split("\n"): + if line[:4] == "ATOM": + c = line[21:22] + r = int(line[22:22+5]) + if r_ is None: r_ = r + if c_ is None: c_ = c + if r != r_ or c != c_: + n += 1 + r_,c_ = r,c + pdb_out.append("%s%s%4i%s" % (line[:21],C[n],R[n],line[26:])) + if line[:5] == "MODEL" or line[:3] == "TER" or line[:6] == "ENDMDL": + pdb_out.append(line) + r_, c_,n = None, None, 0 + return "\n".join(pdb_out) + +def get_ca(pdb_filename, get_bfact=False): + xyz = [] + bfact = [] + for line in open(pdb_filename, "r"): + line = line.rstrip() + if line[:4] == "ATOM": + atom = line[12:12+4].strip() + if atom == "CA": + x = float(line[30:30+8]) + y = float(line[38:38+8]) + z = float(line[46:46+8]) + xyz.append([x, y, z]) + if get_bfact: + b_factor = float(line[60:60+6].strip()) + bfact.append(b_factor) + if get_bfact: + return np.array(xyz), np.array(bfact) + else: + return np.array(xyz) + +def get_Ls(contigs): + Ls = [] + for contig in contigs: + L = 0 + for n,(a,b) in enumerate(x.split("-") for x in contig.split("/")): + if a[0].isalpha(): + L += int(b)-int(a[1:]) + 1 + else: + L += int(b) + Ls.append(L) + return Ls + +def make_animation(pos, plddt=None, Ls=None, ref=0, line_w=2.0, dpi=100): + if plddt is None: + plddt = [None] * len(pos) + + # center inputs + pos = pos - pos[ref,None].mean(1,keepdims=True) + + # align to best view + best_view = _np_kabsch(pos[ref], pos[ref], return_v=True, use_jax=False) + pos = np.asarray([p @ best_view for p in pos]) + + fig, (ax1) = plt.subplots(1) + fig.set_figwidth(5) + fig.set_figheight(5) + fig.set_dpi(dpi) + + xy_min = pos[...,:2].min() - 1 + xy_max = pos[...,:2].max() + 1 + z_min = None #pos[...,-1].min() - 1 + z_max = None #pos[...,-1].max() + 1 + + for ax in [ax1]: + ax.set_xlim(xy_min, xy_max) + ax.set_ylim(xy_min, xy_max) + ax.axis(False) + + ims=[] + for pos_,plddt_ in zip(pos,plddt): + if plddt_ is None: + if Ls is None: + img = plot_pseudo_3D(pos_, ax=ax1, line_w=line_w, zmin=z_min, zmax=z_max) + else: + c = np.concatenate([[n]*L for n,L in enumerate(Ls)]) + img = plot_pseudo_3D(pos_, c=c, cmap=pymol_cmap, cmin=0, cmax=39, line_w=line_w, ax=ax1, zmin=z_min, zmax=z_max) + else: + img = plot_pseudo_3D(pos_, c=plddt_, cmin=50, cmax=90, line_w=line_w, ax=ax1, zmin=z_min, zmax=z_max) + ims.append([img]) + + ani = animation.ArtistAnimation(fig, ims, blit=True, interval=120) + plt.close() + return ani.to_html5_video() \ No newline at end of file diff --git a/build/lib/colabdesign/seq/__init__.py b/build/lib/colabdesign/seq/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/colabdesign/seq/kmeans.py b/build/lib/colabdesign/seq/kmeans.py new file mode 100644 index 00000000..03b7c496 --- /dev/null +++ b/build/lib/colabdesign/seq/kmeans.py @@ -0,0 +1,133 @@ +import jax +import jax.numpy as jnp +import numpy +from math import log + +def _kmeans(X, X_weight, n_clusters=8, n_init=10, max_iter=300, tol=1e-4, seed=0): + '''kmeans implemented in jax''' + + def _dist(a,b): + sm = a @ b.T + + a_norm = jnp.square(a).sum(-1) + b_norm = jnp.square(b).sum(-1) + + return jnp.abs(a_norm[:,None] + b_norm[None,:] - 2 * sm) + + def _kmeans_plus_plus(key, X, X_weight, n_clusters): + '''kmeans++ implemented in jax, for initialization''' + n_samples, n_features = X.shape + n_candidates = 2 + int(log(n_clusters)) + + def loop(m,c): + n,k = c + + inf_mask = jnp.inf * (jnp.arange(n_clusters) > n) + p = (inf_mask + _dist(X,m)).min(-1) + + # sample candidates + candidates = jax.random.choice(k, jnp.arange(n_samples), + shape=(n_candidates,), + p=p/p.sum(), replace=False) + + # pick sample that decreases inertia the most + dist = jnp.minimum(p[:,None],_dist(X,X[candidates])) + i = candidates[(X_weight[:,None] * dist).sum(0).argmin()] + return m.at[n].set(X[i]), None + + i = jax.random.choice(key,jnp.arange(n_samples)) + init_means = jnp.zeros((n_clusters,n_features)).at[0].set(X[i]) + carry = (jnp.arange(1,n_clusters), jax.random.split(key, n_clusters-1)) + return jax.lax.scan(loop, init_means, carry)[0] + + def _E(means): + # get labels + return _dist(X,means).argmin(-1) + + def _M(labels): + # get means + labels = jax.nn.one_hot(labels, n_clusters) + labels = labels * X_weight[:,None] + labels /= labels.sum(0) + 1e-8 + return labels.T @ X + + def _inertia(means): + # compute score: sum(min(dist(X,means))) + sco = _dist(X,means).min(-1) + return (X_weight * sco).sum() + + def single_run(key): + # initialize + init_means = _kmeans_plus_plus(key, X, X_weight, n_clusters) + + # run EM + if tol == 0: + means = jax.lax.scan(lambda mu,_:(_M(_E(mu)),None), init_means, + None, length=max_iter)[0] + else: + def EM(x): + old_mu, old_sco, _, n = x + new_mu = _M(_E(old_mu)) + new_sco = _inertia(new_mu) + return new_mu, new_sco, old_sco, n+1 + def check(x): + _, new_sco, old_sco, n = x + return ((old_sco-new_sco) > tol) & (n < max_iter) + init = EM((init_means,jnp.inf,None,0)) + means = jax.lax.while_loop(check, EM, init)[0] + + return {"labels":_E(means), + "means":means, + "inertia":_inertia(means)} + + # mulitple runs + key = jax.random.PRNGKey(seed) + if n_init > 0: + out = jax.vmap(single_run)(jax.random.split(key,n_init)) + i = out["inertia"].argmin() + out = jax.tree_map(lambda x:x[i],out) + else: + out = single_run(key) + + labels = jax.nn.one_hot(out["labels"],n_clusters) + cat = (labels * X_weight[:,None]).sum(0) / X_weight.sum() + return {**out, "cat":cat} + +def kmeans(x, x_weights, k, seed=0, max_iter=300): + N,L,A = x.shape + if k == 1: + kms = {"means":(x*x_weights[:,None,None]).sum(0,keepdims=True)/x_weights.sum(), + "labels":jnp.zeros(N,dtype=int), + "cat":jnp.ones((1,))} + else: + kms = _kmeans(x.reshape(N,-1), x_weights, n_clusters=k, max_iter=max_iter, seed=seed) + kms["means"] = kms["means"].reshape(k,L,A) + return kms + +def kmeans_sample(msa, msa_weights, k=1, samples=None, seed=0): + + assert k > 0 + + # run kmeans + kms = kmeans(jnp.asarray(msa), jnp.asarray(msa_weights), k=k, seed=seed) + + # sample sequences from kmeans + key = jax.random.PRNGKey(seed) + N,L,A = msa.shape + if samples is None: + # if number of samples is undefined, set to size of input MSA + samples = N + sampled_labels = kms["labels"] + else: + # sample labels + key, key_ = jax.random.split(key) + sampled_labels = jnp.sort(jax.random.choice(key_,jnp.arange(k),shape=(samples,),p=kms["cat"])) + + # sample MSA + sampled_msa = kms["means"][sampled_labels] + sampled_msa = (sampled_msa.cumsum(-1) >= jax.random.uniform(key, shape=(samples,L,1))).argmax(-1) + o = {"kms":kms, + "sampled_labels":sampled_labels, + "sampled_msa":sampled_msa} + + return jax.tree_map(lambda x:np.asarray(x),o) \ No newline at end of file diff --git a/build/lib/colabdesign/seq/learn_msa.py b/build/lib/colabdesign/seq/learn_msa.py new file mode 100644 index 00000000..380383b7 --- /dev/null +++ b/build/lib/colabdesign/seq/learn_msa.py @@ -0,0 +1,89 @@ +import jax +import jax.numpy as jnp +import numpy + +from colabdesign.seq.kmeans import kmeans +from colabdesign.seq.stats import get_stats, get_eff + +# LEARN SEQUENCES +# "parameter-free" model, where we learn msa to match the statistics. +# We can take kmeans to the "next" level and directly optimize sequences to match desired stats. + +class LEARN_MSA: + def __init__(self, X, X_weight=None, samples=None, + mode="tied", k=1, + seed=0, learning_rate=1e-3): + + assert mode in ["tied","full"] + assert k > 0 + + key = jax.random.PRNGKey(seed) + self.k = k + + # collect X stats + N,L,A = X.shape + if samples is None: samples = N + X = jnp.asarray(X) + X_weight = get_eff(X) if X_weight is None else jnp.asarray(X_weight) + + # run kmeans + self.kms = kmeans(X, X_weight, k=self.k) + stats_args = dict(add_f_ij=True, add_mf_ij=(mode=="full"), add_c=True) + self.X_stats = get_stats(X, X_weight, labels=jax.nn.one_hot(self.kms["labels"],self.k), **stats_args) + + if samples == N: + self.Y_labels = self.kms["labels"] + else: + # sample labels + key,key_ = jax.random.split(key) + self.Y_labels = jnp.sort(jax.random.choice(key_, jnp.arange(k), shape=(samples,), p=self.kms["cat"])) + + key, key_ = jax.random.split(key) + Neff = X_weight.sum() + Y_logits = jnp.log(self.kms["means"] * Neff + 0.01 * jnp.log(Neff))[self.Y_labels] + Y = jax.nn.softmax(Y_logits + jax.random.gumbel(key_,(samples,L,A))) + + # setup the model + def model(params, X_stats): + # categorical reparameterization of Y + Y_hard = jax.nn.one_hot(params["Y"].argmax(-1),A) + Y = jax.lax.stop_gradient(Y_hard - params["Y"]) + params["Y"] + + # collect Y stats + Y_stats = get_stats(Y, labels=jax.nn.one_hot(self.Y_labels, self.k), **stats_args) + + # define loss function + i,ij = ("f_i","c_ij") if k == 1 else ("mf_i",("c_ij" if mode == "tied" else "mc_ij")) + loss_i = jnp.square(X_stats[i] - Y_stats[i]).sum((-1,-2)) + loss_ij = jnp.square(X_stats[ij] - Y_stats[ij]).sum((-1,-2,-3)).mean(-1) + + if self.k > 1: + loss_i = (loss_i * self.kms["cat"]).sum() + if mode == "full": + loss_ij = (loss_ij * self.kms["cat"]).sum() + + loss = loss_i + loss_ij + + aux = {"r":get_r(X_stats["c_ij"], Y_stats["c_ij"])} + return loss, aux + + # setup optimizer + self.n = 0 + init_fun, self.update_fun, self.get_params = adam(learning_rate) + self.state = init_fun({"Y":Y}) + self.grad = jax.jit(jax.value_and_grad(model, has_aux=True)) + + def get_msa(self): + Y = np.array(self.get_params(self.state)["Y"]) + return {"kms":self.kms, + "sampled_msa":Y.argmax(-1), + "sampled_labels":self.Y_labels} + + def fit(self, steps=100, verbose=True): + '''train model''' + for n in range(steps): + (loss, aux), grad = self.grad(self.get_params(self.state), self.X_stats) + self.state = self.update_fun(self.n, grad, self.state) + self.n += 1 + if (n+1) % (steps // 10) == 0: + print(self.n, loss, aux["r"]) \ No newline at end of file diff --git a/build/lib/colabdesign/seq/mrf.py b/build/lib/colabdesign/seq/mrf.py new file mode 100644 index 00000000..db61db10 --- /dev/null +++ b/build/lib/colabdesign/seq/mrf.py @@ -0,0 +1,314 @@ +############################ +# TODO: remove reference to laxy, clean up the code +############################ + +def sample_msa(samples=10000, burn_in=1, temp=1.0, + order=None, ar=False, diff=False, seq=True): + + def sample_cat(key, logits=None, probs=None): + if logits is not None: + hard = jax.nn.one_hot(jax.random.categorical(key,logits/temp),logits.shape[-1]) + probs = jax.nn.softmax(logits,-1) + elif probs is not None: + hard = (probs.cumsum(-1) >= jax.random.uniform(key, shape=probs.shape[:-1])).argmax(-1) + if diff: hard = jax.lax.stop_gradient(hard - probs) + probs + return hard + + def sample_pll(key,msa,par): + N,L,A = msa.shape + if seq and ("w" in par or "mw" in par): + # sequential sampling + # burn_in = 1: autoregressive + # burn_in > 1: gibbs + def loop(m,x): + i,k = x + m_logits = [] + if "w" in par: m_logits.append(jnp.einsum("njb,ajb->na", m, par["w"][i])) + if "b" in par: m_logits.append(par["b"][i]) + if "mw" in par: m_logits.append(jnp.einsum("nc,njb,cajb->na", par["labels"], m, par["mw"][:,i])) + if "mb" in par: m_logits.append(jnp.einsum("nc,ca->na", par["labels"], par["mb"][:,i])) + return m.at[:,i].set(sample_cat(k,sum(m_logits))), None + # scan over positions + if order is not None: i = order + elif ar: i = jnp.arange(L) + else: i = jax.random.permutation(key,jnp.arange(L)) + k = jax.random.split(key,L) + return jax.lax.scan(loop,msa,(i,k))[0] + else: + # sample all position independently + logits = [] + if "b" in par: logits.append(par["b"]) + if "mb" in par: logits.append(jnp.einsum("nc,cia->nia", par["labels"], par["mb"])) + if burn_in > 1: + if "w" in par: logits.append(jnp.einsum("njb,iajb->nia", msa, par["w"])) + if "mw" in par: logits.append(jnp.einsum("nc,njb,ciajb->nia", par["labels"], msa, par["mw"])) + return sample_cat(key,sum(logits)) + + def sample(key, params): + for p in ["b","w","mb","mw"]: + if p in params: + L,A = params[p].shape[-2:] + break + msa = jnp.zeros((samples,L,A)) + + # sample from mixture + if "c" in params and ("mb" in params or "mw" in params): + c = params["c"] + labels_logits = jnp.tile(c,(samples,1)) + params["labels"] = jax.nn.one_hot(jax.random.categorical(key,labels_logits),c.shape[0]) + if diff: + labels_soft = jax.nn.softmax(labels_logits,-1) + params["labels"] = jax.lax.stop_gradient(params["labels"] - labels_soft) + labels_soft + else: + params["labels"] = None + + # number of iterations (burn-in) + sample_loop = lambda m,k:(sample_pll(k,m,params),None) + iters = jax.random.split(key,burn_in) + msa = jax.lax.scan(sample_loop,msa,iters)[0] + return {"msa":msa, "labels":params["labels"]} + + return sample + +def reg_loss(params, lam): + reg_loss = [] + if "b" in params: + reg_loss.append(lam * jnp.square(params["b"]).sum()) + if "w" in params: + L,A = params["w"].shape[-2:] + reg_loss.append(lam/2*(L-1)*(A-1) * jnp.square(params["w"]).sum()) + if "mb" in params: + reg_loss.append(lam * jnp.square(params["mb"]).sum()) + if "mw" in params: + L,A = params["mw"].shape[-2:] + reg_loss.append(lam/2*(L-1)*(A-1) * jnp.square(params["mw"]).sum()) + return sum(reg_loss) + +def pll_loss(params, inputs, order=None, labels=None): + logits = [] + + L = inputs["x"].shape[1] + w_mask = 1-jnp.eye(L) + if order is not None: + w_mask *= ar_mask(order) + + if "b" in params: + logits.append(params["b"]) + if "w" in params: + w = params["w"] + w = 0.5 * (w + w.transpose([2,3,0,1])) * w_mask[:,None,:,None] + logits.append(jnp.einsum("nia,iajb->njb", inputs["x"], w)) + + # MIXTURES + if "mb" in params: + logits.append(jnp.einsum("nc,cia->nia", labels, params["mb"])) + + if "mw" in params: + mw = params["mw"] + mw = 0.5 * (mw + mw.transpose([0,3,4,1,2])) * w_mask[None,:,None,:,None] + logits.append(jnp.einsum("nc,nia,ciajb->njb", labels, inputs["x"], mw)) + + # categorical-crossentropy (or pseudo-likelihood) + cce_loss = -(inputs["x"] * jax.nn.log_softmax(sum(logits))).sum([1,2]) + + return (cce_loss*inputs["x_weight"]).sum() + +class MRF: + def __init__(self, X, X_weight=None, + batch_size=None, + ar=False, ar_ent=False, + lam=0.01, + k=1, lr=0.1, shared=False, tied=True, full=False): + + ## MODE ## + inc = ["b","w"] if (tied or full) else ["b"] + if k > 1: + if shared: + if tied: inc += ["mb"] + if full: inc += ["mb","mw"] + else: + if tied: inc = ["mb","w"] + if full: inc = ["mb","mw"] + + N,L,A = X.shape + self.batch_size = batch_size + self.k = k + + # weight per sequence + X = jnp.asarray(X) + X_weight = get_eff(X) if X_weight is None else jnp.asarray(X_weight) + self.Neff = X_weight.sum() + + if batch_size is None: + learning_rate = lr * np.log(N)/L + else: + lam = lam * batch_size/N + learning_rate = lr * jnp.log(batch_size)/L + + if ar: + self.order = jnp.arange(L) + elif ar_ent: + f_i = (X * X_weight[:,None,None]).sum(0)/self.Neff + self.order = (-f_i * jnp.log(f_i + 1e-8)).sum(-1).argsort() + else: + self.order = None + + # setup the model + def model(params, inputs): + labels = inputs["labels"] if "labels" in inputs else None + pll = pll_loss(params, inputs, self.order, labels) + reg = reg_loss(params, lam) + loss = pll + reg + return None, loss + + # initialize inputs + self.inputs = {"x":X, "x_weight":X_weight} + + # initialize params + self.params = {} + if "w" in inc: self.params["w"] = jnp.zeros((L,A,L,A)) + if "mw" in inc: self.params["mw"] = jnp.zeros((k,L,A,L,A)) + if "b" in inc: + b = jnp.log((X * X_weight[:,None,None]).sum(0) + (lam+1e-8) * jnp.log(self.Neff)) + self.params["b"] = b - b.mean(-1,keepdims=True) + if "mb" in inc or "mw" in inc: + kms = kmeans(X, X_weight, k=k) + self.inputs["labels"] = kms["labels"] + mb = jnp.log(kms["means"] * self.Neff + (lam+1e-8) * jnp.log(self.Neff)) + self.params["mb"] = mb - mb.mean(-1,keepdims=True) + if "b" in self.params: + self.params["mb"] -= self.params["b"] + + # setup optimizer + self.opt = laxy.OPT(model, self.params, lr=learning_rate) + + def get_msa(self, samples=1000, burn_in=1): + self.params = self.opt.get_params() + if "labels" in self.inputs: + self.params["c"] = jnp.log((self.inputs["x_weight"][:,None] * self.inputs["labels"]).sum(0) + 1e-8) + + key = laxy.get_random_key() + return sample_msa(samples=samples,burn_in=burn_in,order=self.order)(key, self.params) + + def get_w(self): + self.params = self.opt.get_params() + w = [] + if "w" in self.params: w.append(self.params["w"]) + if "mw" in self.params: w.append(self.params["mw"].sum(0)) + w = sum(w) + w = (w + w.transpose(2,3,0,1))/2 + w = w - w.mean((1,3),keepdims=True) + return w + + def fit(self, steps=100, verbose=True, return_losses=False): + '''train model''' + losses = self.opt.fit(self.inputs, steps=steps, batch_size=self.batch_size, + verbose=verbose, return_losses=return_losses) + if return_losses: return losses + +class MRF_BM: + def __init__(self, X, X_weight=None, samples=1000, + burn_in=1, temp=1.0, + ar=False, ar_ent=True, + lr=0.05, lam=0.01, + k=1, mode="tied"): + + ## MODE ## + inc = ["b","mb"] if k > 1 else ["b"] + if mode == "tied": inc += ["w"] + if mode == "full": inc += ["w","mw"] if k > 1 else ["w"] + + self.X = jnp.asarray(X) + N,L,A = self.X.shape + learning_rate = lr * np.log(N)/L + + # weight per sequence + self.X_weight = get_eff(X) if X_weight is None else jnp.asarray(X_weight) + self.Neff = self.X_weight.sum() + + # collect stats + if k > 1: + self.kms = kmeans(self.X, self.X_weight, k=k) + self.labels = self.kms["labels"] + self.inputs = get_stats(self.X, self.X_weight, labels=self.kms["labels"], + add_mf_ij=("mw" in inc)) + self.inputs["c"] = self.kms["cat"] + else: + self.labels = None + self.inputs = get_stats(self.X, self.X_weight) + + # low entropy to high entropy + if ar_ent: + ent = -(self.inputs["f_i"] * jnp.log(self.inputs["f_i"] + 1e-8)).sum(-1) + self.order = ent.argsort() + elif ar: self.order = jnp.arange(L) + else: self.order = None + + self.burn_in = burn_in + self.temp = temp + + # setup the model + def model(params, inputs): + + # sample msa + sample = sample_msa(samples=samples, burn_in=burn_in, + temp=temp, order=self.order)(inputs["key"], params) + + # compute stats + stats = get_stats(sample["msa"], labels=sample["labels"], + add_mf_ij=("mw" in params)) + + # define gradients + grad = {} + I = (1-jnp.eye(L))[:,None,:,None] + if "c" in params: grad["c"] = sample["labels"].mean(0) - inputs["c"] + if "b" in params: grad["b"] = stats["f_i"] - inputs["f_i"] + if "w" in params: grad["w"] = (stats["f_ij"] - inputs["f_ij"]) * I + if "mb" in params: grad["mb"] = stats["mf_i"] - inputs["mf_i"] + if "mw" in params: grad["mw"] = (stats["mf_ij"] - inputs["mf_ij"]) * I[None] + + # add regularization + reg_grad = jax.grad(reg_loss)(params,lam) + + for g in grad.keys(): + if g in reg_grad: grad[g] = grad[g] * self.Neff + reg_grad[g] + else: grad[g] = grad[g] * self.Neff + + return None, None, grad + + # initialize model params + self.params = {} + if "w" in inc: + self.params["w"] = jnp.zeros((L,A,L,A)) + if "b" in inc: + b = jnp.log(self.inputs["f_i"] * self.Neff + (lam+1e-8) * jnp.log(self.Neff)) + self.params["b"] = b - b.mean(-1,keepdims=True) + + # setup mixture params + if "mb" in inc or "mw" in inc: + c = jnp.log(self.inputs["c"] * self.Neff + 1e-8) + self.params["c"] = c - c.mean(-1,keepdims=True) + + if "mw" in inc: + self.params["mw"] = jnp.zeros((k,L,A,L,A)) + if "mb" in inc: + mb = jnp.log(self.inputs["mf_i"] * self.Neff + (lam+1e-8) * jnp.log(self.Neff)) + self.params["mb"] = mb - mb.mean(-1,keepdims=True) + if "b" in self.params: self.params["mb"] -= self.params["b"] + + # setup optimizer + self.opt = laxy.OPT(model, self.params, lr=learning_rate, has_grad=True) + + def get_msa(self, samples=1000, burn_in=None, temp=None, seed=0): + if burn_in is None: burn_in = self.burn_in + if temp is None: temp = self.temp + + self.params = self.opt.get_params() + key = jax.random.PRNGKey(seed) + return sample_msa(samples=samples, + burn_in=self.burn_in, + order=self.order)(key, self.params)["msa"] + + def fit(self, steps=1000, verbose=True): + '''train model''' + self.opt.fit(self.inputs, steps=steps, verbose=verbose) \ No newline at end of file diff --git a/build/lib/colabdesign/seq/stats.py b/build/lib/colabdesign/seq/stats.py new file mode 100644 index 00000000..83046a6d --- /dev/null +++ b/build/lib/colabdesign/seq/stats.py @@ -0,0 +1,77 @@ +import jax +import jax.numpy as jnp +import numpy as np + +def get_stats(X, X_weight=None, labels=None, add_f_ij=True, add_mf_ij=False, add_c=False): + '''compute f_i/f_ij/f_ijk given msa ''' + n = None + if X_weight is None: + Xn = Xs = X + else: + Xn, Xs = X*X_weight[:,n,n], X*jnp.sqrt(X_weight[:,n,n]) + f_i = Xn.sum(0) + o = {"f_i": f_i / f_i.sum(1,keepdims=True)} + + if add_f_ij: + f_ij = jnp.tensordot(Xs,Xs,[0,0]) + o["f_ij"] = f_ij / f_ij.sum((1,3),keepdims=True) + if add_c: o["c_ij"] = o["f_ij"] - o["f_i"][:,:,n,n] * o["f_i"][n,n,:,:] + + if labels is not None: + # compute mixture stats + if jnp.issubdtype(labels, jnp.integer): + labels = jax.nn.one_hot(labels,labels.max()+1) + mf_i = jnp.einsum("nc,nia->cia", labels, Xn) + o["mf_i"] = mf_i/mf_i.sum((0,2),keepdims=True) + if add_mf_ij: + mf_ij = jnp.einsum("nc,nia,njb->ciajb", labels, Xs, Xs) + o["mf_ij"] = mf_ij/mf_ij.sum((0,2,4),keepdims=True) + if add_c: o["mc_ij"] = o["mf_ij"] - o["mf_i"][:,:,:,n,n] * o["mf_i"][:,n,n,:,:] + return o + +def get_r(a,b): + a = jnp.array(a).flatten() + b = jnp.array(b).flatten() + return jnp.corrcoef(a,b)[0,1] + +def inv_cov(X, X_weight=None): + X = jnp.asarray(X) + N,L,A = X.shape + if X_weight is None: + num_points = N + else: + X_weight = jnp.asarray(X_weight) + num_points = X_weight.sum() + c = get_stats(X, X_weight, add_mf_ij=True, add_c=True)["c_ij"] + c = c.reshape(L*A,L*A) + shrink = 4.5/jnp.sqrt(num_points) * jnp.eye(c.shape[0]) + ic = jnp.linalg.inv(c + shrink) + return ic.reshape(L,A,L,A) + +def get_mtx(W): + W = jnp.asarray(W) + # l2norm of 20x20 matrices (note: we ignore gaps) + raw = jnp.sqrt(jnp.sum(np.square(W[:,1:,:,1:]),(1,3))) + raw = raw.at[jnp.diag_indices_from(raw)].set(0) + + # apc (average product correction) + ap = raw.sum(0,keepdims=True) * raw.sum(1,keepdims=True) / raw.sum() + apc = raw - ap + apc = apc.at[jnp.diag_indices_from(apc)].set(0) + return raw, apc + +def con_auc(true, pred, mask=None): + '''compute agreement between predicted and measured contact map''' + true = jnp.asarray(true) + pred = jnp.asarray(pred) + if mask is not None: + mask = jnp.asarray(mask) + idx = mask.sum(-1) > 0 + true = true[idx,:][:,idx] + pred = pred[idx,:][:,idx] + eval_idx = jnp.triu_indices_from(true, 6) + pred_, true_ = pred[eval_idx], true[eval_idx] + L = (jnp.linspace(0.1,1.0,10)*len(true)).astype(jnp.int32) + sort_idx = jnp.argsort(pred_)[::-1] + return jnp.asarray([true_[sort_idx[:l]].mean() for l in L]) + diff --git a/build/lib/colabdesign/seq/utils.py b/build/lib/colabdesign/seq/utils.py new file mode 100644 index 00000000..579638f1 --- /dev/null +++ b/build/lib/colabdesign/seq/utils.py @@ -0,0 +1,60 @@ +import os, string +import numpy as np +import jax +import jax.numpy as jnp + +ALPHABET = list("ARNDCQEGHILKMFPSTWYV-") + +def parse_fasta(filename, a3m=False, stop=100000): + '''function to parse fasta file''' + + if a3m: + # for a3m files the lowercase letters are removed + # as these do not align to the query sequence + rm_lc = str.maketrans(dict.fromkeys(string.ascii_lowercase)) + + header, sequence = [],[] + lines = open(filename, "r") + for line in lines: + line = line.rstrip() + if len(line) > 0: + if line[0] == ">": + if len(header) == stop: + break + else: + header.append(line[1:]) + sequence.append([]) + else: + if a3m: line = line.translate(rm_lc) + else: line = line.upper() + sequence[-1].append(line) + lines.close() + sequence = [''.join(seq) for seq in sequence] + + return header, sequence + +def mk_msa(seqs): + '''one hot encode msa''' + states = len(ALPHABET) + a2n = {a:n for n,a in enumerate(ALPHABET)} + msa_ori = np.array([[a2n.get(aa, states-1) for aa in seq] for seq in seqs]) + return np.eye(states)[msa_ori] + +def get_eff(msa, eff_cutoff=0.8): + '''compute weight per sequence''' + if msa.shape[0] > 10000: + # loop one-to-all (to avoid memory issues) + msa = msa.argmax(-1) + def get_w(seq): return 1/((seq==msa).mean(-1) > eff_cutoff).sum() + return jax.lax.scan(lambda _,x:(_,get_w(x)),None,msa,unroll=2)[1] + else: + # all-to-all + msa_ident = jnp.tensordot(msa,msa,[[1,2],[1,2]])/msa.shape[1] + return 1/(msa_ident >= eff_cutoff).sum(-1) + +def ar_mask(order, diag=True): + '''compute autoregressive mask, given order of positions''' + L = order.shape[0] + r = order[::-1].argsort() + tri = jnp.triu(jnp.ones((L,L)),k=not diag) + return tri[r[None,:],r[:,None]] diff --git a/build/lib/colabdesign/shared/__init__.py b/build/lib/colabdesign/shared/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/colabdesign/shared/chunked_vmap.py b/build/lib/colabdesign/shared/chunked_vmap.py new file mode 100644 index 00000000..2e81135d --- /dev/null +++ b/build/lib/colabdesign/shared/chunked_vmap.py @@ -0,0 +1,356 @@ +# Copyright 2021 The NetKet Authors - All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# WORK IN PROGRESS + +import sys + +from functools import partial + +import jax +import jax.numpy as jnp + +from jax.extend import linear_util as lu +from jax.api_util import argnums_partial + +from functools import partial +from typing import Optional, Callable + +_tree_add = partial(jax.tree_util.tree_map, jax.lax.add) +_tree_zeros_like = partial(jax.tree_util.tree_map, lambda x: jnp.zeros(x.shape, dtype=x.dtype)) + + + +def _treeify(f): + def _f(x, *args, **kwargs): + return jax.tree_util.tree_map(lambda y: f(y, *args, **kwargs), x) + + return _f + + +@_treeify +def _unchunk(x): + return x.reshape((-1,) + x.shape[2:]) + + +@_treeify +def _chunk(x, chunk_size=None): + # chunk_size=None -> add just a dummy chunk dimension, same as np.expand_dims(x, 0) + n = x.shape[0] + if chunk_size is None: + chunk_size = n + + n_chunks, residual = divmod(n, chunk_size) + if residual != 0: + raise ValueError( + "The first dimension of x must be divisible by chunk_size." + + f"\n Got x.shape={x.shape} but chunk_size={chunk_size}." + ) + return x.reshape((n_chunks, chunk_size) + x.shape[1:]) + + +def _chunk_size(x): + b = set(map(lambda x: x.shape[:2], jax.tree_util.tree_leaves(x))) + if len(b) != 1: + raise ValueError( + "The arrays in x have inconsistent chunk_size or number of chunks" + ) + return b.pop()[1] + + +def unchunk(x_chunked): + """ + Merge the first two axes of an array (or a pytree of arrays) + Args: + x_chunked: an array (or pytree of arrays) of at least 2 dimensions + Returns: a pair (x, chunk_fn) + where x is x_chunked reshaped to (-1,)+x.shape[2:] + and chunk_fn is a function which restores x given x_chunked + """ + return _unchunk(x_chunked), partial(_chunk, chunk_size=_chunk_size(x_chunked)) + + +def chunk(x, chunk_size=None): + """ + Split an array (or a pytree of arrays) into chunks along the first axis + Args: + x: an array (or pytree of arrays) + chunk_size: an integer or None (default) + The first axis in x must be a multiple of chunk_size + Returns: a pair (x_chunked, unchunk_fn) where + - x_chunked is x reshaped to (-1, chunk_size)+x.shape[1:] + if chunk_size is None then it defaults to x.shape[0], i.e. just one chunk + - unchunk_fn is a function which restores x given x_chunked + """ + return _chunk(x, chunk_size), _unchunk + +# TODO put it somewher + +def _multimap(f, *args): + try: + return tuple(map(lambda a: f(*a), zip(*args))) + except TypeError: + return f(*args) + + +def scan_append_reduce(f, x, append_cond, op=_tree_add): + """Evaluate f element by element in x while appending and/or reducing the results + Args: + f: a function that takes elements of the leading dimension of x + x: a pytree where each leaf array has the same leading dimension + append_cond: a bool (if f returns just one result) or a tuple of bools (if f returns multiple values) + which indicates whether the individual result should be appended or reduced + op: a function to (pairwise) reduce the specified results. Defaults to a sum. + Returns: + returns the (tuple of) results corresponding to the output of f + where each result is given by: + if append_cond is True: + a (pytree of) array(s) with leading dimension same as x, + containing the evaluation of f at each element in x + else (append_cond is False): + a (pytree of) array(s) with the same shape as the corresponding output of f, + containing the reduction over op of f evaluated at each x + Example: + import jax.numpy as jnp + from netket.jax import scan_append_reduce + def f(x): + y = jnp.sin(x) + return y, y, y**2 + N = 100 + x = jnp.linspace(0.,jnp.pi,N) + y, s, s2 = scan_append_reduce(f, x, (True, False, False)) + mean = s/N + var = s2/N - mean**2 + """ + # TODO: different op for each result + + x0 = jax.tree_util.tree_map(lambda x: x[0], x) + + # special code path if there is only one element + # to avoid having to rely on xla/llvm to optimize the overhead away + if jax.tree_util.tree_leaves(x)[0].shape[0] == 1: + return _multimap( + lambda c, x: jnp.expand_dims(x, 0) if c else x, append_cond, f(x0) + ) + + # the original idea was to use pytrees, however for now just operate on the return value tuple + _get_append_part = partial(_multimap, lambda c, x: x if c else None, append_cond) + _get_op_part = partial(_multimap, lambda c, x: x if not c else None, append_cond) + _tree_select = partial(_multimap, lambda c, t1, t2: t1 if c else t2, append_cond) + + carry_init = True, _get_op_part(_tree_zeros_like(jax.eval_shape(f, x0))) + + def f_(carry, x): + is_first, y_carry = carry + y = f(x) + y_op = _get_op_part(y) + y_append = _get_append_part(y) + # select here to avoid the user having to specify the zero element for op + y_reduce = jax.tree_util.tree_map( + partial(jax.lax.select, is_first), y_op, op(y_carry, y_op) + ) + return (False, y_reduce), y_append + + (_, res_op), res_append = jax.lax.scan(f_, carry_init, x, unroll=1) + # reconstruct the result from the reduced and appended parts in the two trees + return _tree_select(res_append, res_op) + + +scan_append = partial(scan_append_reduce, append_cond=True) +scan_reduce = partial(scan_append_reduce, append_cond=False) + + +# TODO in_axes a la vmap? +def scanmap(fun, scan_fun, argnums=0): + """ + A helper function to wrap f with a scan_fun + Example: + import jax.numpy as jnp + from functools import partial + from netket.jax import scanmap, scan_append_reduce + scan_fun = partial(scan_append_reduce, append_cond=(True, False, False)) + @partial(scanmap, scan_fun=scan_fun, argnums=1) + def f(c, x): + y = jnp.sin(x) + c + return y, y, y**2 + N = 100 + x = jnp.linspace(0.,jnp.pi,N) + c = 1. + y, s, s2 = f(c, x) + mean = s/N + var = s2/N - mean**2 + """ + + def f_(*args, **kwargs): + f = lu.wrap_init(fun, kwargs) + f_partial, dyn_args = argnums_partial( + f, argnums, args, require_static_args_hashable=False + ) + return scan_fun(lambda x: f_partial.call_wrapped(*x), dyn_args) + + return f_ + + +class HashablePartial(partial): + """ + A class behaving like functools.partial, but that retains it's hash + if it's created with a lexically equivalent (the same) function and + with the same partially applied arguments and keywords. + It also stores the computed hash for faster hashing. + """ + + # TODO remove when dropping support for Python < 3.10 + def __new__(cls, func, *args, **keywords): + # In Python 3.10+ if func is itself a functools.partial instance, + # functools.partial.__new__ would merge the arguments of this HashablePartial + # instance with the arguments of the func + # Pre 3.10 this does not happen, so here we emulate this behaviour recursively + # This is necessary since functools.partial objects do not have a __code__ + # property which we use for the hash + # For python 3.10+ we still need to take care of merging with another HashablePartial + while isinstance( + func, partial if sys.version_info < (3, 10) else HashablePartial + ): + original_func = func + func = original_func.func + args = original_func.args + args + keywords = {**original_func.keywords, **keywords} + return super().__new__(cls, func, *args, **keywords) + + def __init__(self, *args, **kwargs): + self._hash = None + + def __eq__(self, other): + return ( + type(other) is HashablePartial + and self.func.__code__ == other.func.__code__ + and self.args == other.args + and self.keywords == other.keywords + ) + + def __hash__(self): + if self._hash is None: + self._hash = hash( + (self.func.__code__, self.args, frozenset(self.keywords.items())) + ) + + return self._hash + + def __repr__(self): + return f"" + + + +def _fun(vmapped_fun, chunk_size, argnums, *args, **kwargs): + n_elements = jax.tree_util.tree_leaves(args[argnums[0]])[0].shape[0] + n_chunks, n_rest = divmod(n_elements, chunk_size) + + if n_chunks == 0 or chunk_size >= n_elements: + y = vmapped_fun(*args, **kwargs) + else: + # split inputs + def _get_chunks(x): + x_chunks = jax.tree_util.tree_map(lambda x_: x_[: n_elements - n_rest, ...], x) + x_chunks = _chunk(x_chunks, chunk_size) + return x_chunks + + def _get_rest(x): + x_rest = jax.tree_util.tree_map(lambda x_: x_[n_elements - n_rest :, ...], x) + return x_rest + + args_chunks = [ + _get_chunks(a) if i in argnums else a for i, a in enumerate(args) + ] + args_rest = [_get_rest(a) if i in argnums else a for i, a in enumerate(args)] + + y_chunks = _unchunk( + scanmap(vmapped_fun, scan_append, argnums)(*args_chunks, **kwargs) + ) + + if n_rest == 0: + y = y_chunks + else: + y_rest = vmapped_fun(*args_rest, **kwargs) + y = jax.tree_util.tree_map(lambda y1, y2: jnp.concatenate((y1, y2)), y_chunks, y_rest) + return y + + +def _chunk_vmapped_function( + vmapped_fun: Callable, chunk_size: Optional[int], argnums=0 +) -> Callable: + """takes a vmapped function and computes it in chunks""" + + if chunk_size is None: + return vmapped_fun + + if isinstance(argnums, int): + argnums = (argnums,) + + return HashablePartial(_fun, vmapped_fun, chunk_size, argnums) + + +def _parse_in_axes(in_axes): + if isinstance(in_axes, int): + in_axes = (in_axes,) + + if not set(in_axes).issubset((0, None)): + raise NotImplementedError("Only in_axes 0/None are currently supported") + + argnums = tuple( + map(lambda ix: ix[0], filter(lambda ix: ix[1] is not None, enumerate(in_axes))) + ) + return in_axes, argnums + + +def apply_chunked(f: Callable, in_axes=0, *, chunk_size: Optional[int]) -> Callable: + """ + Takes an implicitly vmapped function over the axis 0 and uses scan to + do the computations in smaller chunks over the 0-th axis of all input arguments. + For this to work, the function `f` should be `vectorized` along the `in_axes` + of the arguments. This means that the function `f` should respect the following + condition: + .. code-block:: python + assert f(x) == jnp.concatenate([f(x_i) for x_i in x], axis=0) + which is automatically satisfied if `f` is obtained by vmapping a function, + such as: + .. code-block:: python + f = jax.vmap(f_orig) + Args: + f: A function that satisfies the condition above + in_axes: The axes that should be scanned along. Only supports `0` or `None` + chunk_size: The maximum size of the chunks to be used. If it is `None`, chunking + is disabled + """ + _, argnums = _parse_in_axes(in_axes) + return _chunk_vmapped_function(f, chunk_size, argnums) + + +def vmap_chunked(f: Callable, in_axes=0, *, chunk_size: Optional[int]) -> Callable: + """ + Behaves like jax.vmap but uses scan to chunk the computations in smaller chunks. + This function is essentially equivalent to: + .. code-block:: python + nk.jax.apply_chunked(jax.vmap(f, in_axes), in_axes, chunk_size) + Some limitations to `in_axes` apply. + Args: + f: The function to be vectorised. + in_axes: The axes that should be scanned along. Only supports `0` or `None` + chunk_size: The maximum size of the chunks to be used. If it is `None`, chunking + is disabled + Returns: + A vectorised and chunked function + """ + in_axes, argnums = _parse_in_axes(in_axes) + vmapped_fun = jax.vmap(f, in_axes=in_axes) + return _chunk_vmapped_function(vmapped_fun, chunk_size, argnums) diff --git a/build/lib/colabdesign/shared/model.py b/build/lib/colabdesign/shared/model.py new file mode 100644 index 00000000..e02a3692 --- /dev/null +++ b/build/lib/colabdesign/shared/model.py @@ -0,0 +1,223 @@ +import jax +import jax.numpy as jnp +import numpy as np +import optax + +from colabdesign.shared.utils import copy_dict, update_dict, softmax, Key +from colabdesign.shared.prep import rewire +from colabdesign.af.alphafold.common import residue_constants + +aa_order = residue_constants.restype_order +order_aa = {b:a for a,b in aa_order.items()} + +class design_model: + def set_weights(self, *args, **kwargs): + ''' + set weights + ------------------- + note: model.restart() resets the weights to their defaults + use model.set_weights(..., set_defaults=True) to avoid this + ------------------- + model.set_weights(rmsd=1) + ''' + if kwargs.pop("set_defaults", False): + update_dict(self._opt["weights"], *args, **kwargs) + update_dict(self.opt["weights"], *args, **kwargs) + + def set_seq(self, seq=None, mode=None, bias=None, rm_aa=None, set_state=True, **kwargs): + ''' + set sequence params and bias + ----------------------------------- + -seq=str or seq=[str,str] or seq=array(shape=(L,20) or shape=(?,L,20)) + -mode= + -"wildtype"/"wt" = initialize sequence with sequence saved from input PDB + -"gumbel" = initial sequence with gumbel distribution + -"soft_???" = apply softmax-activation to initailized sequence (eg. "soft_gumbel") + -bias=array(shape=(20,) or shape=(L,20)) - bias the sequence + -rm_aa="C,W" = specify which amino acids to remove (aka. add a negative-infinity bias to these aa) + ----------------------------------- + ''' + + # backward compatibility + seq_init = kwargs.pop("seq_init",None) + if seq_init is not None: + modes = ["soft","gumbel","wildtype","wt"] + if isinstance(seq_init,str): + seq_init = seq_init.split("_") + if isinstance(seq_init,list) and seq_init[0] in modes: + mode = seq_init + else: + seq = seq_init + + if mode is None: mode = [] + + # decide on shape + shape = (self._num, self._len, self._args.get("alphabet_size",20)) + + # initialize bias + if bias is None: + b = np.zeros(shape[1:]) + else: + b = np.array(np.broadcast_to(bias, shape[1:])) + + # disable certain amino acids + if rm_aa is not None: + for aa in rm_aa.split(","): + b[...,aa_order[aa]] -= 1e6 + + # use wildtype sequence + if ("wildtype" in mode or "wt" in mode) and hasattr(self,"_wt_aatype"): + wt_seq = np.eye(shape[-1])[self._wt_aatype] + wt_seq[self._wt_aatype == -1] = 0 + if "pos" in self.opt and self.opt["pos"].shape[0] == wt_seq.shape[0]: + seq = np.zeros(shape) + seq[:,self.opt["pos"],:] = wt_seq + else: + seq = wt_seq + + # initialize sequence + if seq is None: + if hasattr(self,"key"): + x = 0.01 * np.random.normal(size=shape) + else: + x = np.zeros(shape) + else: + if isinstance(seq, str): + seq = [seq] + if isinstance(seq, list): + if isinstance(seq[0], str): + aa_dict = copy_dict(aa_order) + if shape[-1] > 21: + aa_dict["-"] = 21 # add gap character + seq = np.asarray([[aa_dict.get(aa,-1) for aa in s] for s in seq]) + else: + seq = np.asarray(seq) + else: + seq = np.asarray(seq) + + if np.issubdtype(seq.dtype, np.integer): + seq_ = np.eye(shape[-1])[seq] + seq_[seq == -1] = 0 + seq = seq_ + + if kwargs.pop("add_seq",False): + b = b + seq * 1e7 + + if seq.ndim == 2: + x = np.pad(seq[None],[[0,shape[0]-1],[0,0],[0,0]]) + elif shape[0] > seq.shape[0]: + x = np.pad(seq,[[0,shape[0]-seq.shape[0]],[0,0],[0,0]]) + else: + x = seq + + if "gumbel" in mode: + y_gumbel = jax.random.gumbel(self.key(),shape) + if "soft" in mode: + y = softmax(x + b + y_gumbel) + elif "alpha" in self.opt: + y = x + y_gumbel / self.opt["alpha"] + else: + y = x + y_gumbel + + x = np.where(x.sum(-1,keepdims=True) == 1, x, y) + + # set seq/bias/state + self._params["seq"] = x + self._inputs["bias"] = b + + def _norm_seq_grad(self): + g = self.aux["grad"]["seq"] + eff_L = (np.square(g).sum(-1,keepdims=True) > 0).sum(-2,keepdims=True) + gn = np.linalg.norm(g,axis=(-1,-2),keepdims=True) + self.aux["grad"]["seq"] = g * np.sqrt(eff_L) / (gn + 1e-7) + + def set_optimizer(self, optimizer=None, learning_rate=None, norm_seq_grad=None, **kwargs): + ''' + set/reset optimizer + ---------------------------------- + supported optimizers include: [adabelief, adafactor, adagrad, adam, adamw, + fromage, lamb, lars, noisy_sgd, dpsgd, radam, rmsprop, sgd, sm3, yogi] + ''' + optimizers = {'adabelief':optax.adabelief,'adafactor':optax.adafactor, + 'adagrad':optax.adagrad,'adam':optax.adam, + 'adamw':optax.adamw,'fromage':optax.fromage, + 'lamb':optax.lamb,'lars':optax.lars, + 'noisy_sgd':optax.noisy_sgd,'dpsgd':optax.dpsgd, + 'radam':optax.radam,'rmsprop':optax.rmsprop, + 'sgd':optax.sgd,'sm3':optax.sm3,'yogi':optax.yogi} + + if optimizer is None: optimizer = self._args["optimizer"] + if learning_rate is not None: self.opt["learning_rate"] = learning_rate + if norm_seq_grad is not None: self.opt["norm_seq_grad"] = norm_seq_grad + + o = optimizers[optimizer](1.0, **kwargs) + self._state = o.init(self._params) + + def update_grad(state, grad, params): + updates, state = o.update(grad, state, params) + grad = jax.tree_map(lambda x:-x, updates) + return state, grad + + self._optimizer = jax.jit(update_grad) + + def set_seed(self, seed=None): + np.random.seed(seed=seed) + self.key = Key(seed=seed).get + + def get_seq(self, get_best=True): + ''' + get sequences as strings + - set get_best=False, to get the last sampled sequence + ''' + aux = self._tmp["best"]["aux"] if (get_best and "aux" in self._tmp["best"]) else self.aux + x = aux["seq"]["hard"].argmax(-1) + return ["".join([order_aa[a] for a in s]) for s in x] + + def get_seqs(self, get_best=True): + return self.get_seq(get_best) + + def rewire(self, order=None, offset=0, loops=0): + ''' + helper function for "partial" protocol + ----------------------------------------- + -order=[0,1,2] - change order of specified segments + -offset=0 - specify start position of the first segment + -loops=[3,2] - specified loop lengths between segments + ----------------------------------------- + ''' + self.opt["pos"] = rewire(length=self._pos_info["length"], order=order, + offset=offset, loops=loops) + + # make default + if hasattr(self,"_opt"): self._opt["pos"] = self.opt["pos"] + +def soft_seq(x, bias, opt, key=None, num_seq=None, shuffle_first=True): + seq = {"input":x} + # shuffle msa + if x.ndim == 3 and x.shape[0] > 1 and key is not None: + key, sub_key = jax.random.split(key) + if num_seq is None or x.shape[0] == num_seq: + # randomly pick which sequence is query + if shuffle_first: + n = jax.random.randint(sub_key,[],0,x.shape[0]) + seq["input"] = seq["input"].at[0].set(seq["input"][n]).at[n].set(seq["input"][0]) + else: + n = jnp.arange(x.shape[0]) + if shuffle_first: + n = jax.random.permutation(sub_key,n) + else: + n = jnp.append(0,jax.random.permutation(sub_key,n[1:])) + seq["input"] = seq["input"][n[:num_seq]] + + # straight-through/reparameterization + seq["logits"] = seq["input"] * opt["alpha"] + if bias is not None: seq["logits"] = seq["logits"] + bias + seq["pssm"] = jax.nn.softmax(seq["logits"]) + seq["soft"] = jax.nn.softmax(seq["logits"] / opt["temp"]) + seq["hard"] = jax.nn.one_hot(seq["soft"].argmax(-1), seq["soft"].shape[-1]) + seq["hard"] = jax.lax.stop_gradient(seq["hard"] - seq["soft"]) + seq["soft"] + + # create pseudo sequence + seq["pseudo"] = opt["soft"] * seq["soft"] + (1-opt["soft"]) * seq["input"] + seq["pseudo"] = opt["hard"] * seq["hard"] + (1-opt["hard"]) * seq["pseudo"] + return seq diff --git a/build/lib/colabdesign/shared/parse_args.py b/build/lib/colabdesign/shared/parse_args.py new file mode 100644 index 00000000..45c28448 --- /dev/null +++ b/build/lib/colabdesign/shared/parse_args.py @@ -0,0 +1,55 @@ +import sys, getopt +# class for parsing arguments +class parse_args: + def __init__(self): + self.long,self.short = [],[] + self.info,self.help = [],[] + + def txt(self,help): + self.help.append(["txt",help]) + + def add(self, arg, default, type, help=None): + self.long.append(arg[0]) + key = arg[0].replace("=","") + self.info.append({"key":key, "type":type, + "value":default, "arg":[f"--{key}"]}) + if len(arg) == 2: + self.short.append(arg[1]) + s_key = arg[1].replace(":","") + self.info[-1]["arg"].append(f"-{s_key}") + if help is not None: + self.help.append(["opt",[arg,help]]) + + def parse(self,argv): + for opt, arg in getopt.getopt(argv,"".join(self.short),self.long)[0]: + for x in self.info: + if opt in x["arg"]: + if x["type"] is None: x["value"] = (x["value"] == False) + else: x["value"] = x["type"](arg) + + opts = {x["key"]:x["value"] for x in self.info} + print(str(opts).replace(" ","")) + return dict2obj(opts) + + def usage(self, err): + for type,info in self.help: + if type == "txt": print(info) + if type == "opt": + arg, helps = info + help = helps[0] + if len(arg) == 1: print("--%-15s : %s" % (arg[0],help)) + if len(arg) == 2: print("--%-10s -%-3s : %s" % (arg[0],arg[1].replace(":",""),help)) + for help in helps[1:]: print("%19s %s" % ("",help)) + print(f"< {err} >") + print(" "+"-"*(len(err)+2)) + print(" \ ^__^ ") + print(" \ (oo)\_______ ") + print(" (__)\ )\/\ ") + print(" ||----w | ") + print(" || || ") + sys.exit() + +class dict2obj(): + def __init__(self, dictionary): + for key in dictionary: + setattr(self, key, dictionary[key]) \ No newline at end of file diff --git a/build/lib/colabdesign/shared/plot.py b/build/lib/colabdesign/shared/plot.py new file mode 100644 index 00000000..cbd0309b --- /dev/null +++ b/build/lib/colabdesign/shared/plot.py @@ -0,0 +1,337 @@ +# import matplotlib +import numpy as np +from scipy.special import expit as sigmoid +from colabdesign.shared.protein import _np_kabsch, alphabet_list + +import matplotlib +import matplotlib.pyplot as plt +import matplotlib.patheffects +from matplotlib import animation +from matplotlib.gridspec import GridSpec +from matplotlib import collections as mcoll +try: + import py3Dmol +except: + print("py3Dmol not installed") + +pymol_color_list = ["#33ff33","#00ffff","#ff33cc","#ffff00","#ff9999","#e5e5e5","#7f7fff","#ff7f00", + "#7fff7f","#199999","#ff007f","#ffdd5e","#8c3f99","#b2b2b2","#007fff","#c4b200", + "#8cb266","#00bfbf","#b27f7f","#fcd1a5","#ff7f7f","#ffbfdd","#7fffff","#ffff7f", + "#00ff7f","#337fcc","#d8337f","#bfff3f","#ff7fff","#d8d8ff","#3fffbf","#b78c4c", + "#339933","#66b2b2","#ba8c84","#84bf00","#b24c66","#7f7f7f","#3f3fa5","#a5512b"] + +jalview_color_list = {"Clustal": ["#80a0f0","#f01505","#00ff00","#c048c0","#f08080","#00ff00","#c048c0","#f09048","#15a4a4","#80a0f0","#80a0f0","#f01505","#80a0f0","#80a0f0","#ffff00","#00ff00","#00ff00","#80a0f0","#15a4a4","#80a0f0"], + "Zappo": ["#ffafaf","#6464ff","#00ff00","#ff0000","#ffff00","#00ff00","#ff0000","#ff00ff","#6464ff","#ffafaf","#ffafaf","#6464ff","#ffafaf","#ffc800","#ff00ff","#00ff00","#00ff00","#ffc800","#ffc800","#ffafaf"], + "Taylor": ["#ccff00","#0000ff","#cc00ff","#ff0000","#ffff00","#ff00cc","#ff0066","#ff9900","#0066ff","#66ff00","#33ff00","#6600ff","#00ff00","#00ff66","#ffcc00","#ff3300","#ff6600","#00ccff","#00ffcc","#99ff00"], + "Hydrophobicity": ["#ad0052","#0000ff","#0c00f3","#0c00f3","#c2003d","#0c00f3","#0c00f3","#6a0095","#1500ea","#ff0000","#ea0015","#0000ff","#b0004f","#cb0034","#4600b9","#5e00a1","#61009e","#5b00a4","#4f00b0","#f60009","#0c00f3","#680097","#0c00f3"], + "Helix Propensity": ["#e718e7","#6f906f","#1be41b","#778877","#23dc23","#926d92","#ff00ff","#00ff00","#758a75","#8a758a","#ae51ae","#a05fa0","#ef10ef","#986798","#00ff00","#36c936","#47b847","#8a758a","#21de21","#857a85","#49b649","#758a75","#c936c9"], + "Strand Propensity": ["#5858a7","#6b6b94","#64649b","#2121de","#9d9d62","#8c8c73","#0000ff","#4949b6","#60609f","#ecec13","#b2b24d","#4747b8","#82827d","#c2c23d","#2323dc","#4949b6","#9d9d62","#c0c03f","#d3d32c","#ffff00","#4343bc","#797986","#4747b8"], + "Turn Propensity": ["#2cd3d3","#708f8f","#ff0000","#e81717","#a85757","#3fc0c0","#778888","#ff0000","#708f8f","#00ffff","#1ce3e3","#7e8181","#1ee1e1","#1ee1e1","#f60909","#e11e1e","#738c8c","#738c8c","#9d6262","#07f8f8","#f30c0c","#7c8383","#5ba4a4"], + "Buried Index": ["#00a35c","#00fc03","#00eb14","#00eb14","#0000ff","#00f10e","#00f10e","#009d62","#00d52a","#0054ab","#007b84","#00ff00","#009768","#008778","#00e01f","#00d52a","#00db24","#00a857","#00e619","#005fa0","#00eb14","#00b649","#00f10e"]} + +pymol_cmap = matplotlib.colors.ListedColormap(pymol_color_list) + +def show_pdb(pdb_str, show_sidechains=False, show_mainchains=False, + color="pLDDT", chains=None, Ls=None, vmin=50, vmax=90, + color_HP=False, size=(800,480), hbondCutoff=4.0, + animate=False): + + if chains is None: + chains = 1 if Ls is None else len(Ls) + + view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=size[0], height=size[1]) + if animate: + view.addModelsAsFrames(pdb_str,'pdb',{'hbondCutoff':hbondCutoff}) + else: + view.addModel(pdb_str,'pdb',{'hbondCutoff':hbondCutoff}) + if color == "pLDDT": + view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':vmin,'max':vmax}}}) + elif color == "rainbow": + view.setStyle({'cartoon': {'color':'spectrum'}}) + elif color == "chain": + for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list): + view.setStyle({'chain':chain},{'cartoon': {'color':color}}) + if show_sidechains: + BB = ['C','O','N'] + HP = ["ALA","GLY","VAL","ILE","LEU","PHE","MET","PRO","TRP","CYS","TYR"] + if color_HP: + view.addStyle({'and':[{'resn':HP},{'atom':BB,'invert':True}]}, + {'stick':{'colorscheme':"yellowCarbon",'radius':0.3}}) + view.addStyle({'and':[{'resn':HP,'invert':True},{'atom':BB,'invert':True}]}, + {'stick':{'colorscheme':"whiteCarbon",'radius':0.3}}) + view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]}, + {'sphere':{'colorscheme':"yellowCarbon",'radius':0.3}}) + view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]}, + {'stick':{'colorscheme':"yellowCarbon",'radius':0.3}}) + else: + view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]}, + {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) + view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]}, + {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) + view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]}, + {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) + if show_mainchains: + BB = ['C','O','N','CA'] + view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) + view.zoomTo() + if animate: view.animate() + return view + +def plot_pseudo_3D(xyz, c=None, ax=None, chainbreak=5, Ls=None, + cmap="gist_rainbow", line_w=2.0, + cmin=None, cmax=None, zmin=None, zmax=None, + shadow=0.95): + + def rescale(a, amin=None, amax=None): + a = np.copy(a) + if amin is None: amin = a.min() + if amax is None: amax = a.max() + a[a < amin] = amin + a[a > amax] = amax + return (a - amin)/(amax - amin) + + # make segments and colors for each segment + xyz = np.asarray(xyz) + if Ls is None: + seg = np.concatenate([xyz[:,None],np.roll(xyz,1,0)[:,None]],axis=1) + c_seg = np.arange(len(seg))[::-1] if c is None else (c + np.roll(c,1,0))/2 + else: + Ln = 0 + seg = [] + c_seg = [] + for L in Ls: + sub_xyz = xyz[Ln:Ln+L] + seg.append(np.concatenate([sub_xyz[:,None],np.roll(sub_xyz,1,0)[:,None]],axis=1)) + if c is not None: + sub_c = c[Ln:Ln+L] + c_seg.append((sub_c + np.roll(sub_c,1,0))/2) + Ln += L + seg = np.concatenate(seg,0) + c_seg = np.arange(len(seg))[::-1] if c is None else np.concatenate(c_seg,0) + + # set colors + c_seg = rescale(c_seg,cmin,cmax) + if isinstance(cmap, str): + if cmap == "gist_rainbow": + c_seg *= 0.75 + colors = matplotlib.cm.get_cmap(cmap)(c_seg) + else: + colors = cmap(c_seg) + + # remove segments that aren't connected + seg_len = np.sqrt(np.square(seg[:,0] - seg[:,1]).sum(-1)) + if chainbreak is not None: + idx = seg_len < chainbreak + seg = seg[idx] + seg_len = seg_len[idx] + colors = colors[idx] + + seg_mid = seg.mean(1) + seg_xy = seg[...,:2] + seg_z = seg[...,2].mean(-1) + order = seg_z.argsort() + + # add shade/tint based on z-dimension + z = rescale(seg_z,zmin,zmax)[:,None] + + # add shadow (make lines darker if they are behind other lines) + seg_len_cutoff = (seg_len[:,None] + seg_len[None,:]) / 2 + seg_mid_z = seg_mid[:,2] + seg_mid_dist = np.sqrt(np.square(seg_mid[:,None] - seg_mid[None,:]).sum(-1)) + shadow_mask = sigmoid(seg_len_cutoff * 2.0 - seg_mid_dist) * (seg_mid_z[:,None] < seg_mid_z[None,:]) + np.fill_diagonal(shadow_mask,0.0) + shadow_mask = shadow ** shadow_mask.sum(-1,keepdims=True) + + seg_mid_xz = seg_mid[:,:2] + seg_mid_xydist = np.sqrt(np.square(seg_mid_xz[:,None] - seg_mid_xz[None,:]).sum(-1)) + tint_mask = sigmoid(seg_len_cutoff/2 - seg_mid_xydist) * (seg_mid_z[:,None] < seg_mid_z[None,:]) + np.fill_diagonal(tint_mask,0.0) + tint_mask = 1 - tint_mask.max(-1,keepdims=True) + + colors[:,:3] = colors[:,:3] + (1 - colors[:,:3]) * (0.50 * z + 0.50 * tint_mask) / 3 + colors[:,:3] = colors[:,:3] * (0.20 + 0.25 * z + 0.55 * shadow_mask) + colors = np.clip(colors,0,1) + + set_lim = False + if ax is None: + fig, ax = plt.subplots() + fig.set_figwidth(5) + fig.set_figheight(5) + set_lim = True + else: + fig = ax.get_figure() + if ax.get_xlim() == (0,1): + set_lim = True + + if set_lim: + xy_min = xyz[:,:2].min() - line_w + xy_max = xyz[:,:2].max() + line_w + ax.set_xlim(xy_min,xy_max) + ax.set_ylim(xy_min,xy_max) + + ax.set_aspect('equal') + + # determine linewidths + width = fig.bbox_inches.width * ax.get_position().width + linewidths = line_w * 72 * width / np.diff(ax.get_xlim()) + + lines = mcoll.LineCollection(seg_xy[order], colors=colors[order], linewidths=linewidths, + path_effects=[matplotlib.patheffects.Stroke(capstyle="round")]) + + return ax.add_collection(lines) + +def plot_ticks(ax, Ls, Ln=None, add_yticks=False): + if Ln is None: Ln = sum(Ls) + L_prev = 0 + for L_i in Ls[:-1]: + L = L_prev + L_i + L_prev += L_i + ax.plot([0,Ln],[L,L],color="black") + ax.plot([L,L],[0,Ln],color="black") + + if add_yticks: + ticks = np.cumsum([0]+Ls) + ticks = (ticks[1:] + ticks[:-1])/2 + ax.yticks(ticks,alphabet_list[:len(ticks)]) + +def make_animation(seq, con=None, xyz=None, plddt=None, pae=None, + losses=None, pos_ref=None, line_w=2.0, + dpi=100, interval=60, color_msa="Taylor", + length=None, align_xyz=True, color_by="plddt", **kwargs): + + def nankabsch(a,b,**kwargs): + ok = np.isfinite(a).all(axis=1) & np.isfinite(b).all(axis=1) + a,b = a[ok],b[ok] + return _np_kabsch(a,b,**kwargs) + + if xyz is not None: + if pos_ref is None: + pos_ref = xyz[-1] + + if length is None: + L = len(pos_ref) + Ls = None + elif isinstance(length, list): + L = length[0] + Ls = length + else: + L = length + Ls = None + + # align to reference + if align_xyz: + + pos_ref_trim = pos_ref[:L] + pos_ref_trim_mu = np.nanmean(pos_ref_trim,0) + pos_ref_trim = pos_ref_trim - pos_ref_trim_mu + + # align to reference position + new_pos = [] + for x in xyz: + x_mu = np.nanmean(x[:L],0) + aln = nankabsch(x[:L]-x_mu, pos_ref_trim, use_jax=False) + new_pos.append((x-x_mu) @ aln) + + pos = np.array(new_pos) + + # rotate for best view + pos_mean = np.concatenate(pos,0) + m = np.nanmean(pos_mean,0) + rot_mtx = nankabsch(pos_mean - m, pos_mean - m, return_v=True, use_jax=False) + pos = (pos - m) @ rot_mtx + pos_ref_full = ((pos_ref - pos_ref_trim_mu) - m) @ rot_mtx + + else: + # rotate for best view + pos_mean = np.concatenate(xyz,0) + m = np.nanmean(pos_mean,0) + aln = nankabsch(pos_mean - m, pos_mean - m, return_v=True, use_jax=False) + pos = [(x - m) @ aln for x in xyz] + pos_ref_full = (pos_ref - m) @ aln + + # initialize figure + if pae is not None and len(pae) == 0: pae = None + fig = plt.figure() + gs = GridSpec(4,3, figure=fig) + if pae is not None: + ax1, ax2, ax3 = fig.add_subplot(gs[:3,:2]), fig.add_subplot(gs[3:,:]), fig.add_subplot(gs[:3,2:]) + else: + ax1, ax2 = fig.add_subplot(gs[:3,:]), fig.add_subplot(gs[3:,:]) + + fig.subplots_adjust(top=0.95,bottom=0.1,right=0.95,left=0.05,hspace=0,wspace=0) + fig.set_figwidth(8); fig.set_figheight(6); fig.set_dpi(dpi) + ax2.set_xlabel("positions"); ax2.set_yticks([]) + if seq[0].shape[0] > 1: ax2.set_ylabel("sequences") + else: ax2.set_ylabel("amino acids") + + if xyz is None: + ax1.set_title("predicted contact map") + else: + ax1.set_title("N→C") if plddt is None else ax1.set_title("pLDDT") + if pae is not None: + ax3.set_title("pAE") + ax3.set_xticks([]) + ax3.set_yticks([]) + + # set bounderies + if xyz is not None: + main_pos = pos_ref_full[np.isfinite(pos_ref_full).all(1)] + pred_pos = [np.isfinite(x).all(1) for x in pos] + x_min,y_min,z_min = np.minimum(np.mean([x.min(0) for x in pred_pos],0),main_pos.min(0)) - 5 + x_max,y_max,z_max = np.maximum(np.mean([x.max(0) for x in pred_pos],0),main_pos.max(0)) + 5 + + x_pad = ((y_max - y_min) * 2 - (x_max - x_min)) / 2 + y_pad = ((x_max - x_min) / 2 - (y_max - y_min)) / 2 + if x_pad > 0: + x_min -= x_pad + x_max += x_pad + else: + y_min -= y_pad + y_max += y_pad + + ax1.set_xlim(x_min, x_max) + ax1.set_ylim(y_min, y_max) + ax1.set_xticks([]) + ax1.set_yticks([]) + + # get animation frames + ims = [] + for k in range(len(seq)): + ims.append([]) + if xyz is not None: + flags = dict(ax=ax1, line_w=line_w, zmin=z_min, zmax=z_max) + if color_by == "plddt" and plddt is not None: + ims[-1].append(plot_pseudo_3D(pos[k], c=plddt[k], Ls=Ls, cmin=0.5, cmax=0.9, **flags)) + elif color_by == "chain": + c = np.concatenate([[n]*L for n,L in enumerate(length)]) + ims[-1].append(plot_pseudo_3D(pos[k], c=c, Ls=Ls, cmap=pymol_cmap, cmin=0, cmax=39, **flags)) + else: + L = pos[k].shape[0] + ims[-1].append(plot_pseudo_3D(pos[k], c=np.arange(L)[::-1], Ls=Ls, cmin=0, cmax=L, **flags)) + else: + L = con[k].shape[0] + ims[-1].append(ax1.imshow(con[k], animated=True, cmap="Greys",vmin=0, vmax=1, extent=(0, L, L, 0))) + + if seq[k].shape[0] == 1: + ims[-1].append(ax2.imshow(seq[k][0].T, animated=True, cmap="bwr_r",vmin=-1, vmax=1)) + else: + cmap = matplotlib.colors.ListedColormap(jalview_color_list[color_msa]) + vmax = len(jalview_color_list[color_msa]) - 1 + ims[-1].append(ax2.imshow(seq[k].argmax(-1), animated=True, cmap=cmap, vmin=0, vmax=vmax, interpolation="none")) + + if pae is not None: + L = pae[k].shape[0] + ims[-1].append(ax3.imshow(pae[k], animated=True, cmap="bwr",vmin=0, vmax=30, extent=(0, L, L, 0))) + + # add lines + if length is not None: + Ls = length if isinstance(length, list) else [length,None] + if con is not None: + plot_ticks(ax1, Ls, con[0].shape[0]) + if pae is not None: + plot_ticks(ax3, Ls, pae[0].shape[0]) + + # make animation! + ani = animation.ArtistAnimation(fig, ims, blit=True, interval=interval) + plt.close() + return ani.to_html5_video() diff --git a/build/lib/colabdesign/shared/prep.py b/build/lib/colabdesign/shared/prep.py new file mode 100644 index 00000000..437a811d --- /dev/null +++ b/build/lib/colabdesign/shared/prep.py @@ -0,0 +1,72 @@ +import numpy as np +def prep_pos(pos, residue, chain): + ''' + given input [pos]itions (a string of segment ranges seperated by comma, + for example: "1,3-4,10-15"), return list of indices to constrain. + ''' + residue_set = [] + chain_set = [] + len_set = [] + for idx in pos.split(","): + i,j = idx.split("-") if "-" in idx else (idx, None) + + if i.isalpha() and j is None: + residue_set += [None] + chain_set += [i] + len_set += [i] + else: + # if chain defined + if i[0].isalpha(): + c,i = i[0], int(i[1:]) + else: + c,i = chain[0],int(i) + if j is None: + j = i + else: + j = int(j[1:] if j[0].isalpha() else j) + residue_set += list(range(i,j+1)) + chain_set += [c] * (j-i+1) + len_set += [j-i+1] + + residue = np.asarray(residue) + chain = np.asarray(chain) + pos_set = [] + for i,c in zip(residue_set, chain_set): + if i is None: + idx = np.where(chain == c)[0] + assert len(idx) > 0, f'ERROR: chain {c} not found' + pos_set += [n for n in idx] + len_set[len_set.index(c)] = len(idx) + else: + idx = np.where((residue == i) & (chain == c))[0] + assert len(idx) == 1, f'ERROR: positions {i} and chain {c} not found' + pos_set.append(idx[0]) + + return {"residue":np.array(residue_set), + "chain":np.array(chain_set), + "length":np.array(len_set), + "pos":np.asarray(pos_set)} + +def rewire(length, order=None, loops=0, offset=0): + ''' + Given a list of segment [length]s, move them around given an [offset], [order] and [loop] lengths. + The [order] of the segments and the length of [loops] between segments can be controlled. + ''' + seg_len = [length] if isinstance(length,int) else length + num_seg = len(seg_len) + + # define order of segments + if order is None: order = list(range(num_seg)) + assert len(order) == num_seg + + # define loop lengths between segments + loop_len = ([loops] * (num_seg - 1)) if isinstance(loops, int) else loops + assert len(loop_len) == num_seg - 1 + + # get positions we want to restrain/constrain within hallucinated protein + l,new_pos = offset,[] + for n,i in enumerate(np.argsort(order)): + new_pos.append(l + np.arange(seg_len[i])) + if n < num_seg - 1: l += seg_len[i] + loop_len[n] + + return np.concatenate([new_pos[i] for i in order]) \ No newline at end of file diff --git a/build/lib/colabdesign/shared/prng.py b/build/lib/colabdesign/shared/prng.py new file mode 100644 index 00000000..cb1eb996 --- /dev/null +++ b/build/lib/colabdesign/shared/prng.py @@ -0,0 +1,29 @@ +import jax + +# adopted from https://github.com/deepmind/alphafold/blob/main/alphafold/model/prng.py +class SafeKey: + """Safety wrapper for PRNG keys.""" + + def __init__(self, key): + self._key = key + self._used = False + + def _assert_not_used(self): + if self._used: + raise RuntimeError('Random key has been used previously.') + + def get(self): + self._assert_not_used() + self._used = True + return self._key + + def split(self, num_keys=2): + self._assert_not_used() + self._used = True + new_keys = jax.random.split(self._key, num_keys) + return jax.tree_util.tree_map(SafeKey, tuple(new_keys)) + + def duplicate(self, num_keys=2): + self._assert_not_used() + self._used = True + return tuple(SafeKey(self._key) for _ in range(num_keys)) diff --git a/build/lib/colabdesign/shared/protein.py b/build/lib/colabdesign/shared/protein.py new file mode 100644 index 00000000..e75d4bbd --- /dev/null +++ b/build/lib/colabdesign/shared/protein.py @@ -0,0 +1,288 @@ +import jax +import jax.numpy as jnp +import numpy as np + +from colabdesign.af.alphafold.common import residue_constants +from string import ascii_uppercase, ascii_lowercase +alphabet_list = list(ascii_uppercase+ascii_lowercase) + +MODRES = {'MSE':'MET','MLY':'LYS','FME':'MET','HYP':'PRO', + 'TPO':'THR','CSO':'CYS','SEP':'SER','M3L':'LYS', + 'HSK':'HIS','SAC':'SER','PCA':'GLU','DAL':'ALA', + 'CME':'CYS','CSD':'CYS','OCS':'CYS','DPR':'PRO', + 'B3K':'LYS','ALY':'LYS','YCM':'CYS','MLZ':'LYS', + '4BF':'TYR','KCX':'LYS','B3E':'GLU','B3D':'ASP', + 'HZP':'PRO','CSX':'CYS','BAL':'ALA','HIC':'HIS', + 'DBZ':'ALA','DCY':'CYS','DVA':'VAL','NLE':'LEU', + 'SMC':'CYS','AGM':'ARG','B3A':'ALA','DAS':'ASP', + 'DLY':'LYS','DSN':'SER','DTH':'THR','GL3':'GLY', + 'HY3':'PRO','LLP':'LYS','MGN':'GLN','MHS':'HIS', + 'TRQ':'TRP','B3Y':'TYR','PHI':'PHE','PTR':'TYR', + 'TYS':'TYR','IAS':'ASP','GPL':'LYS','KYN':'TRP', + 'CSD':'CYS','SEC':'CYS'} + +def pdb_to_string(pdb_file, chains=None, models=None): + '''read pdb file and return as string''' + + if chains is not None: + if "," in chains: chains = chains.split(",") + if not isinstance(chains,list): chains = [chains] + if models is not None: + if not isinstance(models,list): models = [models] + + modres = {**MODRES} + lines = [] + seen = [] + model = 1 + + if "\n" in pdb_file: + old_lines = pdb_file.split("\n") + else: + with open(pdb_file,"rb") as f: + old_lines = [line.decode("utf-8","ignore").rstrip() for line in f] + for line in old_lines: + if line[:5] == "MODEL": + model = int(line[5:]) + if models is None or model in models: + if line[:6] == "MODRES": + k = line[12:15] + v = line[24:27] + if k not in modres and v in residue_constants.restype_3to1: + modres[k] = v + if line[:6] == "HETATM": + k = line[17:20] + if k in modres: + line = "ATOM "+line[6:17]+modres[k]+line[20:] + if line[:4] == "ATOM": + chain = line[21:22] + if chains is None or chain in chains: + atom = line[12:12+4].strip() + resi = line[17:17+3] + resn = line[22:22+5].strip() + if resn[-1].isalpha(): # alternative atom + resn = resn[:-1] + line = line[:26]+" "+line[27:] + key = f"{model}_{chain}_{resn}_{resi}_{atom}" + if key not in seen: # skip alternative placements + lines.append(line) + seen.append(key) + if line[:5] == "MODEL" or line[:3] == "TER" or line[:6] == "ENDMDL": + lines.append(line) + return "\n".join(lines) + +def renum_pdb_str(pdb_str, Ls=None, renum=True, offset=1): + if Ls is not None: + L_init = 0 + new_chain = {} + for L,c in zip(Ls, alphabet_list): + new_chain.update({i:c for i in range(L_init,L_init+L)}) + L_init += L + + n,num,pdb_out = 0,offset,[] + resnum_ = None + chain_ = None + new_chain_ = new_chain[0] + for line in pdb_str.split("\n"): + if line[:4] == "ATOM": + chain = line[21:22] + resnum = int(line[22:22+5]) + if resnum_ is None: resnum_ = resnum + if chain_ is None: chain_ = chain + if resnum != resnum_ or chain != chain_: + num += (resnum - resnum_) + n += 1 + resnum_,chain_ = resnum,chain + if Ls is not None: + if new_chain[n] != new_chain_: + num = offset + new_chain_ = new_chain[n] + N = num if renum else resnum + if Ls is None: pdb_out.append("%s%4i%s" % (line[:22],N,line[26:])) + else: pdb_out.append("%s%s%4i%s" % (line[:21],new_chain[n],N,line[26:])) + return "\n".join(pdb_out) + +################################################################################# + +def _np_len_pw(x, use_jax=True): + '''compute pairwise distance''' + _np = jnp if use_jax else np + + x_norm = _np.square(x).sum(-1) + xx = _np.einsum("...ia,...ja->...ij",x,x) + sq_dist = x_norm[...,:,None] + x_norm[...,None,:] - 2 * xx + + # due to precision errors the values can sometimes be negative + if use_jax: sq_dist = jax.nn.relu(sq_dist) + else: sq_dist[sq_dist < 0] = 0 + + # return euclidean pairwise distance matrix + return _np.sqrt(sq_dist + 1e-8) + +def _np_rmsdist(true, pred, use_jax=True): + '''compute RMSD of distance matrices''' + _np = jnp if use_jax else np + t = _np_len_pw(true, use_jax=use_jax) + p = _np_len_pw(pred, use_jax=use_jax) + return _np.sqrt(_np.square(t-p).mean() + 1e-8) + +def _np_kabsch(a, b, return_v=False, use_jax=True): + '''get alignment matrix for two sets of coodinates''' + _np = jnp if use_jax else np + ab = a.swapaxes(-1,-2) @ b + u, s, vh = _np.linalg.svd(ab, full_matrices=False) + flip = _np.linalg.det(u @ vh) < 0 + u_ = _np.where(flip, -u[...,-1].T, u[...,-1].T).T + if use_jax: u = u.at[...,-1].set(u_) + else: u[...,-1] = u_ + return u if return_v else (u @ vh) + +def _np_rmsd(true, pred, use_jax=True): + '''compute RMSD of coordinates after alignment''' + _np = jnp if use_jax else np + p = true - true.mean(-2,keepdims=True) + q = pred - pred.mean(-2,keepdims=True) + p = p @ _np_kabsch(p, q, use_jax=use_jax) + return _np.sqrt(_np.square(p-q).sum(-1).mean(-1) + 1e-8) + +def _np_norm(x, axis=-1, keepdims=True, eps=1e-8, use_jax=True): + '''compute norm of vector''' + _np = jnp if use_jax else np + return _np.sqrt(_np.square(x).sum(axis,keepdims=keepdims) + 1e-8) + +def _np_len(a, b, use_jax=True): + '''given coordinates a-b, return length or distance''' + return _np_norm(a-b, use_jax=use_jax) + +def _np_ang(a, b, c, use_acos=False, use_jax=True): + '''given coordinates a-b-c, return angle''' + _np = jnp if use_jax else np + norm = lambda x: _np_norm(x, use_jax=use_jax) + ba, bc = b-a, b-c + cos_ang = (ba * bc).sum(-1,keepdims=True) / (norm(ba) * norm(bc)) + # note the derivative at acos(-1 or 1) is inf, to avoid nans we use cos(ang) + if use_acos: return _np.arccos(cos_ang) + else: return cos_ang + +def _np_dih(a, b, c, d, use_atan2=False, standardize=False, use_jax=True): + '''given coordinates a-b-c-d, return dihedral''' + _np = jnp if use_jax else np + normalize = lambda x: x/_np_norm(x, use_jax=use_jax) + ab, bc, cd = normalize(a-b), normalize(b-c), normalize(c-d) + n1,n2 = _np.cross(ab, bc), _np.cross(bc, cd) + sin_ang = (_np.cross(n1, bc) * n2).sum(-1,keepdims=True) + cos_ang = (n1 * n2).sum(-1,keepdims=True) + if use_atan2: + return _np.arctan2(sin_ang, cos_ang) + else: + angs = _np.concatenate([sin_ang, cos_ang],-1) + if standardize: return normalize(angs) + else: return angs + +def _np_extend(a,b,c, L,A,D, use_jax=True): + ''' + given coordinates a-b-c, + c-d (L)ength, b-c-d (A)ngle, and a-b-c-d (D)ihedral + return 4th coordinate d + ''' + _np = jnp if use_jax else np + normalize = lambda x: x/_np_norm(x, use_jax=use_jax) + bc = normalize(b-c) + n = normalize(_np.cross(b-a, bc)) + return c + sum([L * _np.cos(A) * bc, + L * _np.sin(A) * _np.cos(D) * _np.cross(n, bc), + L * _np.sin(A) * _np.sin(D) * -n]) + +def _np_get_cb(N,CA,C, use_jax=True): + '''compute CB placement from N, CA, C''' + return _np_extend(C, N, CA, 1.522, 1.927, -2.143, use_jax=use_jax) + +def _np_get_6D(all_atom_positions, all_atom_mask=None, use_jax=True, for_trrosetta=False): + '''get 6D features (see TrRosetta paper)''' + + # get CB coordinate + atom_idx = {k:residue_constants.atom_order[k] for k in ["N","CA","C"]} + out = {k:all_atom_positions[...,i,:] for k,i in atom_idx.items()} + out["CB"] = _np_get_cb(**out, use_jax=use_jax) + + if all_atom_mask is not None: + idx = np.fromiter(atom_idx.values(),int) + out["CB_mask"] = all_atom_mask[...,idx].prod(-1) + + # get pairwise features + N,A,B = (out[k] for k in ["N","CA","CB"]) + n0 = N[...,:,None,:] + a0,a1 = A[...,:,None,:],A[...,None,:,:] + b0,b1 = B[...,:,None,:],B[...,None,:,:] + + if for_trrosetta: + out.update({"dist": _np_len(b0,b1, use_jax=use_jax), + "phi": _np_ang(a0,b0,b1, use_jax=use_jax, use_acos=True), + "omega": _np_dih(a0,b0,b1,a1, use_jax=use_jax, use_atan2=True), + "theta": _np_dih(n0,a0,b0,b1, use_jax=use_jax, use_atan2=True)}) + else: + out.update({"dist": _np_len(b0,b1, use_jax=use_jax), + "phi": _np_ang(a0,b0,b1, use_jax=use_jax, use_acos=False), + "omega": _np_dih(a0,b0,b1,a1, use_jax=use_jax, use_atan2=False), + "theta": _np_dih(n0,a0,b0,b1, use_jax=use_jax, use_atan2=False)}) + return out + +#################### +# losses +#################### + +# RMSD +def jnp_rmsdist(true, pred): + return _np_rmsdist(true, pred) + +def jnp_rmsd(true, pred, add_dist=False): + rmsd = _np_rmsd(true, pred) + if add_dist: rmsd = (rmsd + _np_rmsdist(true, pred))/2 + return rmsd + +def jnp_kabsch_w(a, b, weights): + return _np_kabsch(a * weights[:,None], b) + +def jnp_rmsd_w(true, pred, weights): + p = true - (true * weights[:,None]).sum(0,keepdims=True)/weights.sum() + q = pred - (pred * weights[:,None]).sum(0,keepdims=True)/weights.sum() + p = p @ _np_kabsch(p * weights[:,None], q) + return jnp.sqrt((weights*jnp.square(p-q).sum(-1)).sum()/weights.sum() + 1e-8) + +# 6D (see TrRosetta paper) +def _np_get_6D_loss(true, pred, mask=None, use_theta=True, use_dist=False, use_jax=True): + _np = jnp if use_jax else np + + f = {"T":_np_get_6D(true, mask, use_jax=use_jax), + "P":_np_get_6D(pred, use_jax=use_jax)} + + for k in f: f[k]["dist"] /= 10.0 + + keys = ["omega","phi"] + if use_theta: keys.append("theta") + if use_dist: keys.append("dist") + sq_diff = sum([_np.square(f["T"][k]-f["P"][k]).sum(-1) for k in keys]) + + mask = _np.ones(true.shape[0]) if mask is None else f["T"]["CB_mask"] + mask = mask[:,None] * mask[None,:] + loss = (sq_diff * mask).sum((-1,-2)) / mask.sum((-1,-2)) + + return _np.sqrt(loss + 1e-8).mean() + +def _np_get_6D_binned(all_atom_positions, all_atom_mask, use_jax=None): + # TODO: make differentiable, add use_jax option + ref = _np_get_6D(all_atom_positions, + all_atom_mask, + use_jax=False, for_trrosetta=True) + ref = jax.tree_map(jnp.squeeze,ref) + + def mtx2bins(x_ref, start, end, nbins, mask): + bins = np.linspace(start, end, nbins) + x_true = np.digitize(x_ref, bins).astype(np.uint8) + x_true = np.where(mask,0,x_true) + return np.eye(nbins+1)[x_true][...,:-1] + + mask = (ref["dist"] > 20) | (np.eye(ref["dist"].shape[0]) == 1) + return {"dist": mtx2bins(ref["dist"], 2.0, 20.0, 37, mask=mask), + "omega":mtx2bins(ref["omega"], -np.pi, np.pi, 25, mask=mask), + "theta":mtx2bins(ref["theta"], -np.pi, np.pi, 25, mask=mask), + "phi": mtx2bins(ref["phi"], 0.0, np.pi, 13, mask=mask)} \ No newline at end of file diff --git a/build/lib/colabdesign/shared/utils.py b/build/lib/colabdesign/shared/utils.py new file mode 100644 index 00000000..afd184a6 --- /dev/null +++ b/build/lib/colabdesign/shared/utils.py @@ -0,0 +1,111 @@ +import random +import jax +import numpy as np +import jax.numpy as jnp +import sys, gc + +def clear_mem(): + # clear vram (GPU) + backend = jax.lib.xla_bridge.get_backend() + if hasattr(backend,'live_buffers'): + for buf in backend.live_buffers(): + buf.delete() + + # TODO: clear ram (CPU) + gc.collect() + +def update_dict(D, *args, **kwargs): + '''robust function for updating dictionary''' + def set_dict(d, x, override=False): + for k,v in x.items(): + if v is not None: + if k in d: + if isinstance(v, dict): + set_dict(d[k], x[k], override=override) + elif override or d[k] is None: + d[k] = v + elif isinstance(d[k],(np.ndarray,jnp.ndarray)): + d[k] = np.asarray(v) + elif isinstance(d[k], dict): + d[k] = jax.tree.map(lambda x: type(x)(v), d[k]) + else: + d[k] = type(d[k])(v) + else: + print(f"ERROR: '{k}' not found in {list(d.keys())}") + override = kwargs.pop("override", False) + while len(args) > 0 and isinstance(args[0],str): + D,args = D[args[0]],args[1:] + for a in args: + if isinstance(a, dict): set_dict(D, a, override=override) + set_dict(D, kwargs, override=override) + +def copy_dict(x): + '''deepcopy dictionary''' + return jax.tree.map(lambda y:y, x) + +def to_float(x): + '''convert to float''' + if hasattr(x,"tolist"): x = x.tolist() + if isinstance(x,dict): x = {k:to_float(y) for k,y in x.items()} + elif hasattr(x,"__iter__"): x = [to_float(y) for y in x] + else: x = float(x) + return x + +def dict_to_str(x, filt=None, keys=None, ok=None, print_str=None, f=2): + '''convert dictionary to string for print out''' + if keys is None: keys = [] + if filt is None: filt = {} + if print_str is None: print_str = "" + if ok is None: ok = [] + + # gather keys + for k in x.keys(): + if k not in keys: + keys.append(k) + + for k in keys: + if k in x and (filt.get(k,True) or k in ok): + v = x[k] + if isinstance(v,float): + if int(v) == v: + print_str += f" {k} {int(v)}" + else: + print_str += f" {k} {v:.{f}f}" + else: + print_str += f" {k} {v}" + return print_str + +class Key(): + '''random key generator''' + def __init__(self, key=None, seed=None): + if key is None: + self.seed = random.randint(0,2147483647) if seed is None else seed + self.key = jax.random.PRNGKey(self.seed) + else: + self.key = key + def get(self, num=1): + if num > 1: + self.key, *sub_keys = jax.random.split(self.key, num=(num+1)) + return sub_keys + else: + self.key, sub_key = jax.random.split(self.key) + return sub_key + +def softmax(x, axis=-1): + x = x - x.max(axis,keepdims=True) + x = np.exp(x) + return x / x.sum(axis,keepdims=True) + +def categorical(p): + return (p.cumsum(-1) >= np.random.uniform(size=p.shape[:-1])[..., None]).argmax(-1) + +def to_list(xs): + if not isinstance(xs,list): xs = [xs] + return [x for x in xs if x is not None] + +def copy_missing(a,b): + for i,v in a.items(): + if i not in b: + b[i] = v + elif isinstance(v,dict): + copy_missing(v,b[i]) \ No newline at end of file diff --git a/build/lib/colabdesign/tr/__init__.py b/build/lib/colabdesign/tr/__init__.py new file mode 100644 index 00000000..712bf307 --- /dev/null +++ b/build/lib/colabdesign/tr/__init__.py @@ -0,0 +1,14 @@ +import os,jax +# disable triton_gemm for jax versions > 0.3 +if int(jax.__version__.split(".")[1]) > 3: + os.environ["XLA_FLAGS"] = "--xla_gpu_enable_triton_gemm=false" + +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) + +from colabdesign.shared.utils import clear_mem +from colabdesign.tr.model import mk_tr_model +from colabdesign.tr.joint_model import mk_af_tr_model + +# backward compatability +mk_design_model = mk_trdesign_model = mk_tr_model \ No newline at end of file diff --git a/build/lib/colabdesign/tr/joint_model.py b/build/lib/colabdesign/tr/joint_model.py new file mode 100644 index 00000000..801a7eec --- /dev/null +++ b/build/lib/colabdesign/tr/joint_model.py @@ -0,0 +1,66 @@ +from colabdesign.af.model import mk_af_model +from colabdesign.tr.model import mk_tr_model + +class mk_af_tr_model: + def __init__(self, protocol="fixbb", use_templates=False, + recycle_mode="last", num_recycles=0): + assert protocol in ["fixbb","partial","hallucination","binder"] + self.af = mk_af_model(protocol=protocol, use_templates=use_templates, + recycle_mode=recycle_mode, num_recycles=num_recycles) + + if protocol == "binder": + def _prep_inputs(pdb_filename, chain, binder_len=50, binder_chain=None, + ignore_missing=True, **kwargs): + self.af.prep_inputs(pdb_filename=pdb_filename, chain=chain, + binder_len=binder_len, binder_chain=binder_chain, + ignore_missing=ignore_missing, **kwargs) + flags = dict(ignore_missing=ignore_missing) + if binder_chain is None: + self.tr = mk_tr_model(protocol="hallucination") + self.tr.prep_inputs(length=binder_len, **flags) + else: + self.tr = mk_tr_model(protocol="fixbb") + self.tr.prep_inputs(pdb_filename=pdb_filename, chain=binder_chain, **flags) + else: + self.tr = mk_tr_model(protocol=protocol) + + if protocol == "fixbb": + def _prep_inputs(pdb_filename, chain, fix_pos=None, + ignore_missing=True, **kwargs): + flags = dict(pdb_filename=pdb_filename, chain=chain, + fix_pos=fix_pos, ignore_missing=ignore_missing) + self.af.prep_inputs(**flags, **kwargs) + self.tr.prep_inputs(**flags, chain=chain) + + if protocol == "partial": + def _prep_inputs(pdb_filename, chain, pos=None, length=None, + fix_pos=None, use_sidechains=False, atoms_to_exclude=None, + ignore_missing=True, **kwargs): + if use_sidechains: fix_seq = True + flags = dict(pdb_filename=pdb_filename, chain=chain, + length=length, pos=pos, fix_pos=fix_pos, + ignore_missing=ignore_missing) + af_a2e = kwargs.pop("af_atoms_to_exclude",atoms_to_exclude) + tr_a2e = kwargs.pop("tr_atoms_to_exclude",atoms_to_exclude) + self.af.prep_inputs(**flags, use_sidechains=use_sidechains, atoms_to_exclude=af_a2e, **kwargs) + self.tr.prep_inputs(**flags, atoms_to_exclude=tr_a2e) + + def _rewire(order=None, offset=0, loops=0): + self.af.rewire(order=order, offset=offset, loops=loops) + self.tr.rewire(order=order, offset=offset, loops=loops) + + self.rewire = _rewire + + if protocol == "hallucintion": + def _prep_inputs(length=None, **kwargs): + self.af.prep_inputs(length=length, **kwargs) + self.tr.prep_inputs(length=length) + + self.prep_inputs = _prep_inputs + + def set_opt(self,*args,**kwargs): + self.af.set_opt(*args,**kwargs) + self.tr.set_opt(*args,**kwargs) + + def joint_design(self, iters=100, tr_weight=1.0, tr_seed=None, **kwargs): + self.af.design(iters, callback=self.tr.af_callback(weight=tr_weight, seed=tr_seed), **kwargs) diff --git a/build/lib/colabdesign/tr/model.py b/build/lib/colabdesign/tr/model.py new file mode 100644 index 00000000..8262c924 --- /dev/null +++ b/build/lib/colabdesign/tr/model.py @@ -0,0 +1,341 @@ +import random, os +import numpy as np +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +from colabdesign.shared.utils import copy_dict, update_dict, Key, dict_to_str +from colabdesign.shared.prep import prep_pos +from colabdesign.shared.protein import _np_get_6D_binned +from colabdesign.shared.model import design_model, soft_seq + +from .trrosetta import TrRosetta, get_model_params + +# borrow some stuff from AfDesign +from colabdesign.af.prep import prep_pdb +from colabdesign.af.alphafold.common import protein + +class mk_tr_model(design_model): + def __init__(self, protocol="fixbb", num_models=1, + sample_models=True, data_dir="params/tr", + optimizer="sgd", learning_rate=0.1, + loss_callback=None): + + assert protocol in ["fixbb","hallucination","partial"] + + self.protocol = protocol + self._data_dir = "." if os.path.isfile(os.path.join("models",f"model_xaa.npy")) else data_dir + self._loss_callback = loss_callback + self._num = 1 + + # set default options + self.opt = {"temp":1.0, "soft":1.0, "hard":1.0, "dropout":False, + "num_models":num_models,"sample_models":sample_models, + "weights":{}, "lr":1.0, "alpha":1.0, + "learning_rate":learning_rate, "use_pssm":False, + "norm_seq_grad":True} + + self._args = {"optimizer":optimizer} + self._params = {} + self._inputs = {} + + # setup model + self._model = self._get_model() + self._model_params = [] + for k in list("abcde"): + p = os.path.join(self._data_dir,os.path.join("models",f"model_xa{k}.npy")) + self._model_params.append(get_model_params(p)) + + if protocol in ["hallucination","partial"]: + self._bkg_model = TrRosetta(bkg_model=True) + + def _get_model(self): + runner = TrRosetta() + def _get_loss(inputs, outputs): + opt = inputs["opt"] + aux = {"outputs":outputs, "losses":{}} + log_p = jax.tree_map(jax.nn.log_softmax, outputs) + + # bkg loss + if self.protocol in ["hallucination","partial"]: + p = jax.tree_map(jax.nn.softmax, outputs) + log_q = jax.tree_map(jax.nn.log_softmax, inputs["6D_bkg"]) + aux["losses"]["bkg"] = {} + for k in ["dist","omega","theta","phi"]: + aux["losses"]["bkg"][k] = -(p[k]*(log_p[k]-log_q[k])).sum(-1).mean() + + # cce loss + if self.protocol in ["fixbb","partial"]: + if "pos" in opt: + pos = opt["pos"] + log_p = jax.tree_map(lambda x:x[:,pos][pos,:], log_p) + + q = inputs["6D"] + aux["losses"]["cce"] = {} + for k in ["dist","omega","theta","phi"]: + aux["losses"]["cce"][k] = -(q[k]*log_p[k]).sum(-1).mean() + + if self._loss_callback is not None: + aux["losses"].update(self._loss_callback(outputs)) + + # weighted loss + w = opt["weights"] + tree_multi = lambda x,y: jax.tree_map(lambda a,b:a*b, x,y) + losses = {k:(tree_multi(v,w[k]) if k in w else v) for k,v in aux["losses"].items()} + loss = sum(jax.tree_leaves(losses)) + return loss, aux + + def _model(params, model_params, inputs, key): + inputs["params"] = params + opt = inputs["opt"] + seq = soft_seq(params["seq"], inputs["bias"], opt) + if "fix_pos" in opt: + if "pos" in self.opt: + seq_ref = jax.nn.one_hot(inputs["batch"]["aatype_sub"],20) + p = opt["pos"][opt["fix_pos"]] + fix_seq = lambda x:x.at[...,p,:].set(seq_ref) + else: + seq_ref = jax.nn.one_hot(inputs["batch"]["aatype"],20) + p = opt["fix_pos"] + fix_seq = lambda x:x.at[...,p,:].set(seq_ref[...,p,:]) + seq = jax.tree_map(fix_seq, seq) + + inputs.update({"seq":seq["pseudo"][0], + "prf":jnp.where(opt["use_pssm"],seq["pssm"],seq["pseudo"])[0]}) + rate = jnp.where(opt["dropout"],0.15,0.0) + outputs = runner(inputs, model_params, key, rate) + loss, aux = _get_loss(inputs, outputs) + aux.update({"seq":seq,"opt":opt}) + return loss, aux + + return {"grad_fn":jax.jit(jax.value_and_grad(_model, has_aux=True, argnums=0)), + "fn":jax.jit(_model)} + + def prep_inputs(self, pdb_filename=None, chain=None, length=None, + pos=None, fix_pos=None, atoms_to_exclude=None, ignore_missing=True, + **kwargs): + ''' + prep inputs for TrDesign + ''' + if self.protocol in ["fixbb", "partial"]: + # parse PDB file and return features compatible with TrRosetta + pdb = prep_pdb(pdb_filename, chain, ignore_missing=ignore_missing) + self._inputs["batch"] = pdb["batch"] + + if fix_pos is not None: + self.opt["fix_pos"] = prep_pos(fix_pos, **pdb["idx"])["pos"] + + if self.protocol == "partial" and pos is not None: + self._pos_info = prep_pos(pos, **pdb["idx"]) + p = self._pos_info["pos"] + aatype = self._inputs["batch"]["aatype"] + self._inputs["batch"] = jax.tree_map(lambda x:x[p], self._inputs["batch"]) + self.opt["pos"] = p + if "fix_pos" in self.opt: + sub_i,sub_p = [],[] + p = p.tolist() + for i in self.opt["fix_pos"].tolist(): + if i in p: + sub_i.append(i) + sub_p.append(p.index(i)) + self.opt["fix_pos"] = np.array(sub_p) + self._inputs["batch"]["aatype_sub"] = aatype[sub_i] + + self._inputs["6D"] = _np_get_6D_binned(self._inputs["batch"]["all_atom_positions"], + self._inputs["batch"]["all_atom_mask"]) + + self._len = len(self._inputs["batch"]["aatype"]) + self.opt["weights"]["cce"] = {"dist":1/6,"omega":1/6,"theta":2/6,"phi":2/6} + if atoms_to_exclude is not None: + if "N" in atoms_to_exclude: + # theta = [N]-CA-CB-CB + self.opt["weights"]["cce"] = dict(dist=1/4,omega=1/4,phi=1/2,theta=0) + if "CA" in atoms_to_exclude: + # theta = N-[CA]-CB-CB + # omega = [CA]-CB-CB-[CA] + # phi = [CA]-CB-CB + self.opt["weights"]["cce"] = dict(dist=1,omega=0,phi=0,theta=0) + + if self.protocol in ["hallucination", "partial"]: + # compute background distribution + if length is not None: self._len = length + self._inputs["6D_bkg"] = [] + key = jax.random.PRNGKey(0) + for n in range(1,6): + p = os.path.join(self._data_dir,os.path.join("bkgr_models",f"bkgr0{n}.npy")) + self._inputs["6D_bkg"].append(self._bkg_model(get_model_params(p), key, self._len)) + self._inputs["6D_bkg"] = jax.tree_map(lambda *x:np.stack(x).mean(0), *self._inputs["6D_bkg"]) + + # reweight the background + self.opt["weights"]["bkg"] = dict(dist=1/6,omega=1/6,phi=2/6,theta=2/6) + + + self._opt = copy_dict(self.opt) + self.restart(**kwargs) + + def set_opt(self, *args, **kwargs): + ''' + set [opt]ions + ------------------- + note: model.restart() resets the [opt]ions to their defaults + use model.set_opt(..., set_defaults=True) + or model.restart(..., reset_opt=False) to avoid this + ------------------- + model.set_opt(num_models=1) + model.set_opt(con=dict(num=1)) or set_opt({"con":{"num":1}}) + model.set_opt(lr=1, set_defaults=True) + ''' + if kwargs.pop("set_defaults", False): + update_dict(self._opt, *args, **kwargs) + + update_dict(self.opt, *args, **kwargs) + + + def restart(self, seed=None, opt=None, weights=None, + seq=None, reset_opt=True, **kwargs): + + if reset_opt: + self.opt = copy_dict(self._opt) + + self.set_opt(opt) + self.set_weights(weights) + self.set_seed(seed) + + # set sequence + self.set_seq(seq, **kwargs) + + # setup optimizer + self._k = 0 + self.set_optimizer() + + # clear previous best + self._tmp = {"best":{}} + + def run(self, backprop=True): + '''run model to get outputs, losses and gradients''' + + # decide which model params to use + ns = np.arange(5) + m = min(self.opt["num_models"],len(ns)) + if self.opt["sample_models"] and m != len(ns): + model_num = np.random.choice(ns,(m,),replace=False) + else: + model_num = ns[:m] + model_num = np.array(model_num).tolist() + + # run in serial + aux_all = [] + for n in model_num: + model_params = self._model_params[n] + self._inputs["opt"] = self.opt + flags = [self._params, model_params, self._inputs, self.key()] + if backprop: + (loss,aux),grad = self._model["grad_fn"](*flags) + else: + loss,aux = self._model["fn"](*flags) + grad = jax.tree_map(np.zeros_like, self._params) + aux.update({"loss":loss, "grad":grad}) + aux_all.append(aux) + + # average results + self.aux = jax.tree_map(lambda *x:np.stack(x).mean(0), *aux_all) + self.aux["model_num"] = model_num + + + def step(self, backprop=True, callback=None, save_best=True, verbose=1): + self.run(backprop=backprop) + if callback is not None: callback(self) + + # modify gradients + if self.opt["norm_seq_grad"]: self._norm_seq_grad() + self._state, self.aux["grad"] = self._optimizer(self._state, self.aux["grad"], self._params) + + # apply gradients + lr = self.opt["learning_rate"] + self._params = jax.tree_map(lambda x,g:x-lr*g, self._params, self.aux["grad"]) + + # increment + self._k += 1 + + # save results + if save_best: + if "aux" not in self._tmp["best"] or self.aux["loss"] < self._tmp["best"]["aux"]["loss"]: + self._tmp["best"]["aux"] = self.aux + + # print + if verbose and (self._k % verbose) == 0: + x = self.get_loss(get_best=False) + x["models"] = self.aux["model_num"] + print(dict_to_str(x, print_str=f"{self._k}", keys=["models"])) + + def predict(self, seq=None, models=0): + self.set_opt(dropout=False) + if seq is not None: + self.set_seq(seq=seq, set_state=False) + self.run(backprop=False) + + def design(self, iters=100, opt=None, weights=None, save_best=True, verbose=1): + self.set_opt(opt) + self.set_weights(weights) + for _ in range(iters): + self.step(save_best=save_best, verbose=verbose) + + def plot(self, mode="preds", dpi=100, get_best=True): + '''plot predictions''' + + assert mode in ["preds","feats","bkg_feats"] + if mode == "preds": + aux = self._tmp["best"]["aux"] if (get_best and "aux" in self._tmp["best"]) else self.aux + x = aux["outputs"] + elif mode == "feats": + x = self._inputs["6D"] + elif mode == "bkg_feats": + x = self._inputs["6D_bkg"] + + x = jax.tree_map(np.asarray, x) + + plt.figure(figsize=(4*4,4), dpi=dpi) + for n,k in enumerate(["theta","phi","dist","omega"]): + v = x[k] + plt.subplot(1,4,n+1) + plt.title(k) + plt.imshow(v.argmax(-1),cmap="binary") + plt.show() + + def get_loss(self, k=None, get_best=True): + aux = self._tmp["best"]["aux"] if (get_best and "aux" in self._tmp["best"]) else self.aux + if k is None: + return {k:self.get_loss(k, get_best=get_best) for k in aux["losses"].keys()} + losses = aux["losses"][k] + weights = aux["opt"]["weights"][k] + weighted_losses = jax.tree_map(lambda l,w:l*w, losses, weights) + return float(sum(jax.tree_leaves(weighted_losses))) + + def af_callback(self, weight=1.0, seed=None): + + def callback(af_model): + # copy [opt]ions from afdesign + for k,v in af_model.opt.items(): + if k in self.opt and k not in ["weights"]: + self.opt[k] = af_model.opt[k] + + # update sequence input + self._params["seq"] = af_model._params["seq"] + + # run trdesign + self.run(backprop = weight > 0) + + # add gradients + af_model.aux["grad"]["seq"] += weight * self.aux["grad"]["seq"] + + # add loss + af_model.aux["loss"] += weight * self.aux["loss"] + + # for verbose printout + if self.protocol in ["hallucination","partial"]: + af_model.aux["losses"]["TrD_bkg"] = self.get_loss("bkg", get_best=False) + if self.protocol in ["fixbb","partial"]: + af_model.aux["losses"]["TrD_cce"] = self.get_loss("cce", get_best=False) + + self.restart(seed=seed) + return callback \ No newline at end of file diff --git a/build/lib/colabdesign/tr/trrosetta.py b/build/lib/colabdesign/tr/trrosetta.py new file mode 100644 index 00000000..8a27b8da --- /dev/null +++ b/build/lib/colabdesign/tr/trrosetta.py @@ -0,0 +1,115 @@ +import jax.numpy as jnp +import jax +import numpy as np + +def TrRosetta(bkg_model=False): + + def pseudo_mrf(inputs, prf=None): + '''single sequence''' + seq,prf = inputs["seq"],inputs["prf"] + L,A = seq.shape[0],21 + if prf.shape[1] == 20: + prf = jnp.pad(prf,[[0,0],[0,1]]) + + # 1D features + x_1D = jnp.concatenate([seq, prf],-1) + x_1D = jnp.pad(x_1D,[[0,0],[0,1]]) + x_1D = jnp.repeat(x_1D[None],L,0) + + # 2D features + x_2D = jnp.diag(jnp.full(L*A,0.4)) + x_2D = x_2D.reshape(L,A,L,A).swapaxes(1,2).reshape(L,L,-1) + x_2D = jnp.pad(x_2D,[[0,0],[0,0],[0,1]]) + return jnp.concatenate([x_1D.swapaxes(0,1), x_1D, x_2D],-1) + + # layers + def instance_norm(x, params): + mu = x.mean((0,1),keepdims=True) + var = x.var((0,1),keepdims=True) + inv = jax.lax.rsqrt(var + 1e-6) * params["scale"] + return x * inv + params["offset"] - mu * inv + + def conv_2D(x, params, dilation=1, stride=1, padding="SAME"): + flags = dict(window_strides=(stride,stride), + rhs_dilation=(dilation,dilation), + padding=padding) + x = x.transpose([2,0,1]) + f = params["filters"].transpose([3,2,0,1]) + x = jax.lax.conv_general_dilated(x[None], f, **flags)[0] + x = x.transpose([1,2,0]) + return x + params["bias"] + + def dense(x, params): + return x @ params["filters"] + params["bias"] + + def dropout(x, key, rate): + keep_rate = 1.0 - rate + keep = jax.random.bernoulli(key, keep_rate, shape=x.shape) + return keep * x / keep_rate + + # meta layers + def encoder(x, params): + x = dense(x, params) + x = instance_norm(x, params) + return jax.nn.elu(x) + + def block(x, params, dilation, key, rate=0.15): + y = x + for n in [0,1]: + if n == 1: y = dropout(y, key, rate) + p = jax.tree_map(lambda x:x[n], params) + y = conv_2D(y, p, dilation) + y = instance_norm(y, p) + y = jax.nn.elu(y if n == 0 else (x+y)) + return y + + def resnet(x, params, key, rate=0.15): + def body(prev, sub_params): + (x,key) = prev + for n, dilation in enumerate([1,2,4,8,16]): + key, sub_key = jax.random.split(key) + p = jax.tree_map(lambda x:x[n], sub_params) + x = block(x, p, dilation, sub_key, rate) + return (x,key), None + return jax.lax.scan(body,(x,key),params)[0][0] + + def heads(x, params): + o = {k:dense(x,params[k]) for k in ["theta","phi"]} + x = (x + x.swapaxes(0,1)) / 2 + o.update({k:dense(x,params[k]) for k in ["dist","bb","omega"]}) + return o + + def trunk(x, params, key, rate=0.15): + key, sub_key = jax.random.split(key) + x = encoder(x, params["encoder"]) + x = resnet(x, params["resnet"], sub_key, rate) + x = block(x, params["block"], 1, key, rate) + return heads(x, params) + + # decide which model to use + if bkg_model: + def model(params, key, length=100): + key, sub_key = jax.random.split(key) + x = jax.random.normal(sub_key, (length, length, 64)) + return trunk(x, params, key, 0.0) + return jax.jit(model, static_argnums=2) + else: + def model(inputs, params, key, rate=0.15): + x = pseudo_mrf(inputs) + return trunk(x, params, key, rate) + return jax.jit(model) + +def get_model_params(npy): + '''parse TrRosetta params into dictionary''' + xaa = np.load(npy,allow_pickle=True).tolist() + layers = ["encoder","resnet","block","theta","phi","dist","bb","omega"] + num = np.array([4,0,8,2,2,2,2,2]) + num[1] = len(xaa) - num.sum() + idx = np.cumsum(num) - num + def split(params): + labels = ["filters","bias","offset","scale"] + steps = min(len(params),len(labels)) + return {labels[n]:np.squeeze(params[n::steps]) for n in range(steps)} + params = {k:split(xaa[i:i+n]) for k,i,n in zip(layers,idx,num)} + params["resnet"] = jax.tree_map(lambda x:x.reshape(-1,5,2,*x.shape[1:]), params["resnet"]) + return params \ No newline at end of file diff --git a/colabdesign.egg-info/PKG-INFO b/colabdesign.egg-info/PKG-INFO new file mode 100644 index 00000000..0352bb58 --- /dev/null +++ b/colabdesign.egg-info/PKG-INFO @@ -0,0 +1,28 @@ +Metadata-Version: 2.4 +Name: colabdesign +Version: 1.1.3 +Summary: Making Protein Design accessible to all via Google Colab! +Description-Content-Type: text/markdown +License-File: LICENSE.txt +Requires-Dist: py3Dmol +Requires-Dist: absl-py +Requires-Dist: biopython +Requires-Dist: chex +Requires-Dist: dm-haiku +Requires-Dist: dm-tree +Requires-Dist: immutabledict +Requires-Dist: jax +Requires-Dist: ml-collections +Requires-Dist: numpy +Requires-Dist: pandas +Requires-Dist: scipy +Requires-Dist: optax +Requires-Dist: joblib +Requires-Dist: matplotlib +Dynamic: description +Dynamic: description-content-type +Dynamic: license-file +Dynamic: requires-dist +Dynamic: summary + +Making Protein Design accessible to all via Google Colab! diff --git a/colabdesign.egg-info/SOURCES.txt b/colabdesign.egg-info/SOURCES.txt new file mode 100644 index 00000000..810a57e7 --- /dev/null +++ b/colabdesign.egg-info/SOURCES.txt @@ -0,0 +1,109 @@ +LICENSE.txt +MANIFEST.in +README.md +setup.py +colabdesign/__init__.py +colabdesign.egg-info/PKG-INFO +colabdesign.egg-info/SOURCES.txt +colabdesign.egg-info/dependency_links.txt +colabdesign.egg-info/requires.txt +colabdesign.egg-info/top_level.txt +colabdesign/af/__init__.py +colabdesign/af/design.py +colabdesign/af/inputs.py +colabdesign/af/loss.py +colabdesign/af/model.py +colabdesign/af/prep.py +colabdesign/af/utils.py +colabdesign/af/alphafold/__init__.py +colabdesign/af/alphafold/common/__init__.py +colabdesign/af/alphafold/common/confidence.py +colabdesign/af/alphafold/common/protein.py +colabdesign/af/alphafold/common/residue_constants.py +colabdesign/af/alphafold/data/__init__.py +colabdesign/af/alphafold/data/mmcif_parsing.py +colabdesign/af/alphafold/data/parsers.py +colabdesign/af/alphafold/data/pipeline.py +colabdesign/af/alphafold/data/pipeline_multimer.py +colabdesign/af/alphafold/data/prep_inputs.py +colabdesign/af/alphafold/data/tools/__init__.py +colabdesign/af/alphafold/data/tools/utils.py +colabdesign/af/alphafold/model/__init__.py +colabdesign/af/alphafold/model/all_atom.py +colabdesign/af/alphafold/model/all_atom_multimer.py +colabdesign/af/alphafold/model/common_modules.py +colabdesign/af/alphafold/model/config.py +colabdesign/af/alphafold/model/data.py +colabdesign/af/alphafold/model/folding.py +colabdesign/af/alphafold/model/folding_multimer.py +colabdesign/af/alphafold/model/layer_stack.py +colabdesign/af/alphafold/model/lddt.py +colabdesign/af/alphafold/model/mapping.py +colabdesign/af/alphafold/model/model.py +colabdesign/af/alphafold/model/modules.py +colabdesign/af/alphafold/model/modules_multimer.py +colabdesign/af/alphafold/model/prng.py +colabdesign/af/alphafold/model/quat_affine.py +colabdesign/af/alphafold/model/r3.py +colabdesign/af/alphafold/model/utils.py +colabdesign/af/alphafold/model/geometry/__init__.py +colabdesign/af/alphafold/model/geometry/rigid_matrix_vector.py +colabdesign/af/alphafold/model/geometry/rotation_matrix.py +colabdesign/af/alphafold/model/geometry/struct_of_array.py +colabdesign/af/alphafold/model/geometry/test_utils.py +colabdesign/af/alphafold/model/geometry/utils.py +colabdesign/af/alphafold/model/geometry/vector.py +colabdesign/af/alphafold/model/tf/__init__.py +colabdesign/af/alphafold/model/tf/shape_placeholders.py +colabdesign/af/contrib/__init__.py +colabdesign/af/contrib/crop.py +colabdesign/af/weights/__init__.py +colabdesign/af/weights/template_dgram_head.npy +colabdesign/esm_msa/__init__.py +colabdesign/esm_msa/axial_attention.py +colabdesign/esm_msa/config.py +colabdesign/esm_msa/constants.py +colabdesign/esm_msa/data.py +colabdesign/esm_msa/model.py +colabdesign/esm_msa/modules.py +colabdesign/esm_msa/pretrained.py +colabdesign/mpnn/__init__.py +colabdesign/mpnn/ensemble_model.py +colabdesign/mpnn/mdtraj_io.py +colabdesign/mpnn/model.py +colabdesign/mpnn/modules.py +colabdesign/mpnn/sample.py +colabdesign/mpnn/score.py +colabdesign/mpnn/utils.py +colabdesign/mpnn/weights/__init__.py +colabdesign/mpnn/weights/v_48_002.pkl +colabdesign/mpnn/weights/v_48_010.pkl +colabdesign/mpnn/weights/v_48_020.pkl +colabdesign/mpnn/weights/v_48_030.pkl +colabdesign/mpnn/weights_soluble/__init__.py +colabdesign/mpnn/weights_soluble/v_48_002.pkl +colabdesign/mpnn/weights_soluble/v_48_010.pkl +colabdesign/mpnn/weights_soluble/v_48_020.pkl +colabdesign/mpnn/weights_soluble/v_48_030.pkl +colabdesign/rf/__init__.py +colabdesign/rf/designability_test.py +colabdesign/rf/utils.py +colabdesign/seq/__init__.py +colabdesign/seq/kmeans.py +colabdesign/seq/learn_msa.py +colabdesign/seq/mrf.py +colabdesign/seq/stats.py +colabdesign/seq/utils.py +colabdesign/shared/__init__.py +colabdesign/shared/chunked_vmap.py +colabdesign/shared/model.py +colabdesign/shared/parse_args.py +colabdesign/shared/plot.py +colabdesign/shared/prep.py +colabdesign/shared/prng.py +colabdesign/shared/protein.py +colabdesign/shared/utils.py +colabdesign/tr/__init__.py +colabdesign/tr/joint_model.py +colabdesign/tr/model.py +colabdesign/tr/trrosetta.py \ No newline at end of file diff --git a/colabdesign.egg-info/dependency_links.txt b/colabdesign.egg-info/dependency_links.txt new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/colabdesign.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/colabdesign.egg-info/requires.txt b/colabdesign.egg-info/requires.txt new file mode 100644 index 00000000..d48b1a81 --- /dev/null +++ b/colabdesign.egg-info/requires.txt @@ -0,0 +1,15 @@ +py3Dmol +absl-py +biopython +chex +dm-haiku +dm-tree +immutabledict +jax +ml-collections +numpy +pandas +scipy +optax +joblib +matplotlib diff --git a/colabdesign.egg-info/top_level.txt b/colabdesign.egg-info/top_level.txt new file mode 100644 index 00000000..92c923d6 --- /dev/null +++ b/colabdesign.egg-info/top_level.txt @@ -0,0 +1 @@ +colabdesign diff --git a/colabdesign/mpnn/ensemble_model.py b/colabdesign/mpnn/ensemble_model.py new file mode 100644 index 00000000..5d3ef258 --- /dev/null +++ b/colabdesign/mpnn/ensemble_model.py @@ -0,0 +1,582 @@ +""" +ProteinMPNN adapted to molecular dynamics trajectories. + +""" + +import jax +import jax.numpy as jnp +import numpy as np +import re +import copy +import random +import os +import joblib +from functools import partial + +from .modules import RunModel +from .mdtraj_io import prep_from_mdtraj + +from scipy.special import softmax, log_softmax + +from colabdesign.shared.prep import prep_pos +from colabdesign.shared.utils import Key, copy_dict +from colabdesign.shared.chunked_vmap import vmap_chunked as cvmap + +# borrow some stuff from AfDesign +from colabdesign.af.prep import prep_pdb +from colabdesign.af.alphafold.common import protein, residue_constants +aa_order = residue_constants.restype_order +order_aa = {b:a for a,b in aa_order.items()} + +class mk_mpnn_ensemble_model(): + def __init__(self, model_name="v_48_020", + backbone_noise=0.0, dropout=0.0, + seed=None, verbose=False, weights="original", # weights can be set to either original or soluble + batch_size=500): + + # load model + if weights == "original": + from .weights import __file__ as mpnn_path + elif weights == "soluble": + from .weights_soluble import __file__ as mpnn_path + else: + raise ValueError(f'Invalid value {weights} supplied for weights. Value must be either "original" or "soluble".') + + path = os.path.join(os.path.dirname(mpnn_path), f'{model_name}.pkl') + checkpoint = joblib.load(path) + config = {'num_letters': 21, + 'node_features': 128, + 'edge_features': 128, + 'hidden_dim': 128, + 'num_encoder_layers': 3, + 'num_decoder_layers': 3, + 'augment_eps': backbone_noise, + 'k_neighbors': checkpoint['num_edges'], + 'dropout': dropout} + + self._model = RunModel(config) + self._model.params = jax.tree_util.tree_map(np.array, checkpoint['model_state_dict']) + self.batch_size = batch_size + self.set_seed(seed) + self._num = 1 + self._inputs = {} + self._tied_lengths = False + self._setup() + + def prep_inputs( + self, + traj=None, + chain=None, + homooligomer=False, + fix_pos=None, + inverse=False, + rm_aa=None, + verbose=False, + **kwargs, + ): + """Get inputs from an MDTraj object.""" + if traj is not None: + traj = prep_from_mdtraj(traj, chain=chain,) + else: + raise ValueError( + "One of 'mdtraj_frame', 'pdb_filename', or 'pdb_string' must be provided." + ) + + # atom idx + atom_idx = tuple(residue_constants.atom_order[k] for k in ["N", "CA", "C", "O"]) + chain_idx = np.concatenate([[n] * l for n, l in enumerate(traj["lengths"])]) + self._lengths = traj["lengths"] + L = sum(self._lengths) + + self._inputs = { + "X": traj["batch"]["all_atom_positions"][:, :, atom_idx], # atom_idx_moved_back_one + "mask": traj["batch"]["all_atom_mask"][:, 1], + "S": traj["batch"]["aatype"], + "residue_idx": traj["residue_index"], + "chain_idx": chain_idx, + "lengths": np.array(self._lengths), + "bias": np.zeros((L, 20)), + } + + if rm_aa is not None: + for aa in rm_aa.split(","): + self._inputs["bias"][..., aa_order[aa]] -= 1e6 + + if fix_pos is not None: + p = prep_pos(fix_pos, **traj["idx"])["pos"] + if inverse: + p = np.delete(np.arange(L), p) + self._inputs["fix_pos"] = p + self._inputs["bias"][p] = 1e7 * np.eye(21)[self._inputs["S"]][p, :20] + + if homooligomer: + #raise NotImplementedError("'homooligomer=True' not yet implemented") + assert min(self._lengths) == max(self._lengths) + self._tied_lengths = True + self._len = self._lengths[0] + else: + self._tied_lengths = False + self._len = sum(self._lengths) + + self.traj = traj + + if verbose: + print("lengths", self._lengths) + if "fix_pos" in self._inputs: + print("the following positions will be fixed:") + print(self._inputs["fix_pos"]) + + def get_af_inputs(self, af): + '''get inputs from alphafold model''' + + self._lengths = af._lengths + self._len = af._len + + self._inputs["residue_idx"] = af._inputs["residue_index"] + self._inputs["chain_idx"] = af._inputs["asym_id"] + self._inputs["lengths"] = np.array(self._lengths) + + # set bias + L = sum(self._lengths) + self._inputs["bias"] = np.zeros((L,20)) + self._inputs["bias"][-af._len:] = af._inputs["bias"] + + if "offset" in af._inputs: + self._inputs["offset"] = af._inputs["offset"] + + if "batch" in af._inputs: + atom_idx = tuple(residue_constants.atom_order[k] for k in ["N","CA","C","O"]) + batch = af._inputs["batch"] + self._inputs["X"] = batch["all_atom_positions"][:,atom_idx] + self._inputs["mask"] = batch["all_atom_mask"][:,1] + self._inputs["S"] = batch["aatype"] + + # fix positions + if af.protocol == "binder": + p = np.arange(af._target_len) + else: + p = af.opt.get("fix_pos",None) + + if p is not None: + self._inputs["fix_pos"] = p + self._inputs["bias"][p] = 1e7 * np.eye(21)[self._inputs["S"]][p,:20] + + # tie positions + if af._args["homooligomer"]: + assert min(self._lengths) == max(self._lengths) + self._tied_lengths = True + else: + self._tied_lengths = False + + def sample(self, temperature=0.1, rescore=False, **kwargs): + '''Sample one sequence for each conformer''' + I = copy_dict(self._inputs) + I.update(kwargs) + key = I.pop("key",self.key()) + keys = jax.random.split(key,1) + O = self._sample_conformers(keys, I, temperature, self._tied_lengths) + if rescore: + O = self._rescore_parallel(keys, I, O["S"], O["decoding_order"]) + + # must squeeze here, unlike regular model + O = jax.tree_util.tree_map(lambda x: jnp.squeeze(jnp.array(x)), O) + + # process outputs to human-readable form + O.update(self._get_seq(O)) + O.update(self._get_score(I,O)) + return O + + def sample_minimal(self, temperature=0.1, rescore=False, **kwargs): + '''Sample one sequence for each conformer''' + I = copy_dict(self._inputs) + I.update(kwargs) + key = I.pop("key",self.key()) + keys = jax.random.split(key,1) + O = self._sample_conformers(keys, I, temperature, self._tied_lengths) + + # must squeeze here, unlike regular model + O = jax.tree_util.tree_map(lambda x: jnp.squeeze(jnp.array(x)), O) + return O + + def sample_parallel(self, batch=10, temperature=0.1, rescore=False, **kwargs): + '''sample new sequence(s) in parallel + NOT IMPLEMENTED + + ''' + if batch != 1: + raise NotImplementedError("Batched sampling not implemented for conformational ensembles.") + + + def _get_seq(self, O): + """one_hot to amino acid sequence (still returns Python strings)""" + + def split_seq(seq_str, lengths, tied_lengths): # pass lengths and tied_lengths explicitly + if len(lengths) > 1: + # This string manipulation cannot be JITted. + # If this were inside a JITted function, it would be a host callback. + seq_str = "".join(np.insert(list(seq_str), np.cumsum(lengths[:-1]), "/")) + if tied_lengths: + seq_str = seq_str.split("/")[0] + return seq_str + + seqs = [] + # Assuming O["S"] is (batch, L, 21) or (L, 21) + # Convert JAX array to NumPy for iteration and string conversion + S_numpy = np.array(O["S"].argmax(axis=1)).T # modified axis=1 and transposed for ensemble + if S_numpy.ndim == 1: + S_numpy = S_numpy[None, :] # ensure batch dimension + + for s_np in S_numpy: + # This part is Python string manipulation + seq = "".join([order_aa[a_idx] for a_idx in s_np]) + seq = split_seq(seq, self._lengths, self._tied_lengths) # pass necessary attributes + seqs.append(seq) + return {"seq": np.array(seqs)} + + def _get_score(self, I, O): + ''' + logits to score/sequence_recovery + return {"score":score (L, n_frames), "seqid":seqid (L, n_frames)} + ''' + # this is reasonably fast, even without jax + mask = I["mask"].copy() + if "fix_pos" in I: + mask[I["fix_pos"]] = 0 + + mask = np.expand_dims(mask, -1) + + # softmaxes are now mapped over axis=1 + log_q = log_softmax(O["logits"], axis=1)[:,:20,:] + q = softmax(O["logits"][:,:20,:], axis=1) + + # sums are over axis 0 + if "S" in O: + S = O["S"][:,:20,:] + score = -(S * log_q).sum(axis=1) + seqid = S.argmax(axis=1) == np.expand_dims(self._inputs["S"], -1) + else: + score = -(q * log_q).sum(axis=0) + seqid = np.zeros_like(score) + + score = (score * mask).sum(axis=0) / (mask.sum() + 1e-8) + seqid = (seqid * mask).sum(axis=0) / (mask.sum() + 1e-8) + + return {"score":score, "seqid":seqid} + + + + def _get_score_jax(self, inputs_S, logits, mask, fix_pos=None): # Pass necessary inputs directly + ''' logits to score/sequence_recovery - JAX compatible version ''' + + + current_mask = mask + if fix_pos is not None and fix_pos.shape[0] > 0: # Ensure fix_pos is not empty + # Ensure mask is a JAX array for .at.set to work + current_mask = jnp.array(current_mask) # Convert if it's numpy + current_mask = current_mask.at[fix_pos].set(0) + + # Use jax.nn functions + log_q = jax.nn.log_softmax(logits, axis=-1)[..., :20] + q = jax.nn.softmax(logits[..., :20], axis=-1) + + S_scored_one_hot = jax.nn.one_hot(inputs_S, num_classes=21)[...,:20] # Assuming inputs_S is integer encoded for the sequence to score + # This would be I["S"] from the score() method + + score = -(S_scored_one_hot * log_q).sum(-1) + + seqid = (inputs_S == self._inputs["S"]) + + masked_score_sum = (score * current_mask).sum(-1) + masked_seqid_sum = (seqid * current_mask).sum(-1) + mask_sum = current_mask.sum() + 1e-8 + + final_score = masked_score_sum / mask_sum + final_seqid = masked_seqid_sum / mask_sum + + return {"score": final_score, "seqid": final_seqid} + + + def score(self, seq_numeric=None, **kwargs): # seq_numeric is an integer array + '''score sequence - JAX compatible version (mostly)''' + current_inputs = jax.tree_util.tree_map(jnp.array, self._inputs) + + if seq_numeric is not None: + # seq_numeric is expected to be an integer array of amino acid indices + p = jnp.arange(current_inputs["S"].shape[0]) + s_shape_0 = current_inputs["S"].shape[0] # Store shape for JAX tracing + + if self._tied_lengths and seq_numeric.shape[0] == self._lengths[0]: + # Assuming self._lengths is available and compatible + # seq_numeric might need tiling if it represents one chain of a homooligomer + num_repeats = len(self._lengths) + seq_numeric = jnp.tile(seq_numeric, num_repeats) + + if "fix_pos" in current_inputs and current_inputs["fix_pos"].shape[0] > 0: + # Ensure shapes are concrete or JAX can trace them + if seq_numeric.shape[0] == (s_shape_0 - current_inputs["fix_pos"].shape[0]): + p = jnp.delete(p, current_inputs["fix_pos"], axis=0) + + # Update S using .at[].set() + # Ensure seq_numeric is correctly broadcasted or indexed if p is tricky + current_inputs["S"] = current_inputs["S"].at[p].set(seq_numeric) + + # Combine kwargs with current_inputs, ensuring JAX types + for k, v in kwargs.items(): + current_inputs[k] = jnp.asarray(v) if not isinstance(v, jax.Array) else v + + + key_to_use = current_inputs.pop("key", self.key()) # self.key() provides a JAX key + + # _score is already JITted and expects JAX-compatible inputs + # The arguments to _score are X, mask, residue_idx, chain_idx, key, S, bias, decoding_order etc. + # Ensure all these are present in current_inputs and are JAX arrays. + + # Prepare arguments for self._score, ensuring they are all JAX arrays + score_fn_args = {k: current_inputs[k] for k in [ + 'X', 'mask', 'residue_idx', 'chain_idx', 'S', 'bias' + ] if k in current_inputs} + + if "decoding_order" in current_inputs: + score_fn_args["decoding_order"] = current_inputs["decoding_order"] + if "fix_pos" in current_inputs: # _score uses fix_pos to adjust decoding_order + score_fn_args["fix_pos"] = current_inputs["fix_pos"] + + + # O will be a dictionary of JAX arrays + O = self._score(**score_fn_args, key=key_to_use) + + # Call the JAX-compatible _get_score + # It needs: current_inputs["S"] (the sequence being scored, possibly modified), + # O["logits"], current_inputs["mask"], and current_inputs.get("fix_pos") + score_info = self._get_score( + inputs_S=current_inputs["S"], # This is the S that was actually scored by _score + logits=O["logits"], + mask=current_inputs["mask"], + fix_pos=current_inputs.get("fix_pos") + ) + O.update(score_info) # O remains a dict of JAX arrays + + # If you need to convert to NumPy arrays for external use, do it here, + # but the function itself now primarily deals with JAX arrays. + # For full JAX compatibility of `score` itself (e.g. to JIT it), + # this conversion should be outside. + # return jax.tree_map(np.array, O) + return O # Returns dict of JAX array + + def get_logits(self, **kwargs): + '''get logits''' + return self.score(**kwargs)["logits"] + + def get_unconditional_logits(self, **kwargs): + L = self._inputs["X"].shape[0] + kwargs["ar_mask"] = np.zeros((L,L)) + return self.score(**kwargs)["logits"] + + def set_seed(self, seed=None): + np.random.seed(seed=seed) + self.key = Key(seed=seed).get + + def _setup(self): + def _score_internal( + X, mask, residue_idx, chain_idx, key, S, bias, **kwargs + ): # Added S and bias + I = { + "X": X, + "mask": mask, + "residue_idx": residue_idx, + "chain_idx": chain_idx, + "S": S, # Pass S + "bias": bias, # Pass bias + } + I.update(kwargs) + + if "decoding_order" not in I: + key, sub_key = jax.random.split(key) + randn = _randomize_sophie(sub_key, X) + randn = jnp.where(I["mask"], randn, randn + 1) + if "fix_pos" in I and I["fix_pos"].shape[0] > 0: # check if fix_pos is not empty + randn = randn.at[I["fix_pos"]].add(-1) + I["decoding_order"] = randn.argsort() + + # _aa_convert is JAX-compatible + for k_item in ["S", "bias"]: # Use k_item to avoid conflict with key + if k_item in I: + I[k_item] = _aa_convert(I[k_item]) + + output_dict = self._model.score(self._model.params, key, I) + output_dict["S"] = _aa_convert(output_dict["S"], rev=True) + output_dict["logits"] = _aa_convert(output_dict["logits"], rev=True) + return output_dict + + self._score = jax.jit(_score_internal) + + def _sample_internal( + X, + mask, + residue_idx, + chain_idx, + key, + temperature=0.1, + tied_lengths=False, + bias=None, + **kwargs, + ): # added bias + # single conformer sampling + I = { + "X": X, + "mask": mask, + "residue_idx": residue_idx, + "chain_idx": chain_idx, + "temperature": temperature, + "bias": bias, # Pass bias + } + I.update(kwargs) + + # define decoding order (as in original _sample) + if "decoding_order" in I: + if I["decoding_order"].ndim == 1: + I["decoding_order"] = I["decoding_order"][:, None] + else: + key, sub_key = jax.random.split(key) + #randn = jax.random.uniform(sub_key, (I["X"].shape[0],)) + randn = _randomize_sophie(sub_key, X) + + + + randn = jnp.where(I["mask"], randn, randn + 1) + if "fix_pos" in I and I["fix_pos"].shape[0] > 0: + randn = randn.at[I["fix_pos"]].add(-1) + if tied_lengths: + copies = I["lengths"].shape[0] + decoding_order_tied = randn.reshape(copies, -1).mean(0).argsort() + I["decoding_order"] = ( + jnp.arange(I["X"].shape[0]).reshape(copies, -1).T[decoding_order_tied] + ) + else: + I["decoding_order"] = randn.argsort()[:, None] + + # S is not an input to _model.sample, but bias is + if "S" in I: + I["S"] = _aa_convert( + I["S"] + ) # If S is somehow passed (e.g. for conditioning, though MPNN typically doesn't) + if "bias" in I: + I["bias"] = _aa_convert(I["bias"]) + + O_dict = self._model.sample(self._model.params, key, I) + O_dict["S"] = _aa_convert(O_dict["S"], rev=True) # This is the sampled S + O_dict["logits"] = _aa_convert(O_dict["logits"], rev=True) + return O_dict + + self._sample = jax.jit(_sample_internal, static_argnames=["tied_lengths"]) + + # + def _vmap_sample_seqs_from_conformers(key, inputs, temperature, tied_lengths): + inputs_copy = dict(inputs) # Shallow copy for modification + inputs_copy.pop("temperature", None) + inputs_copy.pop("key", None) + # Ensure 'bias' is correctly handled if it's part of 'inputs' + f_of_X = jax.jit( + partial(self._sample, key=key, + **{k : v for k,v in inputs_copy.items() if k not in ("X",)}, + temperature=temperature, tied_lengths=tied_lengths), static_argnames=["tied_lengths"] + ) + # vmap over positions + return cvmap(f_of_X, chunk_size=min(self.batch_size, inputs_copy["X"].shape[0]))(inputs_copy["X"]) + + # this is vmap over keys, but there's only one. + # this is just the easiest way to square with earliest code + # but might be good to refactor + fn_vmap_sample_conformers = jax.vmap(_vmap_sample_seqs_from_conformers, in_axes=[0, None, None, None]) + # difference, no jit for now + self._sample_conformers = fn_vmap_sample_conformers + + @jax.jit + def _vmap_rescore_parallel(key, inputs, S_rescore, decoding_order_rescore): + inputs_copy = dict(inputs) # Shallow copy + inputs_copy.pop("S", None) + inputs_copy.pop("decoding_order", None) + inputs_copy.pop("key", None) + # Ensure 'bias' from original inputs is used, and S_rescore is the new S + return self._score( + **inputs_copy, key=key, S=S_rescore, decoding_order=decoding_order_rescore + ) # Pass S and decoding_order + fn_vmap_rescore = jax.vmap(_vmap_rescore_parallel, in_axes=[0, None, 0, 0]) + self._rescore_parallel = fn_vmap_rescore + +####################################################################################### + +def _aa_convert(x, rev=False): + mpnn_alphabet = 'ACDEFGHIKLMNPQRSTVWYX' + af_alphabet = 'ARNDCQEGHILKMFPSTWYVX' + if x is None: + return x + else: + if rev: + return x[...,tuple(mpnn_alphabet.index(k) for k in af_alphabet)] + else: + x = jax.nn.one_hot(x,21) if jnp.issubdtype(x.dtype, jnp.integer) else x + if x.shape[-1] == 20: + x = jnp.pad(x,[[0,0],[0,1]]) + return x[...,tuple(af_alphabet.index(k) for k in mpnn_alphabet)] + +unknown_aa_index = aa_order.get('X', 20) # Default index for unknown AAs + +def convert_sequence_to_numeric(sequence_str: str, + aa_map: dict = aa_order, + all_chain_lengths: list = None, + is_homooligomer_tied: bool = False) -> jnp.array: + """ + Converts a protein sequence string into a JAX integer array. + + Args: + sequence_str: The amino acid sequence string. + - For monomers: "ACEG..." + - For heteromers (chains separated by '/'): "ACEG.../FGHI..." + - For homooligomers where is_homooligomer_tied is True and + only one chain's sequence is provided: "ACEG..." (will be tiled). + aa_map: Dictionary mapping amino acid characters to integers (e.g., aa_order). + all_chain_lengths: List of lengths of all chains in the complex. + Example: [100, 100] for a dimer of length 100 each. + Used for homooligomer tiling. + is_homooligomer_tied: Boolean. If True and sequence_str is for a single + chain of a homooligomer, the sequence will be tiled. + + Returns: + jnp.array: A JAX array of integers representing the full sequence. + """ + numeric_sequence_list = [] + + # Handle homooligomer case where a single chain sequence is provided to be tiled + if is_homooligomer_tied and \ + all_chain_lengths and \ + len(all_chain_lengths) > 0 and \ + "/" not in sequence_str: + # Check if the provided sequence string matches the length of one chain + if len(sequence_str) == all_chain_lengths[0]: # Assuming all chains have the same length + num_chains = len(all_chain_lengths) + # Tile the string sequence before converting to numeric + sequence_str = "/".join([sequence_str] * num_chains) + # TODO: add a warning or error if the lengths don't match + + # Process chain by chain if '/' is present, otherwise process the whole string + chains = sequence_str.split('/') + + for chain_seq_str in chains: + for aa_char in chain_seq_str: + # Use .get(key, default_value) to handle unexpected characters + numeric_sequence_list.append(aa_map.get(aa_char, unknown_aa_index)) + + return jnp.array(numeric_sequence_list, dtype=jnp.int32) + +### >:D >:D >:D >:D >:D >:D >:D >:D # >:D >:D >:D >:D >:D >:D >:D >:D # >:D >:D >:D >:D +def _randomize_sophie(key, X_conformer, max_freq=1e9, min_freq=1e3): + """ + WARNING: EXPERIMENTAL + + Use X_conformer as a natural entropy source to randomize decoding order + by transforming spatial coordinates with a random-frequency sine wave. + """ + randfreq = (max_freq - min_freq)*jax.random.uniform(key) + min_freq + randn = 0.5*(1 + jnp.sin(X_conformer * randfreq).sum(axis=(1,2))) + return randn \ No newline at end of file diff --git a/colabdesign/mpnn/mdtraj_io.py b/colabdesign/mpnn/mdtraj_io.py new file mode 100644 index 00000000..da4f03e5 --- /dev/null +++ b/colabdesign/mpnn/mdtraj_io.py @@ -0,0 +1,106 @@ +import jax +import numpy as np + +from colabdesign.af.alphafold.common import residue_constants +from colabdesign.shared.protein import _np_get_cb + +order_aa = {b: a for a, b in residue_constants.restype_order.items()} +aa_order = residue_constants.restype_order + + +def prep_from_mdtraj(traj, chain=None, **kwargs): + """ + Extracts features directly from an mdtraj.Trajectory object. + """ + + chains_to_process = [] + if chain is None: + chains_to_process = list(traj.topology.chains) + else: + requested_chain_ids = list(chain) + for c in traj.topology.chains: + if c.chain_id in requested_chain_ids: + chains_to_process.append(c) + + all_chains_data = [] + last_res_idx = 0 + full_lengths = [] + + for chain_obj in chains_to_process: + chain_id = chain_obj.chain_id + atom_indices = [a.index for a in chain_obj.atoms] + + chain_top = traj.topology.subset(atom_indices) + chain_xyz = traj.xyz[:, atom_indices, :] * 10 # Convert nm to Angstroms + n_res = chain_top.n_residues + + all_atom_positions = np.zeros((traj.n_frames, n_res, residue_constants.atom_type_num, 3)) + all_atom_mask = np.zeros((n_res, residue_constants.atom_type_num)) + aatype = np.zeros(n_res, dtype=int) + residue_index = np.zeros(n_res, dtype=int) + + for res_idx, residue in enumerate(chain_top.residues): + res_name = residue.name + aatype[res_idx] = residue_constants.resname_to_idx.get( + res_name, residue_constants.resname_to_idx["UNK"] + ) + residue_index[res_idx] = residue.resSeq + + for atom in residue.atoms: + if atom.name in residue_constants.atom_order: + atom_type_idx = residue_constants.atom_order[atom.name] + chain_atom_index = next( + a.index for a in chain_top.atoms if a.serial == atom.serial + ) + all_atom_positions[:,res_idx, atom_type_idx] = chain_xyz[:,chain_atom_index] + all_atom_mask[res_idx, atom_type_idx] = 1 + + batch = { + "aatype": aatype, + "all_atom_positions": all_atom_positions, + "all_atom_mask": all_atom_mask, + } + + p, m = batch["all_atom_positions"], batch["all_atom_mask"] + atom_idx = residue_constants.atom_order + atoms = {k: p[..., atom_idx[k], :] for k in ["N", "CA", "C",]} + + cb_atoms = _np_get_cb(**atoms, use_jax=False) + cb_mask = np.prod([m[..., atom_idx[k]] for k in ["N", "CA", "C"]], 0) + cb_idx = atom_idx["CB"] + batch["all_atom_positions"][..., cb_idx, :] = np.where( + m[:, cb_idx, None], p[..., cb_idx, :], cb_atoms + ) + batch["all_atom_mask"][..., cb_idx] = (m[:, cb_idx] + cb_mask) > 0 + #batch["all_atom_positions"] = np.moveaxis(batch["all_atom_positions"], 0, -1) + + chain_data = { + "batch": batch, + "residue_index": residue_index + last_res_idx, + "chain_id": [chain_id] * n_res, + "res_indices_original": residue_index, + } + all_chains_data.append(chain_data) + + last_res_idx += n_res + 50 + full_lengths.append(n_res) + + if not all_chains_data: + raise ValueError("No valid chains found or processed from the mdtraj frame.") + + # the one with 4 dimensions is mapped of axis=1, this will be the coordinates + final_batch = jax.tree_util.tree_map( + lambda *x: np.concatenate(x, int(len(x[0].shape)==4)), *[d.pop("batch") for d in all_chains_data] + ) + final_residue_index = np.concatenate([d.pop("residue_index") for d in all_chains_data]) + final_idx = { + "residue": np.concatenate([d.pop("res_indices_original") for d in all_chains_data]), + "chain": np.concatenate([d.pop("chain_id") for d in all_chains_data]), + } + + return { + "batch": final_batch, + "residue_index": final_residue_index, + "idx": final_idx, + "lengths": full_lengths, + } diff --git a/colabdesign/mpnn/model.py b/colabdesign/mpnn/model.py index abfc1f41..077ed2fd 100644 --- a/colabdesign/mpnn/model.py +++ b/colabdesign/mpnn/model.py @@ -21,7 +21,8 @@ class mk_mpnn_model(): def __init__(self, model_name="v_48_020", backbone_noise=0.0, dropout=0.0, - seed=None, verbose=False, weights="original"): # weights can be set to either original or soluble + seed=None, verbose=False, weights="original", + batch_size=1000): # weights can be set to either original or soluble # load model if weights == "original": from .weights import __file__ as mpnn_path @@ -50,6 +51,7 @@ def __init__(self, model_name="v_48_020", self._num = 1 self._inputs = {} self._tied_lengths = False + self.batch_size def prep_inputs(self, pdb_filename=None, chain=None, homooligomer=False, ignore_missing=True, fix_pos=None, inverse=False, @@ -376,9 +378,9 @@ def _vmap_sample_parallel(key, inputs, temperature, tied_lengths): inputs_copy.pop("temperature",None) inputs_copy.pop("key",None) # Ensure 'bias' is correctly handled if it's part of 'inputs' - return self._sample(**inputs_copy, key=key, temperature=temperature, tied_lengths=tied_lengths) + return self._sample(**inputs_copy, key=key, temperature=temperature, tied_lengths=tied_lengths, batch_size=self.batch_size) - fn_vmap_sample = jax.vmap(_vmap_sample_parallel, in_axes=[0,None,None,None]) + fn_vmap_sample = jax.lax.map(_vmap_sample_parallel, in_axes=[0,None,None,None]) self._sample_parallel = jax.jit(fn_vmap_sample, static_argnames=["tied_lengths"]) def _vmap_rescore_parallel(key, inputs, S_rescore, decoding_order_rescore): @@ -389,7 +391,7 @@ def _vmap_rescore_parallel(key, inputs, S_rescore, decoding_order_rescore): # Ensure 'bias' from original inputs is used, and S_rescore is the new S return self._score(**inputs_copy, key=key, S=S_rescore, decoding_order=decoding_order_rescore) # Pass S and decoding_order - fn_vmap_rescore = jax.vmap(_vmap_rescore_parallel, in_axes=[0,None,0,0]) + fn_vmap_rescore = jax.lax.map(_vmap_rescore_parallel, in_axes=[0,None,0,0]) self._rescore_parallel = jax.jit(fn_vmap_rescore) ####################################################################################### diff --git a/colabdesign/shared/chunked_vmap.py b/colabdesign/shared/chunked_vmap.py new file mode 100755 index 00000000..2e81135d --- /dev/null +++ b/colabdesign/shared/chunked_vmap.py @@ -0,0 +1,356 @@ +# Copyright 2021 The NetKet Authors - All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# WORK IN PROGRESS + +import sys + +from functools import partial + +import jax +import jax.numpy as jnp + +from jax.extend import linear_util as lu +from jax.api_util import argnums_partial + +from functools import partial +from typing import Optional, Callable + +_tree_add = partial(jax.tree_util.tree_map, jax.lax.add) +_tree_zeros_like = partial(jax.tree_util.tree_map, lambda x: jnp.zeros(x.shape, dtype=x.dtype)) + + + +def _treeify(f): + def _f(x, *args, **kwargs): + return jax.tree_util.tree_map(lambda y: f(y, *args, **kwargs), x) + + return _f + + +@_treeify +def _unchunk(x): + return x.reshape((-1,) + x.shape[2:]) + + +@_treeify +def _chunk(x, chunk_size=None): + # chunk_size=None -> add just a dummy chunk dimension, same as np.expand_dims(x, 0) + n = x.shape[0] + if chunk_size is None: + chunk_size = n + + n_chunks, residual = divmod(n, chunk_size) + if residual != 0: + raise ValueError( + "The first dimension of x must be divisible by chunk_size." + + f"\n Got x.shape={x.shape} but chunk_size={chunk_size}." + ) + return x.reshape((n_chunks, chunk_size) + x.shape[1:]) + + +def _chunk_size(x): + b = set(map(lambda x: x.shape[:2], jax.tree_util.tree_leaves(x))) + if len(b) != 1: + raise ValueError( + "The arrays in x have inconsistent chunk_size or number of chunks" + ) + return b.pop()[1] + + +def unchunk(x_chunked): + """ + Merge the first two axes of an array (or a pytree of arrays) + Args: + x_chunked: an array (or pytree of arrays) of at least 2 dimensions + Returns: a pair (x, chunk_fn) + where x is x_chunked reshaped to (-1,)+x.shape[2:] + and chunk_fn is a function which restores x given x_chunked + """ + return _unchunk(x_chunked), partial(_chunk, chunk_size=_chunk_size(x_chunked)) + + +def chunk(x, chunk_size=None): + """ + Split an array (or a pytree of arrays) into chunks along the first axis + Args: + x: an array (or pytree of arrays) + chunk_size: an integer or None (default) + The first axis in x must be a multiple of chunk_size + Returns: a pair (x_chunked, unchunk_fn) where + - x_chunked is x reshaped to (-1, chunk_size)+x.shape[1:] + if chunk_size is None then it defaults to x.shape[0], i.e. just one chunk + - unchunk_fn is a function which restores x given x_chunked + """ + return _chunk(x, chunk_size), _unchunk + +# TODO put it somewher + +def _multimap(f, *args): + try: + return tuple(map(lambda a: f(*a), zip(*args))) + except TypeError: + return f(*args) + + +def scan_append_reduce(f, x, append_cond, op=_tree_add): + """Evaluate f element by element in x while appending and/or reducing the results + Args: + f: a function that takes elements of the leading dimension of x + x: a pytree where each leaf array has the same leading dimension + append_cond: a bool (if f returns just one result) or a tuple of bools (if f returns multiple values) + which indicates whether the individual result should be appended or reduced + op: a function to (pairwise) reduce the specified results. Defaults to a sum. + Returns: + returns the (tuple of) results corresponding to the output of f + where each result is given by: + if append_cond is True: + a (pytree of) array(s) with leading dimension same as x, + containing the evaluation of f at each element in x + else (append_cond is False): + a (pytree of) array(s) with the same shape as the corresponding output of f, + containing the reduction over op of f evaluated at each x + Example: + import jax.numpy as jnp + from netket.jax import scan_append_reduce + def f(x): + y = jnp.sin(x) + return y, y, y**2 + N = 100 + x = jnp.linspace(0.,jnp.pi,N) + y, s, s2 = scan_append_reduce(f, x, (True, False, False)) + mean = s/N + var = s2/N - mean**2 + """ + # TODO: different op for each result + + x0 = jax.tree_util.tree_map(lambda x: x[0], x) + + # special code path if there is only one element + # to avoid having to rely on xla/llvm to optimize the overhead away + if jax.tree_util.tree_leaves(x)[0].shape[0] == 1: + return _multimap( + lambda c, x: jnp.expand_dims(x, 0) if c else x, append_cond, f(x0) + ) + + # the original idea was to use pytrees, however for now just operate on the return value tuple + _get_append_part = partial(_multimap, lambda c, x: x if c else None, append_cond) + _get_op_part = partial(_multimap, lambda c, x: x if not c else None, append_cond) + _tree_select = partial(_multimap, lambda c, t1, t2: t1 if c else t2, append_cond) + + carry_init = True, _get_op_part(_tree_zeros_like(jax.eval_shape(f, x0))) + + def f_(carry, x): + is_first, y_carry = carry + y = f(x) + y_op = _get_op_part(y) + y_append = _get_append_part(y) + # select here to avoid the user having to specify the zero element for op + y_reduce = jax.tree_util.tree_map( + partial(jax.lax.select, is_first), y_op, op(y_carry, y_op) + ) + return (False, y_reduce), y_append + + (_, res_op), res_append = jax.lax.scan(f_, carry_init, x, unroll=1) + # reconstruct the result from the reduced and appended parts in the two trees + return _tree_select(res_append, res_op) + + +scan_append = partial(scan_append_reduce, append_cond=True) +scan_reduce = partial(scan_append_reduce, append_cond=False) + + +# TODO in_axes a la vmap? +def scanmap(fun, scan_fun, argnums=0): + """ + A helper function to wrap f with a scan_fun + Example: + import jax.numpy as jnp + from functools import partial + from netket.jax import scanmap, scan_append_reduce + scan_fun = partial(scan_append_reduce, append_cond=(True, False, False)) + @partial(scanmap, scan_fun=scan_fun, argnums=1) + def f(c, x): + y = jnp.sin(x) + c + return y, y, y**2 + N = 100 + x = jnp.linspace(0.,jnp.pi,N) + c = 1. + y, s, s2 = f(c, x) + mean = s/N + var = s2/N - mean**2 + """ + + def f_(*args, **kwargs): + f = lu.wrap_init(fun, kwargs) + f_partial, dyn_args = argnums_partial( + f, argnums, args, require_static_args_hashable=False + ) + return scan_fun(lambda x: f_partial.call_wrapped(*x), dyn_args) + + return f_ + + +class HashablePartial(partial): + """ + A class behaving like functools.partial, but that retains it's hash + if it's created with a lexically equivalent (the same) function and + with the same partially applied arguments and keywords. + It also stores the computed hash for faster hashing. + """ + + # TODO remove when dropping support for Python < 3.10 + def __new__(cls, func, *args, **keywords): + # In Python 3.10+ if func is itself a functools.partial instance, + # functools.partial.__new__ would merge the arguments of this HashablePartial + # instance with the arguments of the func + # Pre 3.10 this does not happen, so here we emulate this behaviour recursively + # This is necessary since functools.partial objects do not have a __code__ + # property which we use for the hash + # For python 3.10+ we still need to take care of merging with another HashablePartial + while isinstance( + func, partial if sys.version_info < (3, 10) else HashablePartial + ): + original_func = func + func = original_func.func + args = original_func.args + args + keywords = {**original_func.keywords, **keywords} + return super().__new__(cls, func, *args, **keywords) + + def __init__(self, *args, **kwargs): + self._hash = None + + def __eq__(self, other): + return ( + type(other) is HashablePartial + and self.func.__code__ == other.func.__code__ + and self.args == other.args + and self.keywords == other.keywords + ) + + def __hash__(self): + if self._hash is None: + self._hash = hash( + (self.func.__code__, self.args, frozenset(self.keywords.items())) + ) + + return self._hash + + def __repr__(self): + return f"" + + + +def _fun(vmapped_fun, chunk_size, argnums, *args, **kwargs): + n_elements = jax.tree_util.tree_leaves(args[argnums[0]])[0].shape[0] + n_chunks, n_rest = divmod(n_elements, chunk_size) + + if n_chunks == 0 or chunk_size >= n_elements: + y = vmapped_fun(*args, **kwargs) + else: + # split inputs + def _get_chunks(x): + x_chunks = jax.tree_util.tree_map(lambda x_: x_[: n_elements - n_rest, ...], x) + x_chunks = _chunk(x_chunks, chunk_size) + return x_chunks + + def _get_rest(x): + x_rest = jax.tree_util.tree_map(lambda x_: x_[n_elements - n_rest :, ...], x) + return x_rest + + args_chunks = [ + _get_chunks(a) if i in argnums else a for i, a in enumerate(args) + ] + args_rest = [_get_rest(a) if i in argnums else a for i, a in enumerate(args)] + + y_chunks = _unchunk( + scanmap(vmapped_fun, scan_append, argnums)(*args_chunks, **kwargs) + ) + + if n_rest == 0: + y = y_chunks + else: + y_rest = vmapped_fun(*args_rest, **kwargs) + y = jax.tree_util.tree_map(lambda y1, y2: jnp.concatenate((y1, y2)), y_chunks, y_rest) + return y + + +def _chunk_vmapped_function( + vmapped_fun: Callable, chunk_size: Optional[int], argnums=0 +) -> Callable: + """takes a vmapped function and computes it in chunks""" + + if chunk_size is None: + return vmapped_fun + + if isinstance(argnums, int): + argnums = (argnums,) + + return HashablePartial(_fun, vmapped_fun, chunk_size, argnums) + + +def _parse_in_axes(in_axes): + if isinstance(in_axes, int): + in_axes = (in_axes,) + + if not set(in_axes).issubset((0, None)): + raise NotImplementedError("Only in_axes 0/None are currently supported") + + argnums = tuple( + map(lambda ix: ix[0], filter(lambda ix: ix[1] is not None, enumerate(in_axes))) + ) + return in_axes, argnums + + +def apply_chunked(f: Callable, in_axes=0, *, chunk_size: Optional[int]) -> Callable: + """ + Takes an implicitly vmapped function over the axis 0 and uses scan to + do the computations in smaller chunks over the 0-th axis of all input arguments. + For this to work, the function `f` should be `vectorized` along the `in_axes` + of the arguments. This means that the function `f` should respect the following + condition: + .. code-block:: python + assert f(x) == jnp.concatenate([f(x_i) for x_i in x], axis=0) + which is automatically satisfied if `f` is obtained by vmapping a function, + such as: + .. code-block:: python + f = jax.vmap(f_orig) + Args: + f: A function that satisfies the condition above + in_axes: The axes that should be scanned along. Only supports `0` or `None` + chunk_size: The maximum size of the chunks to be used. If it is `None`, chunking + is disabled + """ + _, argnums = _parse_in_axes(in_axes) + return _chunk_vmapped_function(f, chunk_size, argnums) + + +def vmap_chunked(f: Callable, in_axes=0, *, chunk_size: Optional[int]) -> Callable: + """ + Behaves like jax.vmap but uses scan to chunk the computations in smaller chunks. + This function is essentially equivalent to: + .. code-block:: python + nk.jax.apply_chunked(jax.vmap(f, in_axes), in_axes, chunk_size) + Some limitations to `in_axes` apply. + Args: + f: The function to be vectorised. + in_axes: The axes that should be scanned along. Only supports `0` or `None` + chunk_size: The maximum size of the chunks to be used. If it is `None`, chunking + is disabled + Returns: + A vectorised and chunked function + """ + in_axes, argnums = _parse_in_axes(in_axes) + vmapped_fun = jax.vmap(f, in_axes=in_axes) + return _chunk_vmapped_function(vmapped_fun, chunk_size, argnums) diff --git a/colabdesign/shared/utils.py b/colabdesign/shared/utils.py index 31670ff8..afd184a6 100644 --- a/colabdesign/shared/utils.py +++ b/colabdesign/shared/utils.py @@ -27,7 +27,7 @@ def set_dict(d, x, override=False): elif isinstance(d[k],(np.ndarray,jnp.ndarray)): d[k] = np.asarray(v) elif isinstance(d[k], dict): - d[k] = jax.tree_map(lambda x: type(x)(v), d[k]) + d[k] = jax.tree.map(lambda x: type(x)(v), d[k]) else: d[k] = type(d[k])(v) else: @@ -41,7 +41,7 @@ def set_dict(d, x, override=False): def copy_dict(x): '''deepcopy dictionary''' - return jax.tree_map(lambda y:y, x) + return jax.tree.map(lambda y:y, x) def to_float(x): '''convert to float'''