Skip to content

Conversation

@skinnider
Copy link
Collaborator

Fresh start trying to merge #275

anushka255
anushka255 previously approved these changes Jan 6, 2026
@anushka255 anushka255 requested review from anushka255 and removed request for anushka255 January 6, 2026 20:14
@anushka255 anushka255 dismissed their stale review January 6, 2026 20:16

I was just testing something.

Michael A. Skinnider and others added 2 commits January 6, 2026 15:22
@skinnider skinnider requested a review from seungchan-an January 7, 2026 16:06
Copy link
Collaborator

@seungchan-an seungchan-an left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

S4 integration looks solid, unused models are cleanly commented out, and the loss refactor and NaN early stopping are reasonable. S4 tests are in place and CI is green.
looks good to me.

@skinnider
Copy link
Collaborator Author

skinnider commented Jan 9, 2026

@GuptaVishu2002 I was going to launch a few small runs of RNN vs. Transformer vs. S4 to check that the changes are non-breaking, but I realized I can't actually set the use of Transformer or S4 models from the config.yaml - and src/clm/commands/sample_molecules_RNN.py does not actually import either of these classes.

Could you take a look at integrating config.yaml -> Snakemake -> sample_molecules_RNN.py so that the user (here, me) can specify Transformer or S4 in the config and run the whole pipeline with one of these models?

We might need to add new parameters to the config, e.g., a "model_type" parameter might be worth considering (since currently the model_params all relate to RNNs):

# Parameters that define the neural network model and training process.
model_params:
  # Type of Recurrent Neural Network (RNN) to use.
  # Available options are 'LSTM' and 'GRU'
  rnn_type: LSTM
  embedding_size: 128 # Size of the embedding vectors that represent each token in the input sequence.
  hidden_size: 1024 # Size of the hidden state of the RNN.
  n_layers: 3 # Number of stacked RNN layers in the model.
  dropout: 0 # Dropout rate applied to the RNN layer for regularization.
  batch_size: 64 # Number of samples processed before the models internal parameters are updated.
  learning_rate: 0.001 # Used by the optimizer to update model parameters.
  max_epochs: 999999 # Maximum number of training epochs (complete passes through the training dataset).
  patience: 50000 # Number of steps with no improvement in the validation loss after which early stopping is triggered.

  # An RNN model conditioned on input descriptors (experimentally obtained properties of the input SMILES).
  # Note that rnn_type and other RNN architecture parameters are still applicable in this case.
  conditional:
    # Is the conditional model enabled?
    enabled: false

    # Note: Both emb and emb_l below cannot be true at the same time.
    # Concatenate the descriptors directly to the token embeddings at each step in the sequence?
    emb: false
    # Concatenate the descriptors to the token embeddings, but by first passing them through a
    # linear layer to obtain embeddings of dimensionality equal to that of the token embeddings?
    emb_l: true

    # Note: Both dec and dec_l below cannot be true at the same time.
    # Concatenate the descriptors directly to the output of the RNN layers
    # (prior to the decoder layer)?
    dec: false
    # Concatenate the descriptors to the output of the RNN layers
    # (prior to the decoder layer), but by first passing them through a
    # linear layer to obtain embeddings of dimensionality equal to that of the token embeddings?
    dec_l: true

    # Instantiate the hidden states based on learned transformations of the descriptors
    # (with a single linear layer), as in Kotsias et al?
    h: false

see also issue #283

@skinnider
Copy link
Collaborator Author

Update on this: tested the S4 implementation with the NPS training set. The model is not outperforming the LSTM by any means but seems to be doing reasonably well, ruling out any major issues in the implementation. The Transformer on the other hand is failing immediately at the train_models step - Vishu will look into this.

Vishu Gupta added 2 commits January 21, 2026 14:24
…r does not maintain recurrent state. Also add torch.cuda.empty_cache() and torch.no_grad() for sampling to GPU memory management
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants