Gramba is a hybrid Transformer RNN deep learning architecture that interleaves minGRU layers with sub-quadratic attention layers.
- Conda
- Clone the repository onto your machine.
- Create the virtual environment:
conda env create -f environment.yml
Note: This may take several minutes.
- Activate the environment:
conda activate gramba
- Ensure the src directory is in your Python Path:
export PYTHONPATH="$PYTHONPATH:/path/to/repository/src"
Note: This will only add the src directory to your path for your current session. Write the export command to your .bashrc, .zshrc, or equivalent to add the path permanently.
- Verify that test cases work:
pytest
A small suite to ensure that the model architecture is working correctly. These tests should be verified before pull requests are merged into main.
The src directory contains the Gramba model logic, scripts to load the dataset, train the model, visualize log files, etc.
The glove directory contains the logic to load the GloVe embeddings into a numpy matrix to be later used as the embeddings for Gramba.
- Download the GloVe embeddings and unzip them into the
glovedirectory. - Run the
load_glove.pyscript to process the embeddings and save the resulting embedding matrix.
python load_glove.py --embedding_dim <DIM> --glove_path <PATH_TO_GLOVE_TXT> --output_path <OUTPUT_FILE>
The twitter directory contains the logic to load and process the Twitter dataset.
- Download the Sentiment140 dataset and place it inside the
twitterdirectory. - Run the
load_twitter.pyscript to process the dataset and save the cleaned version.
python load_imdb.py --dataset_in_path <INPUT_CSV> --dataset_out_path <OUTPUT_CSV>The imdb directory contains the logic to load and process the IMDB movie reviews dataset.
- Download the IMDb dataset and place it inside the
imdbdirectory. - Run the
load_imdb.pyscript to process the dataset and save the cleaned version.
python load_imdb.py --dataset_in_path <INPUT_CSV> --dataset_out_path <OUTPUT_CSV>Helper functions for loading the SQuAD dataset.
Scripts to train models on various datasets. These scripts accept no arguments. The code needs to be adjusted to change the model's hyperparameters. These scripts should be ran on a machine with access to CUDA (e.g. U of T's SLURM Cluster). All of these scripts produce log files, which can be visualized by scripts in the utils directory.
train_gramba_squad.py: Train a Gramba model on SQuAD.train_gramba_sequence_classification: Train a Gramba model either on Twitter or IMDB.
Various utilities, including scripts which accept log files and produce visualizations.
Two scripts for profiling the Gramba model wirh respect to sequence length and model hyperparameters.
Building blocks of the Gramba model.
Three Models:
- The base Gramba Model (
GrambaModel). This model will output an encoded vector for each token in the sequence. One could add a classification head to this model to use it for question answering, named entity recognition, co-reference resolution, etc, or for classification, by making a prediction on the[CLS]token. - Gramba for Sequence Classification (
GrambaForSequenceClassification) - Gramba for Question Answering (
GrambaSQuADModel).
These can be instantiated along with an instance of the GrambaConfig class, which specify the default hyperparameters of the Gramba model.
config = GrambaConfig(
num_classes=2,
embedding_dim=50,
expansion_factor=2
# Other hyperparameters
)
x = torch.randint(0, config.vocab_size, (32, 512))
attention_mask = torch.ones_like(x).bool()
longformer_mask = torch.zeros_like(x).bool()
model = GrambaModel(config)
output = model(x, attention_mask, longformer_mask)