Skip to content
Open

S4 #281

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ __pycache__

# version file (managed by setuptools-scm)
src/clm/_version.py

# Snakemake logs
.snakemake
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ classifiers = [

dependencies = [
"deepsmiles",
"einops",
"fcd_torch",
"numpy",
"opt_einsum",
"pandas",
"pulp<2.8.0",
"rdkit",
"s4dd @ git+https://github.com/GuptaVishu2002/s4-for-de-novo-drug-design.git@fix-module-library-packaging",
"scikit-learn",
"scipy==1.11.1",
"selfies",
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,6 @@ virtualenv==20.26.2
wrapt==1.16.0
yte==1.5.4
zipp==3.19.0
einops==0.6.0
opt_einsum==3.3.0
s4dd @ git+https://github.com/GuptaVishu2002/s4-for-de-novo-drug-design.git@fix-module-library-packaging
145 changes: 125 additions & 20 deletions src/clm/commands/sample_molecules_RNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from tqdm import tqdm

from clm.datasets import Vocabulary, SelfiesVocabulary
from clm.models import RNN, ConditionalRNN
from clm.models import (
RNN,
ConditionalRNN,
Transformer,
StructuredStateSpaceSequenceModel,
) # , H3Model, H3ConvModel, HyenaModel
from clm.functions import load_dataset, write_to_csv_file

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -132,7 +137,8 @@ def sample_molecules_RNN(
vocab = Vocabulary(vocab_file=vocab_file)

heldout_dataset = None
if conditional:

if rnn_type == "S4":
assert (
heldout_file is not None
), "heldout_file must be provided for conditional RNN Model"
Expand All @@ -141,29 +147,124 @@ def sample_molecules_RNN(
input_file=heldout_file,
vocab_file=vocab_file,
)
model = ConditionalRNN(
vocab,
rnn_type=rnn_type,
model = StructuredStateSpaceSequenceModel(
vocabulary=vocab, # heldout_dataset.vocabulary
model_dim=embedding_size,
state_dim=64,
n_layers=n_layers,
embedding_size=embedding_size,
hidden_size=hidden_size,
n_ssm=1,
dropout=dropout,
num_descriptors=heldout_dataset.n_descriptors,
conditional_emb=conditional_emb,
conditional_emb_l=conditional_emb_l,
conditional_dec=conditional_dec,
conditional_dec_l=conditional_dec_l,
conditional_h=conditional_h,
)
else:
model = RNN(
vocab,
rnn_type=rnn_type,
n_layers=n_layers,
# elif rnn_type == "H3":
# assert (
# heldout_file is not None
# ), "heldout_file must be provided for conditional RNN Model"
# heldout_dataset = load_dataset(
# representation=representation,
# input_file=heldout_file,
# vocab_file=vocab_file,
# )
# model = H3Model(
# vocabulary=vocab,
# n_layers=n_layers,
# d_model=embedding_size,
# d_state=64,
# head_dim=1,
# dropout=dropout,
# max_len=250,
# use_fast_fftconv=False,
# )
# elif rnn_type == "H3Conv":
# assert (
# heldout_file is not None
# ), "heldout_file must be provided for conditional RNN Model"
# heldout_dataset = load_dataset(
# representation=representation,
# input_file=heldout_file,
# vocab_file=vocab_file,
# )
# model = H3ConvModel(
# vocabulary=vocab,
# n_layers=n_layers,
# d_model=embedding_size,
# head_dim=1,
# dropout=dropout,
# max_len=250,
# use_fast_fftconv=False,
# )
# elif rnn_type == "Hyena":
# assert (
# heldout_file is not None
# ), "heldout_file must be provided for conditional RNN Model"
# heldout_dataset = load_dataset(
# representation=representation,
# input_file=heldout_file,
# vocab_file=vocab_file,
# )
# model = HyenaModel(
# vocabulary=vocab,
# n_layers=n_layers,
# d_model=embedding_size,
# order=2,
# filter_order=64,
# num_heads=1,
# dropout=dropout,
# max_len=250,
# inner_factor=1,
# )

elif rnn_type == "Transformer":
assert (
heldout_file is not None
), "heldout_file must be provided for conditional RNN Model"
heldout_dataset = load_dataset(
representation=representation,
input_file=heldout_file,
vocab_file=vocab_file,
)
model = Transformer(
vocabulary=vocab,
n_blocks=n_layers,
n_heads=4,
embedding_size=embedding_size,
hidden_size=hidden_size,
dropout=dropout,
exp_factor=4,
bias=True,
)

else:
if conditional:
assert (
heldout_file is not None
), "heldout_file must be provided for conditional RNN Model"
heldout_dataset = load_dataset(
representation=representation,
input_file=heldout_file,
vocab_file=vocab_file,
)
model = ConditionalRNN(
vocab,
rnn_type=rnn_type,
n_layers=n_layers,
embedding_size=embedding_size,
hidden_size=hidden_size,
dropout=dropout,
num_descriptors=heldout_dataset.n_descriptors,
conditional_emb=conditional_emb,
conditional_emb_l=conditional_emb_l,
conditional_dec=conditional_dec,
conditional_dec_l=conditional_dec_l,
conditional_h=conditional_h,
)
else:
model = RNN(
vocab,
rnn_type=rnn_type,
n_layers=n_layers,
embedding_size=embedding_size,
hidden_size=hidden_size,
dropout=dropout,
)
logging.info(vocab.dictionary)

if torch.cuda.is_available():
Expand All @@ -183,8 +284,12 @@ def sample_molecules_RNN(
n_sequences = min(batch_size, sample_mols - i)
descriptors = None
if heldout_dataset is not None:
# Use modulo to cycle through heldout_dataset
descriptor_indices = [
(i + j) % len(heldout_dataset) for j in range(n_sequences)
]
descriptors = torch.stack(
[heldout_dataset[_i][1] for _i in range(i, i + n_sequences)]
[heldout_dataset[idx][1] for idx in descriptor_indices]
)
descriptors = descriptors.to(model.device)
sampled_smiles, losses = model.sample(
Expand Down
105 changes: 86 additions & 19 deletions src/clm/commands/train_models_RNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,19 @@
from torch.utils.data import DataLoader
from tqdm import tqdm
from rdkit import rdBase
from clm.models import RNN, ConditionalRNN
from clm.models import (
RNN,
ConditionalRNN,
Transformer,
StructuredStateSpaceSequenceModel,
) # , H3Model, H3ConvModel, HyenaModel
from clm.loggers import EarlyStopping, track_loss, print_update
from clm.functions import write_smiles, load_dataset

import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

# suppress Chem.MolFromSmiles error output
rdBase.DisableLog("rdApp.error")
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -197,31 +206,89 @@ def train_models_RNN(

dataset = load_dataset(representation, input_file, vocab_file)

if conditional:
model = ConditionalRNN(
dataset.vocabulary,
rnn_type=rnn_type,
if rnn_type == "S4":
model = StructuredStateSpaceSequenceModel(
vocabulary=dataset.vocabulary,
model_dim=embedding_size,
state_dim=64,
n_layers=n_layers,
embedding_size=embedding_size,
hidden_size=hidden_size,
n_ssm=1,
dropout=dropout,
num_descriptors=dataset.n_descriptors,
conditional_emb=conditional_emb,
conditional_emb_l=conditional_emb_l,
conditional_dec=conditional_dec,
conditional_dec_l=conditional_dec_l,
conditional_h=conditional_h,
)
else:
model = RNN(
dataset.vocabulary,
rnn_type=rnn_type,
n_layers=n_layers,

# elif rnn_type == "H3":
# model = H3Model(
# vocabulary=dataset.vocabulary,
# n_layers=n_layers,
# d_model=embedding_size,
# d_state=64,
# head_dim=1,
# dropout=dropout,
# max_len=250,
# use_fast_fftconv=False,
# )

# elif rnn_type == "H3Conv":
# model = H3ConvModel(
# vocabulary=dataset.vocabulary,
# n_layers=n_layers,
# d_model=embedding_size,
# head_dim=1,
# dropout=dropout,
# max_len=250,
# use_fast_fftconv=False,
# )

# elif rnn_type == "Hyena":
# model = HyenaModel(
# vocabulary=dataset.vocabulary,
# n_layers=n_layers,
# d_model=embedding_size,
# order=2,
# filter_order=64,
# num_heads=1,
# dropout=dropout,
# max_len=250,
# inner_factor=1,
# )

elif rnn_type == "Transformer":
model = Transformer(
vocabulary=dataset.vocabulary,
n_blocks=n_layers,
n_heads=4,
embedding_size=embedding_size,
hidden_size=hidden_size,
dropout=dropout,
exp_factor=4,
bias=True,
)

else:
if conditional:
model = ConditionalRNN(
dataset.vocabulary,
rnn_type=rnn_type,
n_layers=n_layers,
embedding_size=embedding_size,
hidden_size=hidden_size,
dropout=dropout,
num_descriptors=dataset.n_descriptors,
conditional_emb=conditional_emb,
conditional_emb_l=conditional_emb_l,
conditional_dec=conditional_dec,
conditional_dec_l=conditional_dec_l,
conditional_h=conditional_h,
)
else:
model = RNN(
dataset.vocabulary,
rnn_type=rnn_type,
n_layers=n_layers,
embedding_size=embedding_size,
hidden_size=hidden_size,
dropout=dropout,
)

logger.info(dataset.vocabulary.dictionary)

loader = DataLoader(
Expand Down
17 changes: 17 additions & 0 deletions src/clm/loggers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import pandas as pd
import torch
import math
from rdkit import Chem
from tqdm import tqdm
from clm.models import ConditionalRNN
Expand Down Expand Up @@ -34,9 +35,25 @@ def __init__(self, patience=100):
self.best_loss = None
self.step_at_best = 0
self.stop = False
self.nan_counter = 0
print("instantiated early stopping with patience=" + str(self.patience))

def __call__(self, val_loss, model, output_file, step_idx):
# Check for NaN/Inf
if math.isnan(val_loss) or math.isinf(val_loss):
self.nan_counter += 1
print(
f"NaN/Inf loss detected at step {step_idx} ({self.nan_counter}/3)"
)
if self.nan_counter >= 3:
self.stop = True
print("Stopping training after 3 consecutive NaN/Inf losses.")
if self.best_loss is not None:
print(
f"Best model (loss={self.best_loss:.4f}) already saved."
)
return

# do nothing if early stopping is disabled
if self.patience > 0:
if self.best_loss is None:
Expand Down
Loading