Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ __pycache__
docs/build
temp
.coverage
*.ipynb_checkpoints
*/.cache
*/lightning_logs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
protein_property:
_target_: cortex.model.branch.TransformerBranch
out_dim: 8
channel_dim: ${channel_dim}
num_blocks: 2
num_heads: 4
dropout_prob: ${dropout_prob}
is_causal: false
19 changes: 19 additions & 0 deletions cortex/config/hydra/roots/protein_seq_transformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
protein_seq:
_target_: cortex.model.root.TransformerRoot
corruption_process:
_target_: cortex.corruption.MaskCorruptionProcess
tokenizer_transform:
_target_: cortex.transforms.HuggingFaceTokenizerTransform
tokenizer:
_target_: cortex.tokenization.ProteinSequenceTokenizerFast
max_len: 256
out_dim: ${embed_dim}
embed_dim: ${embed_dim}
channel_dim: ${channel_dim}
num_blocks: 2
num_heads: 4
is_causal: false
dropout_prob: ${dropout_prob}
pos_encoding: true
train_transforms: null
eval_transforms: null
2 changes: 2 additions & 0 deletions cortex/model/block/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from ._conv1d_resid_block import Conv1dResidBlock
from ._transformer_block import TransformerBlock

__all__ = [
"Conv1dResidBlock",
"TransformerBlock",
]
44 changes: 44 additions & 0 deletions cortex/model/block/_transformer_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from torch import Tensor, nn

from cortex.model.elemental import MLP, BidirectionalSelfAttention, CausalSelfAttention


class TransformerBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_heads: int = 4,
bias: bool = False,
dropout_p: float = 0.0,
is_causal: bool = False,
):
super().__init__()
self.ln_1 = nn.LayerNorm(in_channels, bias=bias)

if is_causal:
self.attn = CausalSelfAttention(num_heads=num_heads, embed_dim=in_channels, dropout_p=dropout_p, bias=bias)
else:
self.attn = BidirectionalSelfAttention(
num_heads=num_heads, embed_dim=in_channels, dropout_p=dropout_p, bias=bias
)

self.ln_2 = nn.LayerNorm(in_channels, bias=bias)
self.mlp = MLP(in_channels, out_channels, bias=bias, dropout_p=dropout_p)

if not in_channels == out_channels:
self.proj = nn.Linear(in_channels, out_channels, bias=bias)
else:
self.proj = None

def forward(self, inputs: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]:
resid, padding_mask = inputs
x, padding_mask = self.attn((self.ln_1(resid), padding_mask))
x = resid + x

if self.proj is not None:
resid = self.proj(resid)

x = resid + self.mlp(self.ln_2(x))

return x, padding_mask
3 changes: 3 additions & 0 deletions cortex/model/branch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from ._abstract_branch import BranchNode, BranchNodeOutput
from ._conv1d_branch import Conv1dBranch, Conv1dBranchOutput
from ._transformer_branch import TransformerBranch, TransformerBranchOutput

__all__ = [
"BranchNode",
"BranchNodeOutput",
"Conv1dBranch",
"Conv1dBranchOutput",
"TransformerBranch",
"TransformerBranchOutput",
]
101 changes: 101 additions & 0 deletions cortex/model/branch/_transformer_branch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from dataclasses import dataclass

import torch
from torch import nn

from cortex.model.block import TransformerBlock
from cortex.model.branch import BranchNode, BranchNodeOutput
from cortex.model.elemental import (
Apply,
Expression,
MeanPooling,
WeightedMeanPooling,
identity,
)
from cortex.model.trunk import PaddedTrunkOutput


@dataclass
class TransformerBranchOutput(BranchNodeOutput):
branch_mask: torch.Tensor
pooled_features: torch.Tensor


class TransformerBranch(BranchNode):
"""
Branch node which transforms aggregated trunk features to task branch specific features
"""

def __init__(
self,
in_dim: int,
out_dim: int = 64,
channel_dim: int = 64,
num_blocks: int = 2,
num_heads: int = 5,
is_causal: bool = False,
dropout_prob: float = 0.0,
pooling_type: str = "mean",
**kwargs,
):
super().__init__()
# create encoder
self.in_dim = in_dim
self.out_dim = out_dim
self.channel_dim = channel_dim
self.num_blocks = num_blocks

if num_blocks == 0:
# add projection if dims don't match
encoder_modules = [
Expression(identity) if in_dim == out_dim else Apply(nn.Linear(in_dim, out_dim, bias=False))
]
else:
# conv layers expect inputs with shape (batch_size, input_dim, num_tokens)
encoder_modules = []

block_kwargs = {
"num_heads": num_heads,
"is_causal": is_causal,
"dropout_p": dropout_prob,
}

if num_blocks == 1:
encoder_modules.append(TransformerBlock(in_dim, out_dim, **block_kwargs))
elif num_blocks > 1:
encoder_modules.append(TransformerBlock(in_dim, channel_dim, **block_kwargs))
encoder_modules.extend(
[TransformerBlock(channel_dim, channel_dim, **block_kwargs) for _ in range(num_blocks - 2)]
)
encoder_modules.append(TransformerBlock(channel_dim, out_dim, **block_kwargs))

