Skip to content

Gramba is a hybrid RNN-Transformer deep learning model for natural language processing that scales sub-quadratically with sequence length by interleaving minGRU and LongFormer layers for local and global context.

Notifications You must be signed in to change notification settings

lbaierreinio/gramba

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Gramba

Gramba is a hybrid Transformer RNN deep learning architecture that interleaves minGRU layers with sub-quadratic attention layers.

Model Architecture

gramba

Requirements

  • Conda

Setup

  1. Clone the repository onto your machine.
  2. Create the virtual environment:

conda env create -f environment.yml

Note: This may take several minutes.

  1. Activate the environment:

conda activate gramba

  1. Ensure the src directory is in your Python Path:

export PYTHONPATH="$PYTHONPATH:/path/to/repository/src"

Note: This will only add the src directory to your path for your current session. Write the export command to your .bashrc, .zshrc, or equivalent to add the path permanently.

  1. Verify that test cases work:

pytest

test

A small suite to ensure that the model architecture is working correctly. These tests should be verified before pull requests are merged into main.

src

The src directory contains the Gramba model logic, scripts to load the dataset, train the model, visualize log files, etc.

glove

The glove directory contains the logic to load the GloVe embeddings into a numpy matrix to be later used as the embeddings for Gramba.

Steps to Prepare GloVe Embeddings:

  1. Download the GloVe embeddings and unzip them into the glove directory.
  2. Run the load_glove.py script to process the embeddings and save the resulting embedding matrix.

Usage:

python load_glove.py --embedding_dim <DIM> --glove_path <PATH_TO_GLOVE_TXT> --output_path <OUTPUT_FILE>

twitter

The twitter directory contains the logic to load and process the Twitter dataset.

Steps to Prepare the Twitter Dataset:

  1. Download the Sentiment140 dataset and place it inside the twitter directory.
  2. Run the load_twitter.py script to process the dataset and save the cleaned version.

Usage:

python load_imdb.py --dataset_in_path <INPUT_CSV> --dataset_out_path <OUTPUT_CSV>

imdb

The imdb directory contains the logic to load and process the IMDB movie reviews dataset.

Steps to Prepare the IMDb Dataset:

  1. Download the IMDb dataset and place it inside the imdb directory.
  2. Run the load_imdb.py script to process the dataset and save the cleaned version.

Usage:

python load_imdb.py --dataset_in_path <INPUT_CSV> --dataset_out_path <OUTPUT_CSV>

squad

Helper functions for loading the SQuAD dataset.

train

Scripts to train models on various datasets. These scripts accept no arguments. The code needs to be adjusted to change the model's hyperparameters. These scripts should be ran on a machine with access to CUDA (e.g. U of T's SLURM Cluster). All of these scripts produce log files, which can be visualized by scripts in the utils directory.

  1. train_gramba_squad.py: Train a Gramba model on SQuAD.
  2. train_gramba_sequence_classification: Train a Gramba model either on Twitter or IMDB.

utils

Various utilities, including scripts which accept log files and produce visualizations.

profile

Two scripts for profiling the Gramba model wirh respect to sequence length and model hyperparameters.

layers

Building blocks of the Gramba model.

model

Three Models:

  1. The base Gramba Model (GrambaModel). This model will output an encoded vector for each token in the sequence. One could add a classification head to this model to use it for question answering, named entity recognition, co-reference resolution, etc, or for classification, by making a prediction on the [CLS] token.
  2. Gramba for Sequence Classification (GrambaForSequenceClassification)
  3. Gramba for Question Answering (GrambaSQuADModel).

These can be instantiated along with an instance of the GrambaConfig class, which specify the default hyperparameters of the Gramba model.

Usage

config = GrambaConfig(
    num_classes=2,
    embedding_dim=50,
    expansion_factor=2
    # Other hyperparameters
)

x = torch.randint(0, config.vocab_size, (32, 512))
attention_mask = torch.ones_like(x).bool()
longformer_mask = torch.zeros_like(x).bool()

model = GrambaModel(config)
output = model(x, attention_mask, longformer_mask)

About

Gramba is a hybrid RNN-Transformer deep learning model for natural language processing that scales sub-quadratically with sequence length by interleaving minGRU and LongFormer layers for local and global context.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •  

Languages