diff --git a/.gitignore b/.gitignore index c339a42..caa8e50 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ __pycache__ docs/build temp .coverage +*.ipynb_checkpoints +*/.cache +*/lightning_logs diff --git a/cortex/config/hydra/branches/protein_property_transformer.yaml b/cortex/config/hydra/branches/protein_property_transformer.yaml new file mode 100644 index 0000000..c0e20d4 --- /dev/null +++ b/cortex/config/hydra/branches/protein_property_transformer.yaml @@ -0,0 +1,8 @@ +protein_property: + _target_: cortex.model.branch.TransformerBranch + out_dim: 8 + channel_dim: ${channel_dim} + num_blocks: 2 + num_heads: 4 + dropout_prob: ${dropout_prob} + is_causal: false diff --git a/cortex/config/hydra/roots/protein_seq_transformer.yaml b/cortex/config/hydra/roots/protein_seq_transformer.yaml new file mode 100644 index 0000000..3127d25 --- /dev/null +++ b/cortex/config/hydra/roots/protein_seq_transformer.yaml @@ -0,0 +1,19 @@ +protein_seq: + _target_: cortex.model.root.TransformerRoot + corruption_process: + _target_: cortex.corruption.MaskCorruptionProcess + tokenizer_transform: + _target_: cortex.transforms.HuggingFaceTokenizerTransform + tokenizer: + _target_: cortex.tokenization.ProteinSequenceTokenizerFast + max_len: 256 + out_dim: ${embed_dim} + embed_dim: ${embed_dim} + channel_dim: ${channel_dim} + num_blocks: 2 + num_heads: 4 + is_causal: false + dropout_prob: ${dropout_prob} + pos_encoding: true + train_transforms: null + eval_transforms: null diff --git a/cortex/model/block/__init__.py b/cortex/model/block/__init__.py index d05a78e..c604d6d 100644 --- a/cortex/model/block/__init__.py +++ b/cortex/model/block/__init__.py @@ -1,5 +1,7 @@ from ._conv1d_resid_block import Conv1dResidBlock +from ._transformer_block import TransformerBlock __all__ = [ "Conv1dResidBlock", + "TransformerBlock", ] diff --git a/cortex/model/block/_transformer_block.py b/cortex/model/block/_transformer_block.py new file mode 100644 index 0000000..af6c1a9 --- /dev/null +++ b/cortex/model/block/_transformer_block.py @@ -0,0 +1,44 @@ +from torch import Tensor, nn + +from cortex.model.elemental import MLP, BidirectionalSelfAttention, CausalSelfAttention + + +class TransformerBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_heads: int = 4, + bias: bool = False, + dropout_p: float = 0.0, + is_causal: bool = False, + ): + super().__init__() + self.ln_1 = nn.LayerNorm(in_channels, bias=bias) + + if is_causal: + self.attn = CausalSelfAttention(num_heads=num_heads, embed_dim=in_channels, dropout_p=dropout_p, bias=bias) + else: + self.attn = BidirectionalSelfAttention( + num_heads=num_heads, embed_dim=in_channels, dropout_p=dropout_p, bias=bias + ) + + self.ln_2 = nn.LayerNorm(in_channels, bias=bias) + self.mlp = MLP(in_channels, out_channels, bias=bias, dropout_p=dropout_p) + + if not in_channels == out_channels: + self.proj = nn.Linear(in_channels, out_channels, bias=bias) + else: + self.proj = None + + def forward(self, inputs: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: + resid, padding_mask = inputs + x, padding_mask = self.attn((self.ln_1(resid), padding_mask)) + x = resid + x + + if self.proj is not None: + resid = self.proj(resid) + + x = resid + self.mlp(self.ln_2(x)) + + return x, padding_mask diff --git a/cortex/model/branch/__init__.py b/cortex/model/branch/__init__.py index 16ed371..a9f9538 100644 --- a/cortex/model/branch/__init__.py +++ b/cortex/model/branch/__init__.py @@ -1,9 +1,12 @@ from ._abstract_branch import BranchNode, BranchNodeOutput from ._conv1d_branch import Conv1dBranch, Conv1dBranchOutput +from ._transformer_branch import TransformerBranch, TransformerBranchOutput __all__ = [ "BranchNode", "BranchNodeOutput", "Conv1dBranch", "Conv1dBranchOutput", + "TransformerBranch", + "TransformerBranchOutput", ] diff --git a/cortex/model/branch/_transformer_branch.py b/cortex/model/branch/_transformer_branch.py new file mode 100644 index 0000000..29e7a58 --- /dev/null +++ b/cortex/model/branch/_transformer_branch.py @@ -0,0 +1,101 @@ +from dataclasses import dataclass + +import torch +from torch import nn + +from cortex.model.block import TransformerBlock +from cortex.model.branch import BranchNode, BranchNodeOutput +from cortex.model.elemental import ( + Apply, + Expression, + MeanPooling, + WeightedMeanPooling, + identity, +) +from cortex.model.trunk import PaddedTrunkOutput + + +@dataclass +class TransformerBranchOutput(BranchNodeOutput): + branch_mask: torch.Tensor + pooled_features: torch.Tensor + + +class TransformerBranch(BranchNode): + """ + Branch node which transforms aggregated trunk features to task branch specific features + """ + + def __init__( + self, + in_dim: int, + out_dim: int = 64, + channel_dim: int = 64, + num_blocks: int = 2, + num_heads: int = 5, + is_causal: bool = False, + dropout_prob: float = 0.0, + pooling_type: str = "mean", + **kwargs, + ): + super().__init__() + # create encoder + self.in_dim = in_dim + self.out_dim = out_dim + self.channel_dim = channel_dim + self.num_blocks = num_blocks + + if num_blocks == 0: + # add projection if dims don't match + encoder_modules = [ + Expression(identity) if in_dim == out_dim else Apply(nn.Linear(in_dim, out_dim, bias=False)) + ] + else: + # conv layers expect inputs with shape (batch_size, input_dim, num_tokens) + encoder_modules = [] + + block_kwargs = { + "num_heads": num_heads, + "is_causal": is_causal, + "dropout_p": dropout_prob, + } + + if num_blocks == 1: + encoder_modules.append(TransformerBlock(in_dim, out_dim, **block_kwargs)) + elif num_blocks > 1: + encoder_modules.append(TransformerBlock(in_dim, channel_dim, **block_kwargs)) + encoder_modules.extend( + [TransformerBlock(channel_dim, channel_dim, **block_kwargs) for _ in range(num_blocks - 2)] + ) + encoder_modules.append(TransformerBlock(channel_dim, out_dim, **block_kwargs)) + + self.encoder = nn.Sequential(*encoder_modules) + if pooling_type == "mean": + self.pooling_op = MeanPooling() + elif pooling_type == "weighted_mean": + self.pooling_op = WeightedMeanPooling(out_dim) + else: + raise NotImplementedError + + def forward( + self, + trunk_outputs: PaddedTrunkOutput, + ) -> TransformerBranchOutput: + """ + Args: + trunk_outputs: {'trunk_features': torch.Tensor, 'padding_mask': torch.Tensor} + Returns: + outputs: {'branch_features': torch.Tensor, 'branch_mask': torch.Tensor, 'pooled_features': torch.Tensor} + """ + trunk_features = trunk_outputs.trunk_features + padding_mask = trunk_outputs.padding_mask + + branch_features, branch_mask = self.encoder((trunk_features, padding_mask.to(trunk_features))) + pooled_features = self.pooling_op(branch_features, branch_mask) + + branch_outputs = TransformerBranchOutput( + branch_features=branch_features.contiguous(), + branch_mask=branch_mask, + pooled_features=pooled_features, + ) + return branch_outputs diff --git a/cortex/model/elemental/__init__.py b/cortex/model/elemental/__init__.py index 020dbc8..f305dbd 100644 --- a/cortex/model/elemental/__init__.py +++ b/cortex/model/elemental/__init__.py @@ -1,13 +1,18 @@ from ._apply import Apply +from ._bidirectional_self_attention import BidirectionalSelfAttention +from ._causal_self_attention import CausalSelfAttention from ._ddp_standardize import DDPStandardize from ._expression import Expression from ._functional import identity, permute_spatial_channel_dims, swish from ._layernorm import MaskLayerNorm1d from ._mean_pooling import MeanPooling, WeightedMeanPooling +from ._mlp import MLP from ._sine_pos_encoder import SinePosEncoder __all__ = [ "Apply", + "BidirectionalSelfAttention", + "CausalSelfAttention", "DDPStandardize", "Expression", "identity", diff --git a/cortex/model/elemental/_bidirectional_self_attention.py b/cortex/model/elemental/_bidirectional_self_attention.py new file mode 100644 index 0000000..49b1dda --- /dev/null +++ b/cortex/model/elemental/_bidirectional_self_attention.py @@ -0,0 +1,37 @@ +from torch import Tensor, nn + + +class BidirectionalSelfAttention(nn.Module): + def __init__(self, num_heads: int = 4, embed_dim: int = 32, dropout_p: float = 0.0, bias: bool = False): + super().__init__() + if embed_dim % num_heads != 0: + raise ValueError("num_heads must evenly divide embed_dim") + + self.c_attn = nn.Linear(embed_dim, embed_dim * 3, bias=bias) + self.dropout = nn.Dropout(dropout_p) + self.dropout_p = dropout_p + self.head_dim = embed_dim // num_heads + self.num_heads = num_heads + + def forward(self, inputs: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: + x, padding_mask = inputs + seq_len = x.size(-2) + queries, keys, values = self.c_attn(x).chunk(3, dim=-1) + + queries = queries.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3) + keys = keys.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3) + values = values.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3) + + attn_mask = padding_mask[..., None, :, None] + + res = nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=attn_mask, + dropout_p=self.dropout_p if self.training else 0.0, + is_causal=False, + ) + + res = res.transpose(-2, -3).flatten(start_dim=-2) + return self.dropout(res), padding_mask diff --git a/cortex/model/elemental/_causal_self_attention.py b/cortex/model/elemental/_causal_self_attention.py new file mode 100644 index 0000000..0f1b76b --- /dev/null +++ b/cortex/model/elemental/_causal_self_attention.py @@ -0,0 +1,35 @@ +from torch import Tensor, nn + + +class CausalSelfAttention(nn.Module): + def __init__(self, num_heads: int = 4, embed_dim: int = 32, dropout_p: float = 0.0, bias: bool = False): + super().__init__() + if embed_dim % num_heads != 0: + raise ValueError("num_heads must evenly divide embed_dim") + + self.c_attn = nn.Linear(embed_dim, embed_dim * 3, bias=bias) + self.dropout = nn.Dropout(dropout_p) + self.dropout_p = dropout_p + self.head_dim = embed_dim // num_heads + self.num_heads = num_heads + + def forward(self, inputs: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: + x, padding_mask = inputs + seq_len = x.size(-2) + queries, keys, values = self.c_attn(x).chunk(3, dim=-1) + + queries = queries.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3) + keys = keys.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3) + values = values.view(-1, seq_len, self.num_heads, self.head_dim).transpose(-2, -3) + + res = nn.functional.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=None, + dropout_p=self.dropout_p if self.training else 0.0, + is_causal=True, + ) + + res = res.transpose(-2, -3).flatten(start_dim=-2) + return self.dropout(res), padding_mask diff --git a/cortex/model/elemental/_mlp.py b/cortex/model/elemental/_mlp.py new file mode 100644 index 0000000..dd1c09b --- /dev/null +++ b/cortex/model/elemental/_mlp.py @@ -0,0 +1,18 @@ +from torch import nn + + +class MLP(nn.Sequential): + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + bias: bool = False, + dropout_p: float = 0.0, + ): + out_channels = out_channels if out_channels else in_channels + super().__init__( + nn.Linear(in_channels, 4 * in_channels, bias=bias), + nn.GELU(), + nn.Linear(4 * in_channels, out_channels, bias=bias), + nn.Dropout(dropout_p), + ) diff --git a/cortex/model/root/__init__.py b/cortex/model/root/__init__.py index 0a5f1fe..b2f1736 100644 --- a/cortex/model/root/__init__.py +++ b/cortex/model/root/__init__.py @@ -1,9 +1,12 @@ from ._abstract_root import RootNode, RootNodeOutput from ._conv1d_root import Conv1dRoot, Conv1dRootOutput +from ._transformer_root import TransformerRoot, TransformerRootOutput __all__ = [ "RootNode", "RootNodeOutput", "Conv1dRoot", "Conv1dRootOutput", + "TransformerRoot", + "TransformerRootOutput", ] diff --git a/cortex/model/root/_transformer_root.py b/cortex/model/root/_transformer_root.py new file mode 100644 index 0000000..75c4f2f --- /dev/null +++ b/cortex/model/root/_transformer_root.py @@ -0,0 +1,341 @@ +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import numpy as np +import torch +from torch import LongTensor, nn + +from cortex.corruption import CorruptionProcess, GaussianCorruptionProcess, MaskCorruptionProcess +from cortex.model.block import TransformerBlock +from cortex.model.elemental import SinePosEncoder +from cortex.model.root import RootNode +from cortex.transforms import HuggingFaceTokenizerTransform, PadTransform, ToTensor + + +@dataclass +class TransformerRootOutput: + """Output of TransforerEncoderRoot.""" + + root_features: torch.Tensor + padding_mask: torch.Tensor + corrupt_frac: Optional[torch.Tensor] = None + src_tok_idxs: Optional[torch.LongTensor] = None + tgt_tok_idxs: Optional[torch.LongTensor] = None + src_tok_embs: Optional[torch.Tensor] = None + is_corrupted: Optional[torch.Tensor] = None + + +class TransformerRoot(RootNode): + """ + A root node transforming an array of discrete sequences to an array of continuous sequence embeddings + """ + + def __init__( + self, + tokenizer_transform: HuggingFaceTokenizerTransform, + max_len: int, + out_dim: int = 64, + embed_dim: int = 64, + channel_dim: int = 256, + num_blocks: int = 2, + num_heads: int = 4, + is_causal: bool = False, + dropout_prob: float = 0.0, + pos_encoding: bool = True, + train_transforms=None, + eval_transforms=None, + corruption_process: Optional[CorruptionProcess] = None, + **kwargs, + ) -> None: + super().__init__() + self.tokenizer = tokenizer_transform.tokenizer + self.vocab_size = len(self.tokenizer.vocab) + self.max_len = max_len + self.pad_tok_idx = self.tokenizer.padding_idx + if num_blocks >= 1: + self.tok_encoder = nn.Embedding(self.vocab_size, embed_dim, padding_idx=self.pad_tok_idx) + # optional positional encoding + if pos_encoding: + self.pos_encoder = SinePosEncoder(embed_dim, dropout_prob, max_len, batch_first=True) + else: + self.pos_encoder = None + + # create encoder + self.embed_dim = embed_dim + self.num_blocks = num_blocks + if num_blocks >= 1: + self.out_dim = out_dim + encoder_modules = [] + resid_block_kwargs = { + "num_heads": num_heads, + "dropout_p": dropout_prob, + "is_causal": is_causal, + } + if num_blocks == 1: + encoder_modules.append(TransformerBlock(embed_dim, out_dim, **resid_block_kwargs)) + else: + encoder_modules.append(TransformerBlock(embed_dim, channel_dim, **resid_block_kwargs)) + + encoder_modules.extend( + [ + TransformerBlock( + channel_dim, + channel_dim, + **resid_block_kwargs, + ) + for _ in range(num_blocks - 2) + ] + ) + + encoder_modules.append( + TransformerBlock( + channel_dim, + out_dim, + **resid_block_kwargs, + ) + ) + self.encoder = nn.Sequential(*encoder_modules) + + shared_transforms = [ + tokenizer_transform, # convert np.array([str, str, ...]) to list[list[int, int, ...]] + ToTensor(padding_value=self.pad_tok_idx), # convert list[list[int, int, ...]] to tensor + PadTransform(max_length=self.max_len, pad_value=self.pad_tok_idx), # pad to max_len + ] + train_transforms = [] if train_transforms is None else list(train_transforms.values()) + eval_transforms = [] if eval_transforms is None else list(eval_transforms.values()) + self.train_transform = nn.Sequential(*(train_transforms + shared_transforms)) + self.eval_transform = nn.Sequential(*(eval_transforms + shared_transforms)) + self.corruption_process = corruption_process + + def initialize_weights(self, **kwargs): + # default random initialization + pass + + def get_token_embedding(self, tok_idx: int): + return self.tok_encoder(torch.tensor(tok_idx, device=self.device)) + + @property + def device(self): + return self.tok_encoder.weight.device + + def init_seq( + self, + inputs: Optional[Union[np.ndarray, torch.Tensor]] = None, # TODO deprecate + seq_array: Optional[np.ndarray] = None, + tgt_tok_idxs: Optional[LongTensor] = None, + src_tok_embs: Optional[torch.Tensor] = None, + corrupt_frac: float = 0.0, + **kwargs, + ): + # infer input type if not specified + if inputs is not None: + if isinstance(inputs, np.ndarray): + seq_array = inputs + if isinstance(inputs, LongTensor): + tgt_tok_idxs = inputs + elif isinstance(inputs, torch.Tensor): + src_tok_embs = inputs + msg = "inputs is deprecated, use a specific argument instead" + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + + # Determine batch size from any available input + batch_size = None + if seq_array is not None: + batch_size = seq_array.shape[0] + elif tgt_tok_idxs is not None: + batch_size = tgt_tok_idxs.shape[0] + elif src_tok_embs is not None: + batch_size = src_tok_embs.shape[0] + + # Fallback to default batch size of 1 if no inputs are provided + if batch_size is None: + batch_size = 1 + + if "mask_frac" in kwargs: + corrupt_frac = kwargs["mask_frac"] + msg = "mask_frac is deprecated, use corrupt_frac instead." + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + + if self.corruption_process is not None and corrupt_frac is None: + corrupt_frac = self.corruption_process.sample_corrupt_frac(n=batch_size).to(self.device) + elif isinstance(corrupt_frac, float): + corrupt_frac = torch.full((batch_size,), corrupt_frac, device=self.device) + elif isinstance(corrupt_frac, torch.Tensor): + # Move tensor to the correct device + corrupt_frac = corrupt_frac.to(self.device) + else: + corrupt_frac = torch.full((batch_size,), 0.0, device=self.device) + + return seq_array, tgt_tok_idxs, src_tok_embs, corrupt_frac + + def tokenize_seq( + self, + seq_array: Optional[np.ndarray] = None, + tgt_tok_idxs: Optional[LongTensor] = None, + src_tok_embs: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + corrupt_frac: Union[float, torch.Tensor] = 0.0, + is_corrupted: Optional[torch.Tensor] = None, + corruption_allowed: Optional[torch.Tensor] = None, + ): + # begin forward pass from raw sequence + if seq_array is not None: + assert tgt_tok_idxs is None + assert src_tok_embs is None + if self.training: + tgt_tok_idxs = self.train_transform(seq_array) + else: + tgt_tok_idxs = self.eval_transform(seq_array) + tgt_tok_idxs = tgt_tok_idxs.to(self.device) + + # truncate token sequence to max context length + if tgt_tok_idxs is not None: + assert src_tok_embs is None + # truncate to max context length, keep final stop token + if tgt_tok_idxs.size(-1) > self.max_len: + tmp_tok_idxs = tgt_tok_idxs[..., : self.max_len - 1] + tgt_tok_idxs = torch.cat([tmp_tok_idxs, tgt_tok_idxs[..., -1:]], dim=-1) + + if corruption_allowed is None and tgt_tok_idxs is not None: + corruption_allowed = self.tokenizer.get_corruptible_mask(tgt_tok_idxs) + + # begin forward pass from tokenized sequence + if tgt_tok_idxs is not None: + # apply masking corruption + if isinstance(self.corruption_process, MaskCorruptionProcess) and ( + (isinstance(corrupt_frac, float) and corrupt_frac > 0.0) + or (isinstance(corrupt_frac, torch.Tensor) and torch.any(corrupt_frac > 0.0)) + ): + src_tok_idxs, is_corrupted = self.corruption_process( + x_start=tgt_tok_idxs, + mask_val=self.tokenizer.masking_idx, + corruption_allowed=corruption_allowed, + corrupt_frac=corrupt_frac, + ) + else: + src_tok_idxs = tgt_tok_idxs + is_corrupted = ( + torch.full_like(src_tok_idxs, False, dtype=torch.bool) if is_corrupted is None else is_corrupted + ) + + padding_mask = src_tok_idxs != self.pad_tok_idx + + if src_tok_embs is not None: + assert seq_array is None + assert padding_mask is not None + src_tok_idxs = None + + return ( + src_tok_idxs, + tgt_tok_idxs, + corruption_allowed, + is_corrupted, + padding_mask, + ) + + def embed_seq( + self, + src_tok_idxs: Optional[LongTensor] = None, + src_tok_embs: Optional[torch.Tensor] = None, + corrupt_frac: Union[float, torch.Tensor] = 0.0, + is_corrupted: Optional[torch.Tensor] = None, + corruption_allowed: Optional[torch.Tensor] = None, + normalize_embeds: bool = True, + ): + # begin forward pass from token embeddings + if src_tok_embs is None: + src_tok_embs = self.tok_encoder(src_tok_idxs) + if normalize_embeds: + src_tok_embs = src_tok_embs / src_tok_embs.norm(dim=-1, keepdim=True).clamp_min(1e-6) + src_tok_embs = src_tok_embs * math.sqrt(self.embed_dim) + + # apply gaussian embedding corruption + if isinstance(self.corruption_process, GaussianCorruptionProcess) and ( + (isinstance(corrupt_frac, float) and corrupt_frac > 0.0) + or (isinstance(corrupt_frac, torch.Tensor) and torch.any(corrupt_frac > 0.0)) + ): + assert corruption_allowed is not None + src_tok_embs, is_corrupted = self.corruption_process( + x_start=src_tok_embs, + corruption_allowed=corruption_allowed[..., None], + corrupt_frac=corrupt_frac, + ) + is_corrupted = is_corrupted.sum(-1).bool() + else: + none_corrupted = torch.zeros(*src_tok_embs.shape[:-1], dtype=torch.bool).to(src_tok_embs.device) + is_corrupted = none_corrupted if is_corrupted is None else is_corrupted + + return src_tok_embs, is_corrupted + + def process_seq( + self, + src_tok_embs: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + ): + # apply positional encoding if it exists + if self.pos_encoder is not None: + src_features = self.pos_encoder(src_tok_embs) + else: + src_features = src_tok_embs + + # main forward pass + src_features, _ = self.encoder((src_features, padding_mask.to(src_features))) + + return src_features + + def forward( + self, + inputs: Optional[Union[np.ndarray, torch.Tensor]] = None, # TODO deprecate + seq_array: Optional[np.ndarray] = None, + tgt_tok_idxs: Optional[LongTensor] = None, + src_tok_embs: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + corrupt_frac: Union[float, torch.Tensor] = 0.0, + is_corrupted: Optional[torch.Tensor] = None, + corruption_allowed: Optional[torch.Tensor] = None, + **kwargs, + ) -> TransformerRootOutput: + """ + Args: + seq_array: (batch_size,) array of discrete sequences (e.g. text strings) + Returns: + outputs: {'root_features': torch.Tensor, 'padding_mask': torch.Tensor} + """ + seq_array, tgt_tok_idxs, src_tok_embs, corrupt_frac = self.init_seq( + inputs, seq_array, tgt_tok_idxs, src_tok_embs, corrupt_frac, **kwargs + ) + ( + src_tok_idxs, + tgt_tok_idxs, + corruption_allowed, + is_corrupted, + padding_mask, + ) = self.tokenize_seq( + seq_array, + tgt_tok_idxs, + src_tok_embs, + padding_mask, + corrupt_frac, + is_corrupted, + corruption_allowed, + ) + src_tok_embs, is_corrupted = self.embed_seq( + src_tok_idxs, src_tok_embs, corrupt_frac, is_corrupted, corruption_allowed + ) + src_features = self.process_seq(src_tok_embs, padding_mask) + # Make sure corrupt_frac is on the same device as other tensors + if isinstance(corrupt_frac, torch.Tensor): + corrupt_frac = corrupt_frac.to(src_tok_embs.device) + + outputs = TransformerRootOutput( + root_features=src_features.contiguous(), + padding_mask=padding_mask, + src_tok_embs=src_tok_embs, + src_tok_idxs=src_tok_idxs, + tgt_tok_idxs=tgt_tok_idxs, + is_corrupted=is_corrupted, + corrupt_frac=corrupt_frac, + ) + return outputs diff --git a/tests/cortex/model/block/test_transformer_block.py b/tests/cortex/model/block/test_transformer_block.py new file mode 100644 index 0000000..8faddfc --- /dev/null +++ b/tests/cortex/model/block/test_transformer_block.py @@ -0,0 +1,38 @@ +import torch + +from cortex.model.block import TransformerBlock + +BATCH_SIZE = 2 +NUM_HEADS = 3 +EMBED_DIM = 12 +SEQ_LEN = 5 + + +def test_transformer_encoder_block(): + module = TransformerBlock( + in_channels=EMBED_DIM, + out_channels=EMBED_DIM, + num_heads=NUM_HEADS, + is_causal=False, + ) + + x = torch.randn(BATCH_SIZE, SEQ_LEN, EMBED_DIM) + padding_mask = torch.ones(BATCH_SIZE, SEQ_LEN, dtype=torch.bool) + x_prime, _ = module((x, padding_mask)) + + assert x_prime.shape == x.shape + + +def test_transformer_decoder_block(): + module = TransformerBlock( + in_channels=EMBED_DIM, + out_channels=EMBED_DIM, + num_heads=NUM_HEADS, + is_causal=True, + ) + + x = torch.randn(BATCH_SIZE, SEQ_LEN, EMBED_DIM) + padding_mask = torch.ones(BATCH_SIZE, SEQ_LEN, dtype=torch.bool) + x_prime, _ = module((x, padding_mask)) + + assert x_prime.shape == x.shape diff --git a/tests/cortex/model/branch/test_transformer_branch.py b/tests/cortex/model/branch/test_transformer_branch.py new file mode 100644 index 0000000..e8a12f9 --- /dev/null +++ b/tests/cortex/model/branch/test_transformer_branch.py @@ -0,0 +1,88 @@ +import torch + +from cortex.model.branch import TransformerBranch, TransformerBranchOutput +from cortex.model.trunk import PaddedTrunkOutput + + +def test_transformer_encoder_branch(): + in_dim = 12 + out_dim = 12 + embed_dim = 12 + channel_dim = 12 + num_blocks = 7 + num_heads = 3 + max_seq_len = 13 + batch_size = 17 + dropout_prob = 0.125 + is_causal = False + + branch_node = TransformerBranch( + in_dim=in_dim, + out_dim=out_dim, + embed_dim=embed_dim, + channel_dim=channel_dim, + num_blocks=num_blocks, + num_heads=num_heads, + is_causal=is_causal, + dropout_prob=dropout_prob, + ) + + trunk_output = PaddedTrunkOutput( + trunk_features=torch.rand(batch_size, max_seq_len, in_dim), + padding_mask=torch.ones(batch_size, max_seq_len, dtype=torch.float), + ) + branch_output = branch_node(trunk_output) + assert isinstance(branch_output, TransformerBranchOutput) + branch_features = branch_output.branch_features + branch_mask = branch_output.branch_mask + pooled_features = branch_output.pooled_features + + assert torch.is_tensor(branch_features) + assert torch.is_tensor(branch_mask) + assert torch.is_tensor(pooled_features) + + assert branch_features.size() == torch.Size((batch_size, max_seq_len, out_dim)) + assert branch_mask.size() == torch.Size((batch_size, max_seq_len)) + assert pooled_features.size() == torch.Size((batch_size, out_dim)) + + +def test_transformer_decoder_branch(): + in_dim = 12 + out_dim = 12 + embed_dim = 12 + channel_dim = 12 + num_blocks = 7 + num_heads = 3 + max_seq_len = 13 + batch_size = 17 + dropout_prob = 0.125 + is_causal = True + + branch_node = TransformerBranch( + in_dim=in_dim, + out_dim=out_dim, + embed_dim=embed_dim, + channel_dim=channel_dim, + num_blocks=num_blocks, + num_heads=num_heads, + is_causal=is_causal, + dropout_prob=dropout_prob, + ) + + trunk_output = PaddedTrunkOutput( + trunk_features=torch.rand(batch_size, max_seq_len, in_dim), + padding_mask=torch.ones(batch_size, max_seq_len, dtype=torch.float), + ) + branch_output = branch_node(trunk_output) + assert isinstance(branch_output, TransformerBranchOutput) + branch_features = branch_output.branch_features + branch_mask = branch_output.branch_mask + pooled_features = branch_output.pooled_features + + assert torch.is_tensor(branch_features) + assert torch.is_tensor(branch_mask) + assert torch.is_tensor(pooled_features) + + assert branch_features.size() == torch.Size((batch_size, max_seq_len, out_dim)) + assert branch_mask.size() == torch.Size((batch_size, max_seq_len)) + assert pooled_features.size() == torch.Size((batch_size, out_dim)) diff --git a/tests/cortex/model/elemental/test_bidirectional_self_attention.py b/tests/cortex/model/elemental/test_bidirectional_self_attention.py new file mode 100644 index 0000000..c0d6842 --- /dev/null +++ b/tests/cortex/model/elemental/test_bidirectional_self_attention.py @@ -0,0 +1,18 @@ +import torch + +from cortex.model.elemental import BidirectionalSelfAttention + +BATCH_SIZE = 2 +NUM_HEADS = 3 +EMBED_DIM = 12 +SEQ_LEN = 5 + + +def test_bidirectional_self_attention(): + module = BidirectionalSelfAttention(num_heads=NUM_HEADS, embed_dim=EMBED_DIM, dropout_p=0.0, bias=False) + + x = torch.randn(BATCH_SIZE, SEQ_LEN, EMBED_DIM) + padding_mask = torch.ones(BATCH_SIZE, SEQ_LEN, dtype=torch.bool) + x_prime, _ = module((x, padding_mask)) + + assert x_prime.shape == x.shape diff --git a/tests/cortex/model/elemental/test_causal_self_attention.py b/tests/cortex/model/elemental/test_causal_self_attention.py new file mode 100644 index 0000000..b8edb42 --- /dev/null +++ b/tests/cortex/model/elemental/test_causal_self_attention.py @@ -0,0 +1,18 @@ +import torch + +from cortex.model.elemental import CausalSelfAttention + +BATCH_SIZE = 2 +NUM_HEADS = 3 +EMBED_DIM = 12 +SEQ_LEN = 5 + + +def test_causal_self_attention(): + module = CausalSelfAttention(num_heads=NUM_HEADS, embed_dim=EMBED_DIM, dropout_p=0.0, bias=False) + + x = torch.randn(BATCH_SIZE, SEQ_LEN, EMBED_DIM) + padding_mask = torch.ones(BATCH_SIZE, SEQ_LEN, dtype=torch.bool) + x_prime, _ = module((x, padding_mask)) + + assert x_prime.shape == x.shape diff --git a/tests/cortex/model/elemental/test_mlp.py b/tests/cortex/model/elemental/test_mlp.py new file mode 100644 index 0000000..f36a7ac --- /dev/null +++ b/tests/cortex/model/elemental/test_mlp.py @@ -0,0 +1,13 @@ +import torch + +from cortex.model.elemental import MLP + + +def test_mlp(): + in_channels = 32 + module = MLP(in_channels) + + x = torch.randn(2, 3, in_channels) + res = module(x) + + assert res.shape == x.shape diff --git a/tests/cortex/model/root/test_transformer_root.py b/tests/cortex/model/root/test_transformer_root.py new file mode 100644 index 0000000..59fd3a0 --- /dev/null +++ b/tests/cortex/model/root/test_transformer_root.py @@ -0,0 +1,172 @@ +import numpy as np +import torch + +from cortex.constants import COMPLEX_SEP_TOKEN +from cortex.corruption import MaskCorruptionProcess +from cortex.model.root import TransformerRoot, TransformerRootOutput +from cortex.tokenization import ProteinSequenceTokenizerFast +from cortex.transforms import HuggingFaceTokenizerTransform + + +def test_transformer_encoder_root(): + batch_size = 2 + out_dim = 12 + embed_dim = 12 + channel_dim = 12 + num_heads = 3 + is_causal = False + num_blocks = 7 + + max_seq_len = 13 + dropout_prob = 0.125 + pos_encoding = True + tokenizer = ProteinSequenceTokenizerFast() + + root_node = TransformerRoot( + tokenizer_transform=HuggingFaceTokenizerTransform(tokenizer), + max_len=max_seq_len, + out_dim=out_dim, + embed_dim=embed_dim, + channel_dim=channel_dim, + num_blocks=num_blocks, + num_heads=num_heads, + is_causal=is_causal, + dropout_prob=dropout_prob, + pos_encoding=pos_encoding, + ) + + # src_tok_idxs = torch.randint(0, vocab_size, (batch_size, max_seq_len)) + seq_array = np.array( + [ + f"{COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V C C", + f"{COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V C C", + ] + ) + root_output = root_node(seq_array) + assert isinstance(root_output, TransformerRootOutput) + root_features = root_output.root_features + padding_mask = root_output.padding_mask + + assert torch.is_tensor(root_features) + assert torch.is_tensor(padding_mask) + + assert root_features.size() == torch.Size((batch_size, max_seq_len, out_dim)) + assert padding_mask.size() == torch.Size((batch_size, max_seq_len)) + + +def test_transformer_encoder_root_with_per_element_corrupt_frac(): + """Test TransformerEncoderRoot handles per-element corrupt_frac correctly.""" + batch_size = 4 + out_dim = 12 + embed_dim = 12 + channel_dim = 12 + num_heads = 3 + is_causal = False + max_seq_len = 13 + tokenizer = ProteinSequenceTokenizerFast() + + # Create a root node with corruption process + corruption_process = MaskCorruptionProcess() + root_node = TransformerRoot( + tokenizer_transform=HuggingFaceTokenizerTransform(tokenizer), + max_len=max_seq_len, + out_dim=out_dim, + embed_dim=embed_dim, + channel_dim=channel_dim, + num_heads=num_heads, + is_causal=is_causal, + corruption_process=corruption_process, + ) + + # Create input sequences + seq_array = np.array( + [ + f"{COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V C C", + f"{COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V C C", + f"{COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V C C", + f"{COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V C C", + ] + ) + + # Test case 1: Scalar corrupt_frac + scalar_corrupt_frac = 0.3 + root_output1 = root_node(seq_array, corrupt_frac=scalar_corrupt_frac) + + # Verify corrupt_frac is a tensor with batch dimension + assert isinstance(root_output1.corrupt_frac, torch.Tensor) + assert root_output1.corrupt_frac.shape[0] == batch_size + assert torch.allclose( + root_output1.corrupt_frac, + torch.tensor([scalar_corrupt_frac] * batch_size, device=root_output1.corrupt_frac.device), + ) + + # Test case 2: Per-element corrupt_frac + per_element_corrupt_frac = torch.tensor([0.1, 0.2, 0.3, 0.4]) + root_output2 = root_node(seq_array, corrupt_frac=per_element_corrupt_frac) + + # Verify corrupt_frac maintains per-element values + assert isinstance(root_output2.corrupt_frac, torch.Tensor) + assert root_output2.corrupt_frac.shape[0] == batch_size + + # Debug: Print the actual values + print(f"Expected: {per_element_corrupt_frac}") + print(f"Actual: {root_output2.corrupt_frac}") + + # Temporarily commenting out this assertion until we fix the issue + assert torch.allclose(root_output2.corrupt_frac, per_element_corrupt_frac.to(root_output2.corrupt_frac.device)) + + # Test case 3: None corrupt_frac (should sample from corruption process) + root_output3 = root_node(seq_array, corrupt_frac=None) + + # Verify corrupt_frac is a tensor with batch dimension + assert isinstance(root_output3.corrupt_frac, torch.Tensor) + assert root_output3.corrupt_frac.shape[0] == batch_size + # Values should be between 0 and 1 + assert torch.all(root_output3.corrupt_frac >= 0.0) + assert torch.all(root_output3.corrupt_frac <= 1.0) + + +def test_transformer_decoder_root(): + batch_size = 2 + out_dim = 12 + embed_dim = 12 + channel_dim = 12 + num_heads = 3 + is_causal = True + num_blocks = 7 + + max_seq_len = 13 + dropout_prob = 0.125 + pos_encoding = True + tokenizer = ProteinSequenceTokenizerFast() + + root_node = TransformerRoot( + tokenizer_transform=HuggingFaceTokenizerTransform(tokenizer), + max_len=max_seq_len, + out_dim=out_dim, + embed_dim=embed_dim, + channel_dim=channel_dim, + num_blocks=num_blocks, + num_heads=num_heads, + is_causal=is_causal, + dropout_prob=dropout_prob, + pos_encoding=pos_encoding, + ) + + # src_tok_idxs = torch.randint(0, vocab_size, (batch_size, max_seq_len)) + seq_array = np.array( + [ + f"{COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V C C", + f"{COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V {COMPLEX_SEP_TOKEN} A V C C", + ] + ) + root_output = root_node(seq_array) + assert isinstance(root_output, TransformerRootOutput) + root_features = root_output.root_features + padding_mask = root_output.padding_mask + + assert torch.is_tensor(root_features) + assert torch.is_tensor(padding_mask) + + assert root_features.size() == torch.Size((batch_size, max_seq_len, out_dim)) + assert padding_mask.size() == torch.Size((batch_size, max_seq_len)) diff --git a/tutorials/hydra/branches/protein_property_transformer.yaml b/tutorials/hydra/branches/protein_property_transformer.yaml new file mode 100644 index 0000000..5a4ede3 --- /dev/null +++ b/tutorials/hydra/branches/protein_property_transformer.yaml @@ -0,0 +1,7 @@ +protein_property: + _target_: cortex.model.branch.TransformerEncoderBranch + out_dim: 8 + channel_dim: ${feature_dim} + num_blocks: 1 + num_heads: 4 + is_causal: false diff --git a/tutorials/hydra/roots/protein_seq_transformer.yaml b/tutorials/hydra/roots/protein_seq_transformer.yaml new file mode 100644 index 0000000..ea93943 --- /dev/null +++ b/tutorials/hydra/roots/protein_seq_transformer.yaml @@ -0,0 +1,15 @@ +protein_seq: + _target_: cortex.model.root.TransformerEncoderRoot + corruption_process: + _target_: cortex.corruption.MaskCorruptionProcess + tokenizer_transform: + _target_: cortex.transforms.HuggingFaceTokenizerTransform + tokenizer: + _target_: cortex.tokenization.ProteinSequenceTokenizerFast + max_len: 256 + embed_dim: ${feature_dim} + channel_dim: ${feature_dim} + out_dim: ${feature_dim} + num_blocks: 2 + num_heads: 4 + is_causal: false