self.encoder = nn.Sequential(*encoder_modules)
if pooling_type == "mean":
self.pooling_op = MeanPooling()
elif pooling_type == "weighted_mean":
self.pooling_op = WeightedMeanPooling(out_dim)
else:
raise NotImplementedError

def forward(
self,
trunk_outputs: PaddedTrunkOutput,
) -> TransformerBranchOutput:
"""
Args:
trunk_outputs: {'trunk_features': torch.Tensor, 'padding_mask': torch.Tensor}
Returns:
outputs: {'branch_features': torch.Tensor, 'branch_mask': torch.Tensor, 'pooled_features': torch.Tensor}
"""
trunk_features = trunk_outputs.trunk_features
padding_mask = trunk_outputs.padding_mask

branch_features, branch_mask = self.encoder((trunk_features, padding_mask.to(trunk_features)))
pooled_features = self.pooling_op(branch_features, branch_mask)

branch_outputs = TransformerBranchOutput(
branch_features=branch_features.contiguous(),
branch_mask=branch_mask,
pooled_features=pooled_features,
)
return branch_outputs
5 changes: 5 additions & 0 deletions cortex/model/elemental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from ._apply import Apply
from ._bidirectional_self_attention import BidirectionalSelfAttention
from ._causal_self_attention import CausalSelfAttention
from ._ddp_standardize import DDPStandardize
from ._expression import Expression
from ._functional import identity, permute_spatial_channel_dims, swish
from ._layernorm import MaskLayerNorm1d
from ._mean_pooling import MeanPooling, WeightedMeanPooling
from ._mlp import MLP
from ._sine_pos_encoder import SinePosEncoder

__all__ = [
"Apply",
"BidirectionalSelfAttention",
"CausalSelfAttention",
"DDPStandardize",
"Expression",
"identity",
Expand Down
37 changes: 37 additions & 0 deletions cortex/model/elemental/_bidirectional_self_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from torch import Tensor, nn


class BidirectionalSelfAttention(nn.Module):
def __init__(self, num_heads: int = 4, embed_dim: int = 32, dropout_p: float = 0.0, bias: bool = False):
super().__init__()
if embed_dim % num_heads != 0:
raise ValueError("num_heads must evenly divide embed_dim")

self.c_attn = nn.Linear(embed_dim, embed_dim * 3, bias=bias)
self.dropout = nn.Dropout(dropout_p)
self.dropout_p = dropout_p
self.head_dim = embed_dim // num_heads
self.num_heads = num_heads

def forward(self, inputs: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]:
x, padding_mask = inputs
seq_len = x.size(-2)
queries, keys, values = self.c_attn(x).chunk(3, dim=-1)

queries = queries.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3)
keys = keys.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3)
values = values.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3)

attn_mask = padding_mask[..., None, :, None]

res = nn.functional.scaled_dot_product_attention(
queries,
keys,
values,
attn_mask=attn_mask,
dropout_p=self.dropout_p if self.training else 0.0,
is_causal=False,
)

res = res.transpose(-2, -3).flatten(start_dim=-2)
return self.dropout(res), padding_mask
35 changes: 35 additions & 0 deletions cortex/model/elemental/_causal_self_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from torch import Tensor, nn


class CausalSelfAttention(nn.Module):
def __init__(self, num_heads: int = 4, embed_dim: int = 32, dropout_p: float = 0.0, bias: bool = False):
super().__init__()
if embed_dim % num_heads != 0:
raise ValueError("num_heads must evenly divide embed_dim")

self.c_attn = nn.Linear(embed_dim, embed_dim * 3, bias=bias)
self.dropout = nn.Dropout(dropout_p)
self.dropout_p = dropout_p
self.head_dim = embed_dim // num_heads
self.num_heads = num_heads

def forward(self, inputs: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]:
x, padding_mask = inputs
seq_len = x.size(-2)
queries, keys, values = self.c_attn(x).chunk(3, dim=-1)

queries = queries.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3)
keys = keys.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3)
values = values.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3)

res = nn.functional.scaled_dot_product_attention(
queries,
keys,
values,
attn_mask=None,
dropout_p=self.dropout_p if self.training else 0.0,
is_causal=True,
)

res = res.transpose(-2, -3).flatten(start_dim=-2)
return self.dropout(res), padding_mask
18 changes: 18 additions & 0 deletions cortex/model/elemental/_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from torch import nn


class MLP(nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int | None = None,
bias: bool = False,
dropout_p: float = 0.0,
):
out_channels = out_channels if out_channels else in_channels
super().__init__(
nn.Linear(in_channels, 4 * in_channels, bias=bias),
nn.GELU(),
nn.Linear(4 * in_channels, out_channels, bias=bias),
nn.Dropout(dropout_p),
)
3 changes: 3 additions & 0 deletions cortex/model/root/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from ._abstract_root import RootNode, RootNodeOutput
from ._conv1d_root import Conv1dRoot, Conv1dRootOutput
from ._transformer_root import TransformerRoot, TransformerRootOutput

__all__ = [
"RootNode",
"RootNodeOutput",
"Conv1dRoot",
"Conv1dRootOutput",
"TransformerRoot",
"TransformerRootOutput",
]
Loading