Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
# ColabDesign
# ColabDesign (EXPERIMENTAL justktln2 REMIX)

The only modification to the forked branch has been to introduce I/O with MDTraj and vectorize ProteinMPNN sampling across trajectories. In principle it is possible to do this with any conformational ensemble with a fixed protein topology, so this could be used with structures generated by AlphaFold MSA subsampling.
You will therefore need to install MDTraj to use this.
The important parts of this will be refactored and folded into either [ciMIST](https://github.com/justktln2/ciMIST/) or ColabDesign.

Original README, which I had nothing to do with, follows.

### Making Protein Design accessible to all via Google Colab!
- P(structure | sequence)
- [TrDesign](/tr) - using TrRosetta for design
Expand Down
16 changes: 16 additions & 0 deletions build/lib/colabdesign/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os,jax
# disable triton_gemm for jax versions > 0.3
if int(jax.__version__.split(".")[1]) > 3:
os.environ["XLA_FLAGS"] = "--xla_gpu_enable_triton_gemm=false"

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from colabdesign.shared.utils import clear_mem
from colabdesign.af.model import mk_af_model
from colabdesign.tr.model import mk_tr_model
from colabdesign.mpnn.model import mk_mpnn_model

# backward compatability
mk_design_model = mk_afdesign_model = mk_af_model
mk_trdesign_model = mk_tr_model
13 changes: 13 additions & 0 deletions build/lib/colabdesign/af/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os,jax
# disable triton_gemm for jax versions > 0.3
if int(jax.__version__.split(".")[1]) > 3:
os.environ["XLA_FLAGS"] = "--xla_gpu_enable_triton_gemm=false"

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from colabdesign.shared.utils import clear_mem
from colabdesign.af.model import mk_af_model

# backward compatability
mk_design_model = mk_afdesign_model = mk_af_model
14 changes: 14 additions & 0 deletions build/lib/colabdesign/af/alphafold/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An implementation of the inference pipeline of AlphaFold v2.0."""
14 changes: 14 additions & 0 deletions build/lib/colabdesign/af/alphafold/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common data types and constants used within Alphafold."""
169 changes: 169 additions & 0 deletions build/lib/colabdesign/af/alphafold/common/confidence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions for processing confidence metrics."""

import jax.numpy as jnp
import jax
import numpy as np
from colabdesign.af.alphafold.common import residue_constants
import scipy.special

def compute_tol(prev_pos, current_pos, mask, use_jnp=False):
# Early stopping criteria based on criteria used in
# AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
_np = jnp if use_jnp else np
dist = lambda x:_np.sqrt(((x[:,None] - x[None,:])**2).sum(-1))
ca_idx = residue_constants.atom_order['CA']
sq_diff = _np.square(dist(prev_pos[:,ca_idx])-dist(current_pos[:,ca_idx]))
mask_2d = mask[:,None] * mask[None,:]
return _np.sqrt((sq_diff * mask_2d).sum()/mask_2d.sum() + 1e-8)


def compute_plddt(logits, use_jnp=False):
"""Computes per-residue pLDDT from logits.
Args:
logits: [num_res, num_bins] output from the PredictedLDDTHead.
Returns:
plddt: [num_res] per-residue pLDDT.
"""
if use_jnp:
_np, _softmax = jnp, jax.nn.softmax
else:
_np, _softmax = np, scipy.special.softmax

num_bins = logits.shape[-1]
bin_width = 1.0 / num_bins
bin_centers = _np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width)
probs = _softmax(logits, axis=-1)
predicted_lddt_ca = (probs * bin_centers[None, :]).sum(-1)
return predicted_lddt_ca * 100

def _calculate_bin_centers(breaks, use_jnp=False):
"""Gets the bin centers from the bin edges.
Args:
breaks: [num_bins - 1] the error bin edges.
Returns:
bin_centers: [num_bins] the error bin centers.
"""
_np = jnp if use_jnp else np
step = breaks[1] - breaks[0]

# Add half-step to get the center
bin_centers = breaks + step / 2

# Add a catch-all bin at the end.
return _np.append(bin_centers, bin_centers[-1] + step)

def _calculate_expected_aligned_error(
alignment_confidence_breaks,
aligned_distance_error_probs,
use_jnp=False):
"""Calculates expected aligned distance errors for every pair of residues.
Args:
alignment_confidence_breaks: [num_bins - 1] the error bin edges.
aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted
probs for each error bin, for each pair of residues.
Returns:
predicted_aligned_error: [num_res, num_res] the expected aligned distance
error for each pair of residues.
max_predicted_aligned_error: The maximum predicted error possible.
"""
bin_centers = _calculate_bin_centers(alignment_confidence_breaks, use_jnp=use_jnp)
# Tuple of expected aligned distance error and max possible error.
pae = (aligned_distance_error_probs * bin_centers).sum(-1)
return (pae, bin_centers[-1])

def compute_predicted_aligned_error(logits, breaks, use_jnp=False):
"""Computes aligned confidence metrics from logits.
Args:
logits: [num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
breaks: [num_bins - 1] the error bin edges.

Returns:
aligned_confidence_probs: [num_res, num_res, num_bins] the predicted
aligned error probabilities over bins for each residue pair.
predicted_aligned_error: [num_res, num_res] the expected aligned distance
error for each pair of residues.
max_predicted_aligned_error: The maximum predicted error possible.
"""
_softmax = jax.nn.softmax if use_jnp else scipy.special.softmax
aligned_confidence_probs = _softmax(logits,axis=-1)
predicted_aligned_error, max_predicted_aligned_error = \
_calculate_expected_aligned_error(breaks, aligned_confidence_probs, use_jnp=use_jnp)

return {
'aligned_confidence_probs': aligned_confidence_probs,
'predicted_aligned_error': predicted_aligned_error,
'max_predicted_aligned_error': max_predicted_aligned_error,
}

def predicted_tm_score(logits, breaks, residue_weights = None,
asym_id = None, use_jnp=False):
"""Computes predicted TM alignment or predicted interface TM alignment score.

Args:
logits: [num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
breaks: [num_bins] the error bins.
residue_weights: [num_res] the per residue weights to use for the
expectation.
asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for
ipTM calculation.

Returns:
ptm_score: The predicted TM alignment or the predicted iTM score.
"""
if use_jnp:
_np, _softmax = jnp, jax.nn.softmax
else:
_np, _softmax = np, scipy.special.softmax

# residue_weights has to be in [0, 1], but can be floating-point, i.e. the
# exp. resolved head's probability.
if residue_weights is None:
residue_weights = _np.ones(logits.shape[0])

bin_centers = _calculate_bin_centers(breaks, use_jnp=use_jnp)
num_res = residue_weights.shape[0]

# Clip num_res to avoid negative/undefined d0.
clipped_num_res = _np.maximum(residue_weights.sum(), 19)

# Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick
# "Scoring function for automated assessment of protein structure template
# quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8

# Convert logits to probs.
probs = _softmax(logits, axis=-1)

# TM-Score term for every bin.
tm_per_bin = 1. / (1 + _np.square(bin_centers) / _np.square(d0))
# E_distances tm(distance).
predicted_tm_term = (probs * tm_per_bin).sum(-1)

if asym_id is None:
pair_mask = _np.full((num_res,num_res),True)
else:
pair_mask = asym_id[:, None] != asym_id[None, :]

predicted_tm_term *= pair_mask

pair_residue_weights = pair_mask * (residue_weights[None, :] * residue_weights[:, None])
normed_residue_mask = pair_residue_weights / (1e-8 + pair_residue_weights.sum(-1, keepdims=True))
per_alignment = (predicted_tm_term * normed_residue_mask).sum(-1)

return (per_alignment * residue_weights).max()
Loading