diff --git a/.gitignore b/.gitignore index 04e1d60..f523615 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,6 @@ __pycache__ # version file (managed by setuptools-scm) src/clm/_version.py + +# Snakemake logs +.snakemake diff --git a/pyproject.toml b/pyproject.toml index 63eff6e..86aa9ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,11 +18,14 @@ classifiers = [ dependencies = [ "deepsmiles", + "einops", "fcd_torch", "numpy", + "opt_einsum", "pandas", "pulp<2.8.0", "rdkit", + "s4dd @ git+https://github.com/GuptaVishu2002/s4-for-de-novo-drug-design.git@fix-module-library-packaging", "scikit-learn", "scipy==1.11.1", "selfies", diff --git a/requirements.txt b/requirements.txt index 7657e54..e5502ff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -128,3 +128,6 @@ virtualenv==20.26.2 wrapt==1.16.0 yte==1.5.4 zipp==3.19.0 +einops==0.6.0 +opt_einsum==3.3.0 +s4dd @ git+https://github.com/GuptaVishu2002/s4-for-de-novo-drug-design.git@fix-module-library-packaging diff --git a/src/clm/commands/sample_molecules_RNN.py b/src/clm/commands/sample_molecules_RNN.py index 7539969..383039e 100644 --- a/src/clm/commands/sample_molecules_RNN.py +++ b/src/clm/commands/sample_molecules_RNN.py @@ -6,7 +6,12 @@ from tqdm import tqdm from clm.datasets import Vocabulary, SelfiesVocabulary -from clm.models import RNN, ConditionalRNN +from clm.models import ( + RNN, + ConditionalRNN, + Transformer, + StructuredStateSpaceSequenceModel, +) # , H3Model, H3ConvModel, HyenaModel from clm.functions import load_dataset, write_to_csv_file logger = logging.getLogger(__name__) @@ -132,7 +137,8 @@ def sample_molecules_RNN( vocab = Vocabulary(vocab_file=vocab_file) heldout_dataset = None - if conditional: + + if rnn_type == "S4": assert ( heldout_file is not None ), "heldout_file must be provided for conditional RNN Model" @@ -141,29 +147,124 @@ def sample_molecules_RNN( input_file=heldout_file, vocab_file=vocab_file, ) - model = ConditionalRNN( - vocab, - rnn_type=rnn_type, + model = StructuredStateSpaceSequenceModel( + vocabulary=vocab, # heldout_dataset.vocabulary + model_dim=embedding_size, + state_dim=64, n_layers=n_layers, - embedding_size=embedding_size, - hidden_size=hidden_size, + n_ssm=1, dropout=dropout, - num_descriptors=heldout_dataset.n_descriptors, - conditional_emb=conditional_emb, - conditional_emb_l=conditional_emb_l, - conditional_dec=conditional_dec, - conditional_dec_l=conditional_dec_l, - conditional_h=conditional_h, ) - else: - model = RNN( - vocab, - rnn_type=rnn_type, - n_layers=n_layers, + # elif rnn_type == "H3": + # assert ( + # heldout_file is not None + # ), "heldout_file must be provided for conditional RNN Model" + # heldout_dataset = load_dataset( + # representation=representation, + # input_file=heldout_file, + # vocab_file=vocab_file, + # ) + # model = H3Model( + # vocabulary=vocab, + # n_layers=n_layers, + # d_model=embedding_size, + # d_state=64, + # head_dim=1, + # dropout=dropout, + # max_len=250, + # use_fast_fftconv=False, + # ) + # elif rnn_type == "H3Conv": + # assert ( + # heldout_file is not None + # ), "heldout_file must be provided for conditional RNN Model" + # heldout_dataset = load_dataset( + # representation=representation, + # input_file=heldout_file, + # vocab_file=vocab_file, + # ) + # model = H3ConvModel( + # vocabulary=vocab, + # n_layers=n_layers, + # d_model=embedding_size, + # head_dim=1, + # dropout=dropout, + # max_len=250, + # use_fast_fftconv=False, + # ) + # elif rnn_type == "Hyena": + # assert ( + # heldout_file is not None + # ), "heldout_file must be provided for conditional RNN Model" + # heldout_dataset = load_dataset( + # representation=representation, + # input_file=heldout_file, + # vocab_file=vocab_file, + # ) + # model = HyenaModel( + # vocabulary=vocab, + # n_layers=n_layers, + # d_model=embedding_size, + # order=2, + # filter_order=64, + # num_heads=1, + # dropout=dropout, + # max_len=250, + # inner_factor=1, + # ) + + elif rnn_type == "Transformer": + assert ( + heldout_file is not None + ), "heldout_file must be provided for conditional RNN Model" + heldout_dataset = load_dataset( + representation=representation, + input_file=heldout_file, + vocab_file=vocab_file, + ) + model = Transformer( + vocabulary=vocab, + n_blocks=n_layers, + n_heads=4, embedding_size=embedding_size, - hidden_size=hidden_size, dropout=dropout, + exp_factor=4, + bias=True, ) + + else: + if conditional: + assert ( + heldout_file is not None + ), "heldout_file must be provided for conditional RNN Model" + heldout_dataset = load_dataset( + representation=representation, + input_file=heldout_file, + vocab_file=vocab_file, + ) + model = ConditionalRNN( + vocab, + rnn_type=rnn_type, + n_layers=n_layers, + embedding_size=embedding_size, + hidden_size=hidden_size, + dropout=dropout, + num_descriptors=heldout_dataset.n_descriptors, + conditional_emb=conditional_emb, + conditional_emb_l=conditional_emb_l, + conditional_dec=conditional_dec, + conditional_dec_l=conditional_dec_l, + conditional_h=conditional_h, + ) + else: + model = RNN( + vocab, + rnn_type=rnn_type, + n_layers=n_layers, + embedding_size=embedding_size, + hidden_size=hidden_size, + dropout=dropout, + ) logging.info(vocab.dictionary) if torch.cuda.is_available(): @@ -183,8 +284,12 @@ def sample_molecules_RNN( n_sequences = min(batch_size, sample_mols - i) descriptors = None if heldout_dataset is not None: + # Use modulo to cycle through heldout_dataset + descriptor_indices = [ + (i + j) % len(heldout_dataset) for j in range(n_sequences) + ] descriptors = torch.stack( - [heldout_dataset[_i][1] for _i in range(i, i + n_sequences)] + [heldout_dataset[idx][1] for idx in descriptor_indices] ) descriptors = descriptors.to(model.device) sampled_smiles, losses = model.sample( diff --git a/src/clm/commands/train_models_RNN.py b/src/clm/commands/train_models_RNN.py index 121e484..394064f 100644 --- a/src/clm/commands/train_models_RNN.py +++ b/src/clm/commands/train_models_RNN.py @@ -6,10 +6,19 @@ from torch.utils.data import DataLoader from tqdm import tqdm from rdkit import rdBase -from clm.models import RNN, ConditionalRNN +from clm.models import ( + RNN, + ConditionalRNN, + Transformer, + StructuredStateSpaceSequenceModel, +) # , H3Model, H3ConvModel, HyenaModel from clm.loggers import EarlyStopping, track_loss, print_update from clm.functions import write_smiles, load_dataset +import warnings + +warnings.filterwarnings("ignore", category=FutureWarning) + # suppress Chem.MolFromSmiles error output rdBase.DisableLog("rdApp.error") logger = logging.getLogger(__name__) @@ -197,31 +206,89 @@ def train_models_RNN( dataset = load_dataset(representation, input_file, vocab_file) - if conditional: - model = ConditionalRNN( - dataset.vocabulary, - rnn_type=rnn_type, + if rnn_type == "S4": + model = StructuredStateSpaceSequenceModel( + vocabulary=dataset.vocabulary, + model_dim=embedding_size, + state_dim=64, n_layers=n_layers, - embedding_size=embedding_size, - hidden_size=hidden_size, + n_ssm=1, dropout=dropout, - num_descriptors=dataset.n_descriptors, - conditional_emb=conditional_emb, - conditional_emb_l=conditional_emb_l, - conditional_dec=conditional_dec, - conditional_dec_l=conditional_dec_l, - conditional_h=conditional_h, ) - else: - model = RNN( - dataset.vocabulary, - rnn_type=rnn_type, - n_layers=n_layers, + + # elif rnn_type == "H3": + # model = H3Model( + # vocabulary=dataset.vocabulary, + # n_layers=n_layers, + # d_model=embedding_size, + # d_state=64, + # head_dim=1, + # dropout=dropout, + # max_len=250, + # use_fast_fftconv=False, + # ) + + # elif rnn_type == "H3Conv": + # model = H3ConvModel( + # vocabulary=dataset.vocabulary, + # n_layers=n_layers, + # d_model=embedding_size, + # head_dim=1, + # dropout=dropout, + # max_len=250, + # use_fast_fftconv=False, + # ) + + # elif rnn_type == "Hyena": + # model = HyenaModel( + # vocabulary=dataset.vocabulary, + # n_layers=n_layers, + # d_model=embedding_size, + # order=2, + # filter_order=64, + # num_heads=1, + # dropout=dropout, + # max_len=250, + # inner_factor=1, + # ) + + elif rnn_type == "Transformer": + model = Transformer( + vocabulary=dataset.vocabulary, + n_blocks=n_layers, + n_heads=4, embedding_size=embedding_size, - hidden_size=hidden_size, dropout=dropout, + exp_factor=4, + bias=True, ) + else: + if conditional: + model = ConditionalRNN( + dataset.vocabulary, + rnn_type=rnn_type, + n_layers=n_layers, + embedding_size=embedding_size, + hidden_size=hidden_size, + dropout=dropout, + num_descriptors=dataset.n_descriptors, + conditional_emb=conditional_emb, + conditional_emb_l=conditional_emb_l, + conditional_dec=conditional_dec, + conditional_dec_l=conditional_dec_l, + conditional_h=conditional_h, + ) + else: + model = RNN( + dataset.vocabulary, + rnn_type=rnn_type, + n_layers=n_layers, + embedding_size=embedding_size, + hidden_size=hidden_size, + dropout=dropout, + ) + logger.info(dataset.vocabulary.dictionary) loader = DataLoader( diff --git a/src/clm/loggers.py b/src/clm/loggers.py index 62e6b40..1b89de8 100644 --- a/src/clm/loggers.py +++ b/src/clm/loggers.py @@ -1,6 +1,7 @@ import os import pandas as pd import torch +import math from rdkit import Chem from tqdm import tqdm from clm.models import ConditionalRNN @@ -34,9 +35,25 @@ def __init__(self, patience=100): self.best_loss = None self.step_at_best = 0 self.stop = False + self.nan_counter = 0 print("instantiated early stopping with patience=" + str(self.patience)) def __call__(self, val_loss, model, output_file, step_idx): + # Check for NaN/Inf + if math.isnan(val_loss) or math.isinf(val_loss): + self.nan_counter += 1 + print( + f"NaN/Inf loss detected at step {step_idx} ({self.nan_counter}/3)" + ) + if self.nan_counter >= 3: + self.stop = True + print("Stopping training after 3 consecutive NaN/Inf losses.") + if self.best_loss is not None: + print( + f"Best model (loss={self.best_loss:.4f}) already saved." + ) + return + # do nothing if early stopping is disabled if self.patience > 0: if self.best_loss is None: diff --git a/src/clm/models.py b/src/clm/models.py index e301f5b..ad365a8 100644 --- a/src/clm/models.py +++ b/src/clm/models.py @@ -4,6 +4,822 @@ import torch.nn.functional as F from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence +# from clm.src.models.sequence.h3 import H3 +# from clm.src.models.sequence.h3_conv import H3Conv +# from clm.src.models.sequence.hyena_components import HyenaOperator + +from s4dd.module_library.sequence_model import SequenceModel + + +# class H3Model(nn.Module): +# def __init__( +# self, +# vocabulary, +# n_layers=4, +# d_model=256, +# d_state=64, +# head_dim=1, +# dropout=0.1, +# max_len=250, +# use_fast_fftconv=False, +# ): +# super(H3Model, self).__init__() + +# if H3 is None: +# raise ImportError( +# "H3 modules not found. Make sure src.models.sequence.h3 is available." +# ) + +# # detect device +# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# # vocabulary +# self.vocabulary = vocabulary +# self.vocabulary_size = len(self.vocabulary) +# self.padding_idx = self.vocabulary.dictionary[""] +# padding_t = torch.tensor(self.padding_idx).to(self.device) + +# # hyperparams +# self.n_layers = n_layers +# self.d_model = d_model +# self.d_state = d_state +# self.head_dim = head_dim +# self.dropout = dropout +# self.max_len = max_len +# self.use_fast_fftconv = use_fast_fftconv + +# # model components +# self.embedding = nn.Embedding( +# self.vocabulary_size, self.d_model, padding_idx=padding_t +# ) + +# # H3 layers +# self.layers = nn.ModuleList([ +# H3( +# d_model=self.d_model, +# d_state=self.d_state, +# l_max=self.max_len, +# head_dim=self.head_dim, +# use_fast_fftconv=self.use_fast_fftconv, +# dropout=self.dropout, +# layer_idx=i, +# ) +# for i in range(self.n_layers) +# ]) + +# # dropout and output +# self.norm = nn.LayerNorm(self.d_model) +# self.dropout_layer = nn.Dropout(dropout) +# self.output_projection = nn.Linear(self.d_model, self.vocabulary_size) + +# # loss function (ignoring padding) +# self.loss_fn = nn.CrossEntropyLoss( +# ignore_index=self.padding_idx, reduction="none" +# ) + +# # move to GPU +# if torch.cuda.is_available(): +# self.cuda() + +# def forward(self, x, inference_params=None): + +# batch_size, seq_len = x.size() + +# # Embed the input +# x = self.embedding(x) # (batch_size, seq_len, d_model) + +# # Pass through H3 layers +# for layer in self.layers: +# x = layer(x, inference_params=inference_params) +# if self.dropout > 0: +# x = self.dropout_layer(x) + +# # Normalize and project to vocabulary +# x = self.norm(x) +# logits = self.output_projection(x) # (batch_size, seq_len, vocab_size) + +# return logits + +# def loss(self, batch): +# if len(batch) == 3: +# padded, lengths, _ = batch +# else: +# padded, lengths = batch + +# padded = padded.to(self.device) + +# # Handle different input formats +# if padded.dim() == 2: +# if padded.shape[0] > padded.shape[1]: +# padded = padded.transpose(0, 1) + +# # Forward pass +# logits = self(padded) + +# # Calculate loss +# targets = padded[:, 1:] +# logits = logits[:, :-1, :] + +# loss = 0.0 +# actual_len = min(logits.shape[1], targets.shape[1]) + +# for char_idx in range(actual_len): +# loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) + +# return loss.mean() + +# def sample( +# self, +# *, +# n_sequences, +# max_len=None, +# return_smiles=True, +# return_losses=False, +# descriptors=None, +# ): +# if max_len is None: +# max_len = self.max_len + +# self.eval() + +# # Get start/stop tokens +# start_token = self.vocabulary.dictionary["SOS"] +# stop_token = self.vocabulary.dictionary["EOS"] +# pad_token = self.vocabulary.dictionary[""] + +# # Create inference params +# class InferenceParams: +# def __init__(self, max_seqlen, batch_size): +# self.max_seqlen = max_seqlen +# self.max_batch_size = batch_size +# self.sequence_len_offset = 0 +# self.key_value_memory_dict = {} + +# inference_params = InferenceParams(max_len, n_sequences) + +# # Initialize with start tokens - keep only current token for recurrent stepping +# current_token = torch.full( +# (n_sequences, 1), start_token, dtype=torch.long, device=self.device +# ) + +# loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) + +# finished = torch.zeros(n_sequences, dtype=torch.bool, device=self.device) +# log_probs = torch.zeros(n_sequences, device=self.device) +# sequences = [] + +# with torch.no_grad(): +# for step in range(max_len): +# # Process only the current token in recurrent mode +# logits = self(current_token, inference_params=inference_params) +# logits = logits[:, -1, :] # Get last (and only) position + +# logits = torch.clamp(logits, min=-1e4, max=1e4) +# prob = F.softmax(logits, dim=-1) + +# if torch.isnan(prob).any() or torch.isinf(prob).any(): +# break + +# outputs = torch.multinomial(prob, num_samples=1) +# sequences.append(outputs) + +# log_prob = F.log_softmax(logits, dim=-1) +# losses = loss_fn(log_prob, outputs.squeeze(1)) +# losses[finished] = 0 +# log_probs += losses + +# # Update current token for next step (don't accumulate) +# current_token = outputs +# inference_params.sequence_len_offset += 1 + +# finished = finished | (outputs.squeeze(1) == stop_token) +# if finished.all(): +# break + +# seqs = torch.cat(sequences, 1) if sequences else torch.full( +# (n_sequences, 1), start_token, dtype=torch.long, device=self.device +# ) + +# if return_smiles: +# outputs = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] +# else: +# outputs = sequences + +# if return_losses: +# return outputs, log_probs.detach().cpu().numpy() +# else: +# return outputs + + +# class H3ConvModel(nn.Module): +# def __init__( +# self, +# vocabulary, +# n_layers=4, +# d_model=256, +# head_dim=1, +# dropout=0.1, +# max_len=250, +# use_fast_fftconv=False, +# ): +# super(H3ConvModel, self).__init__() + +# if H3Conv is None: +# raise ImportError( +# "H3Conv modules not found. Make sure src.models.sequence.h3_conv is available." +# ) + +# # detect device +# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# # vocabulary +# self.vocabulary = vocabulary +# self.vocabulary_size = len(self.vocabulary) +# self.padding_idx = self.vocabulary.dictionary[""] +# padding_t = torch.tensor(self.padding_idx).to(self.device) + +# # hyperparams +# self.n_layers = n_layers +# self.d_model = d_model +# self.head_dim = head_dim +# self.dropout = dropout +# self.max_len = max_len +# self.use_fast_fftconv = use_fast_fftconv + +# # model components +# self.embedding = nn.Embedding( +# self.vocabulary_size, self.d_model, padding_idx=padding_t +# ) + +# # H3Conv layers +# self.layers = nn.ModuleList([ +# H3Conv( +# d_model=self.d_model, +# l_max=self.max_len, +# head_dim=self.head_dim, +# use_fast_fftconv=self.use_fast_fftconv, +# dropout=self.dropout, +# layer_idx=i, +# ) +# for i in range(self.n_layers) +# ]) + +# # dropout and output +# self.norm = nn.LayerNorm(self.d_model) +# self.dropout_layer = nn.Dropout(dropout) +# self.output_projection = nn.Linear(self.d_model, self.vocabulary_size) + +# # loss function (ignoring padding) +# self.loss_fn = nn.CrossEntropyLoss( +# ignore_index=self.padding_idx, reduction="none" +# ) + +# # move to GPU +# if torch.cuda.is_available(): +# self.cuda() + +# def forward(self, x, inference_params=None): +# batch_size, seq_len = x.size() + +# # Embed the input +# x = self.embedding(x) # (batch_size, seq_len, d_model) + +# # Pass through H3Conv layers +# for layer in self.layers: +# x = layer(x, inference_params=inference_params) +# if self.dropout > 0: +# x = self.dropout_layer(x) + +# # Normalize and project to vocabulary +# x = self.norm(x) +# logits = self.output_projection(x) + +# return logits + +# def loss(self, batch): +# if len(batch) == 3: +# padded, lengths, _ = batch +# else: +# padded, lengths = batch + +# padded = padded.to(self.device) + +# # Handle different input formats +# if padded.dim() == 2: +# if padded.shape[0] > padded.shape[1]: +# padded = padded.transpose(0, 1) + +# # Forward pass +# logits = self(padded) + +# # Calculate loss +# targets = padded[:, 1:] +# logits = logits[:, :-1, :] + +# loss = 0.0 +# actual_len = min(logits.shape[1], targets.shape[1]) + +# for char_idx in range(actual_len): +# loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) + +# return loss.mean() + +# def sample( +# self, +# *, +# n_sequences, +# max_len=None, +# return_smiles=True, +# return_losses=False, +# descriptors=None, +# ): +# if max_len is None: +# max_len = self.max_len + +# self.eval() + +# start_token = self.vocabulary.dictionary["SOS"] +# stop_token = self.vocabulary.dictionary["EOS"] +# pad_token = self.vocabulary.dictionary[""] + +# # H3Conv doesn't use stateful inference, process full sequence each time +# inputs = torch.full( +# (n_sequences, 1), start_token, dtype=torch.long, device=self.device +# ) + +# loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) + +# finished = torch.zeros(n_sequences, dtype=torch.bool, device=self.device) +# log_probs = torch.zeros(n_sequences, device=self.device) +# sequences = [] + +# with torch.no_grad(): +# for step in range(max_len): +# logits = self(inputs) +# logits = logits[:, -1, :] + +# logits = torch.clamp(logits, min=-1e4, max=1e4) +# prob = F.softmax(logits, dim=-1) + +# if torch.isnan(prob).any() or torch.isinf(prob).any(): +# break + +# outputs = torch.multinomial(prob, num_samples=1) +# sequences.append(outputs) + +# log_prob = F.log_softmax(logits, dim=-1) +# losses = loss_fn(log_prob, outputs.squeeze(1)) +# losses[finished] = 0 +# log_probs += losses + +# inputs = torch.cat([inputs, outputs], dim=1) + +# finished = finished | (outputs.squeeze(1) == stop_token) +# if finished.all(): +# break + +# seqs = torch.cat(sequences, 1) if sequences else torch.full( +# (n_sequences, 1), start_token, dtype=torch.long, device=self.device +# ) + +# if return_smiles: +# outputs = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] +# else: +# outputs = sequences + +# if return_losses: +# return outputs, log_probs.detach().cpu().numpy() +# else: +# return outputs + + +# class HyenaModel(nn.Module): +# def __init__( +# self, +# vocabulary, +# n_layers=4, +# d_model=256, +# order=2, +# filter_order=64, +# num_heads=1, +# dropout=0.1, +# max_len=250, +# inner_factor=1, +# ): +# super(HyenaModel, self).__init__() + +# # detect device +# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# # vocabulary +# self.vocabulary = vocabulary +# self.vocabulary_size = len(self.vocabulary) +# self.padding_idx = self.vocabulary.dictionary[""] +# padding_t = torch.tensor(self.padding_idx).to(self.device) + +# # hyperparams +# self.n_layers = n_layers +# self.d_model = d_model +# self.order = order +# self.filter_order = filter_order +# self.num_heads = num_heads +# self.dropout = dropout +# self.max_len = max_len +# self.inner_factor = inner_factor + +# # model components +# self.embedding = nn.Embedding( +# self.vocabulary_size, self.d_model, padding_idx=padding_t +# ) + +# # Hyena layers +# self.layers = nn.ModuleList([ +# HyenaOperator( +# d_model=self.d_model, +# l_max=self.max_len, +# order=self.order, +# filter_order=self.filter_order, +# num_heads=self.num_heads, +# inner_factor=self.inner_factor, +# dropout=self.dropout, +# ) +# for i in range(self.n_layers) +# ]) + +# # dropout and output +# self.norm = nn.LayerNorm(self.d_model) +# self.dropout_layer = nn.Dropout(dropout) +# self.output_projection = nn.Linear(self.d_model, self.vocabulary_size) + +# # loss function (ignoring padding) +# self.loss_fn = nn.CrossEntropyLoss( +# ignore_index=self.padding_idx, reduction="none" +# ) + +# # move to GPU +# if torch.cuda.is_available(): +# self.cuda() + +# def forward(self, x): +# batch_size, seq_len = x.size() + +# # Embed the input +# x = self.embedding(x) # (batch_size, seq_len, d_model) + +# # Pass through Hyena layers +# for layer in self.layers: +# residual = x +# x = layer(x) +# x = x + residual # Residual connection +# if self.dropout > 0: +# x = self.dropout_layer(x) + +# # Normalize and project to vocabulary +# x = self.norm(x) +# logits = self.output_projection(x) + +# return logits + +# def loss(self, batch): +# if len(batch) == 3: +# padded, lengths, _ = batch +# else: +# padded, lengths = batch + +# padded = padded.to(self.device) + +# # Handle different input formats +# if padded.dim() == 2: +# if padded.shape[0] > padded.shape[1]: +# padded = padded.transpose(0, 1) + +# # Forward pass +# logits = self(padded) + +# # Calculate loss +# targets = padded[:, 1:] +# logits = logits[:, :-1, :] + +# loss = 0.0 +# actual_len = min(logits.shape[1], targets.shape[1]) + +# for char_idx in range(actual_len): +# loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) + +# return loss.mean() + +# def sample( +# self, +# *, +# n_sequences, +# max_len=None, +# return_smiles=True, +# return_losses=False, +# descriptors=None, +# ): +# if max_len is None: +# max_len = self.max_len + +# self.eval() + +# start_token = self.vocabulary.dictionary["SOS"] +# stop_token = self.vocabulary.dictionary["EOS"] +# pad_token = self.vocabulary.dictionary[""] + +# # Initialize with start tokens +# inputs = torch.full( +# (n_sequences, 1), start_token, dtype=torch.long, device=self.device +# ) + +# loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) + +# finished = torch.zeros(n_sequences, dtype=torch.bool, device=self.device) +# log_probs = torch.zeros(n_sequences, device=self.device) +# sequences = [] + +# with torch.no_grad(): +# for step in range(max_len): +# # Hyena processes full sequence each time (stateless) +# logits = self(inputs) +# logits = logits[:, -1, :] + +# logits = torch.clamp(logits, min=-1e4, max=1e4) +# prob = F.softmax(logits, dim=-1) + +# if torch.isnan(prob).any() or torch.isinf(prob).any(): +# break + +# outputs = torch.multinomial(prob, num_samples=1) +# sequences.append(outputs) + +# log_prob = F.log_softmax(logits, dim=-1) +# losses = loss_fn(log_prob, outputs.squeeze(1)) +# losses[finished] = 0 +# log_probs += losses + +# inputs = torch.cat([inputs, outputs], dim=1) + +# finished = finished | (outputs.squeeze(1) == stop_token) +# if finished.all(): +# break + +# seqs = torch.cat(sequences, 1) if sequences else torch.full( +# (n_sequences, 1), start_token, dtype=torch.long, device=self.device +# ) + +# if return_smiles: +# outputs = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] +# else: +# outputs = sequences + +# if return_losses: +# return outputs, log_probs.detach().cpu().numpy() +# else: +# return outputs + + +class StructuredStateSpaceSequenceModel(nn.Module): + def __init__( + self, + vocabulary, + model_dim=256, + state_dim=64, + n_layers=4, + n_ssm=1, + dropout=0.25, + max_len=250, + ): + super(StructuredStateSpaceSequenceModel, self).__init__() + + # detect device + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + + # vocabulary + self.vocabulary = vocabulary + self.vocabulary_size = len(self.vocabulary) + self.padding_idx = self.vocabulary.dictionary[""] + padding_t = torch.tensor(self.padding_idx).to(self.device) + + # hyperparams + self.model_dim = model_dim + self.state_dim = state_dim + self.n_layers = n_layers + self.n_ssm = n_ssm + self.dropout = dropout + self.max_len = max_len + + # S4 layer configuration + self.layer_config = [ + { + "_name_": "s4", + "d_state": self.state_dim, + "n_ssm": self.n_ssm, + }, + { + "_name_": "s4", + "d_state": self.state_dim, + "n_ssm": self.n_ssm, + }, + {"_name_": "ff"}, + ] + self.pool_config = {"_name_": "pool", "stride": 1, "expand": None} + + # model components + self.embedding = nn.Embedding( + self.vocabulary_size, self.model_dim, padding_idx=padding_t + ) + + self.model = SequenceModel( + d_model=self.model_dim, + n_layers=self.n_layers, + transposed=False, # Changed to False - expect (batch, length, dim) + dropout=self.dropout, + layer=self.layer_config, + pool=self.pool_config, + ) + + self.output_embedding = nn.Linear(self.model_dim, self.vocabulary_size) + self.recurrent_state = None + + # loss function (ignoring padding) + self.loss_fn = nn.CrossEntropyLoss( + ignore_index=self.padding_idx, reduction="none" + ) + + # move to GPU + if torch.cuda.is_available(): + self.cuda() + + def forward(self, x): + batch_size, seq_len = x.size() + + # Embed the input + x = self.embedding(x) # (batch_size, seq_len, model_dim) + + # Pass through S4 model (without state in training mode) + x, _ = self.model(x, state=None) + + # Project to vocabulary + logits = self.output_embedding(x) # (batch_size, seq_len, vocab_size) + + return logits + + def reset_state(self, batch_size, device=None): + if device is None: + device = self.device + self.recurrent_state = self.model.default_state( + batch_size, device=device + ) + + def recurrent_step(self, x_t): + if x_t.dim() == 1: + x_t = x_t.unsqueeze(1) + + x_t = self.embedding(x_t).squeeze(1) # (batch_size, model_dim) + x_t, state = self.model.step(x_t, state=self.recurrent_state) + self.recurrent_state = state + x_t = self.output_embedding(x_t) # (batch_size, vocab_size) + + return x_t + + def loss(self, batch): + if len(batch) == 3: + padded, lengths, _ = batch + else: + padded, lengths = batch + + padded = padded.to(self.device) + + # Handle different input formats + # RNN collate returns (seq_len, batch_size) format + # S4 model expects (batch_size, seq_len) format + # Always transpose since collate always uses (seq_len, batch_size) + if padded.dim() == 2: + padded = padded.transpose(0, 1) + + # batch_size = padded.shape[0] + # seq_len = padded.shape[1] + + # Don't use recurrent state during training - use full convolution mode + self.recurrent_state = None + + # Forward pass + logits = self(padded) # (batch_size, seq_len, vocab_size) + + # Calculate loss + # Shift targets: predict next token + targets = padded[:, 1:] # (batch_size, seq_len-1) + logits = logits[:, :-1, :] # (batch_size, seq_len-1, vocab_size) + + # Reshape for loss calculation + loss = 0.0 + actual_len = min(logits.shape[1], targets.shape[1]) + + for char_idx in range(actual_len): + loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) + + return loss.mean() + + def sample( + self, + *, + n_sequences, + max_len=None, + return_smiles=True, + return_losses=False, + descriptors=None, + ): + if max_len is None: + max_len = self.max_len + + # IMPORTANT: Set model to eval mode before sampling + self.eval() + + # Setup for recurrent mode + for module in self.model.modules(): + if hasattr(module, "setup_step"): + module.setup_step() + + # Reset state + self.reset_state(n_sequences, device=self.device) + + # Get start/stop tokens + start_token = self.vocabulary.dictionary["SOS"] + stop_token = self.vocabulary.dictionary["EOS"] + pad_token = self.vocabulary.dictionary[""] + + # Create start token tensor + inputs = ( + torch.empty(n_sequences).fill_(start_token).long().to(self.device) + ) + + # Setup loss function + loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) + + # Sample sequences + finished = torch.zeros(n_sequences).byte().to(self.device) + log_probs = torch.zeros(n_sequences).to(self.device) + sequences = [] + + with torch.no_grad(): # Also add no_grad for efficiency + for step in range(max_len): + # Get logits for current input + logits = self.recurrent_step(inputs) + + # Clamp logits to prevent inf/nan + logits = torch.clamp(logits, min=-1e4, max=1e4) + + # Sample from distribution + prob = F.softmax(logits, dim=-1) + + # Check for invalid values + if torch.isnan(prob).any() or torch.isinf(prob).any(): + break + + outputs = torch.multinomial(prob, num_samples=1).squeeze(1) + + sequences.append(outputs.view(-1, 1)) + + # Calculate NLL + log_prob = F.log_softmax(logits, dim=-1) + losses = loss_fn(log_prob, outputs) + + # Zero losses if we are finished sampling + losses[finished.bool()] = 0 + log_probs += losses + + # Update inputs for next step + inputs = outputs + + # Track whether sampling is done for all molecules + finished = torch.ge(finished + (outputs == stop_token), 1) + if torch.prod(finished) == 1: + break + + # Concatenate sequences and decode + seqs = ( + torch.cat(sequences, 1) + if sequences + else torch.empty(n_sequences, 1, dtype=torch.long) + .fill_(start_token) + .to(self.device) + ) + + if return_smiles: + outputs = [ + self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs + ] + else: + outputs = sequences + + # Optionally return losses + if return_losses: + return outputs, log_probs.detach().cpu().numpy() + else: + return outputs + class RNN(nn.Module): def __init__( @@ -439,23 +1255,31 @@ def forward(self, x): return logits def loss(self, batch): - # extract the elements of a single minibatch - padded, lengths = batch - # tranpose padded to batch_size * seq_len - padded = padded.transpose(0, 1) - # move to the gpu + if len(batch) == 3: + padded, lengths, _ = batch + else: + padded, lengths = batch + padded = padded.to(self.device) - # pass through the entire transformer model - decoded = self(padded) - # -> decoded: batch_size x max_len x vocab_size + # RNN collate returns (seq_len, batch_size) format + # Transformer expects (batch_size, seq_len) format + if padded.dim() == 2: + padded = padded.transpose(0, 1) + + # Get actual sequence length from batch + actual_seq_len = padded.shape[1] + + decoded = self(padded) # batch_size x seq_len x vocab_size - # finally, calculate loss loss = 0.0 - max_len = max(lengths) - targets = padded[:, 1:] - for char_idx in range(max_len - 1): - loss += self.loss_fn(decoded[:, char_idx], targets[:, char_idx]) + targets = padded[:, 1:] # batch_size x (seq_len-1) + + # Loop only up to actual decoded sequence length minus 1 + for char_idx in range( + min(actual_seq_len - 1, decoded.shape[1], targets.shape[1]) + ): + loss += self.loss_fn(decoded[:, char_idx, :], targets[:, char_idx]) return loss.mean() @@ -467,6 +1291,12 @@ def sample( return_losses=False, descriptors=None, ): + # Reset recurrent state before sampling + # self.reset_state(n_sequences, device=self.device) + + self.eval() + torch.cuda.empty_cache() + # get start/stop tokens start_token = self.vocabulary.dictionary["SOS"] stop_token = self.vocabulary.dictionary["EOS"] @@ -488,28 +1318,42 @@ def sample( finished = torch.zeros(n_sequences).byte().to(self.device) log_probs = torch.zeros(n_sequences).to(self.device) sequences = [] - for step in range(self.max_len): - logits = self(inputs)[:, -1, :] - prob = F.softmax(logits, dim=-1) - outputs = torch.multinomial(prob, num_samples=1) - # append to growing sequence - inputs = torch.cat((inputs, outputs), dim=1) - sequences.append(outputs) - # calculate NLL too - log_prob = F.log_softmax(logits, dim=1) - losses = loss(log_prob, outputs.squeeze(1)) - # zero losses if we are finished sampling - losses[finished.bool()] = 0 - log_probs += losses - # track whether sampling is done for all molecules - finished = torch.ge( - finished + (outputs.squeeze(1) == stop_token), 1 - ) - if torch.prod(finished) == 1: - break + with torch.no_grad(): + for step in range(self.max_len): + logits = self(inputs)[:, -1, :] + # Clamp logits to prevent inf/nan + logits = torch.clamp(logits, min=-1e4, max=1e4) + prob = F.softmax(logits, dim=-1) + + # Check for invalid values and skip if found + if torch.isnan(prob).any() or torch.isinf(prob).any(): + break + + outputs = torch.multinomial(prob, num_samples=1) + # append to growing sequence + inputs = torch.cat((inputs, outputs), dim=1) + sequences.append(outputs) + # calculate NLL too + log_prob = F.log_softmax(logits, dim=1) + losses = loss(log_prob, outputs.squeeze(1)) + # zero losses if we are finished sampling + losses[finished.bool()] = 0 + log_probs += losses + # track whether sampling is done for all molecules + finished = torch.ge( + finished + (outputs.squeeze(1) == stop_token), 1 + ) + if torch.prod(finished) == 1: + break # concatenate sequences and decode - seqs = torch.cat(sequences, 1) + seqs = ( + torch.cat(sequences, 1) + if sequences + else torch.empty(n_sequences, 1, dtype=torch.long) + .fill_(start_token) + .to(self.device) + ) if return_smiles: outputs = [ self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs @@ -517,6 +1361,8 @@ def sample( else: outputs = sequences + torch.cuda.empty_cache() + # optionally return losses if return_losses: return outputs, log_probs.detach().cpu().numpy() diff --git a/tests/test_snakemake_steps.py b/tests/test_snakemake_steps.py index 2f7869a..e1bc6ef 100644 --- a/tests/test_snakemake_steps.py +++ b/tests/test_snakemake_steps.py @@ -242,6 +242,33 @@ def test_02_train_models_conditional_RNN(tmp_path): # so we simply ensure that this step runs without errors. +def test_02_train_models_S4(tmp_path): + train_models_RNN.train_models_RNN( + representation="SMILES", + rnn_type="S4", + embedding_size=32, + hidden_size=256, + n_layers=3, + dropout=0, + batch_size=64, + learning_rate=0.001, + max_epochs=3, + patience=5000, + log_every_steps=100, + log_every_epochs=1, + sample_mols=100, + input_file=test_dir + / "0/prior/inputs/train_LOTUS_truncated_SMILES_0.smi", + vocab_file=test_dir + / "0/prior/inputs/train_LOTUS_truncated_SMILES_0.vocabulary", + model_file=tmp_path / "LOTUS_truncated_SMILES_0_0_model.pt", + loss_file=tmp_path / "LOTUS_truncated_SMILES_0_0_loss.csv", + smiles_file=None, + ) + # Model loss values can vary between platforms and architectures, + # so we simply ensure that this step runs without errors. + + def test_03_sample_molecules_RNN(tmp_path): output_file = ( tmp_path / "0/prior/samples/LOTUS_truncated_SMILES_0_0_0_samples.csv" diff --git a/workflow/Snakefile_data b/workflow/Snakefile_data index ecad188..fd4c212 100644 --- a/workflow/Snakefile_data +++ b/workflow/Snakefile_data @@ -206,7 +206,7 @@ rule sample_molecules_RNN: output_file = PATHS['input_file'] resources: mem_mb=12000, - runtime=15+MODEL_PARAMS["sample_mols"]//10000, + runtime=120+MODEL_PARAMS["sample_mols"]//10000, slurm_extra="--gres=gpu:1" shell: 'clm sample_molecules_RNN '