diff --git a/AntibodySequenceSampler.py b/AntibodySequenceSampler.py
index 25ea255..05b81df 100644
--- a/AntibodySequenceSampler.py
+++ b/AntibodySequenceSampler.py
@@ -15,9 +15,9 @@
from tqdm import tqdm
-from utils.utils_plotting \
- import plot_seq_logo, plot_histogram_for_array,\
- sequences_to_probabilities
+from utils.plotting \
+ import plot_seq_logo,\
+ sequences_to_probabilities #plot_histogram_for_array
torch.set_default_dtype(torch.float64)
torch.set_grad_enabled(False)
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..2dc1dcb
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 Gray Lab
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index 53ad014..c17d567 100644
--- a/README.md
+++ b/README.md
@@ -33,7 +33,9 @@ Download and extract trained models from [Zenodo](https://zenodo.org/records/831
tar -xvzf model.tar.gz
```
-## Sampling protein sequences
+## Sampling sequences
+
+### Sampling protein sequences
To design/generate all positions on the protein, run:
```bash
MODEL=trained_models/ProtEnT_backup.ckpt
@@ -52,7 +54,7 @@ The above command samples all sequences at 100% masking (i.e. only coord informa
python3 ProteinSequenceSampler.py --help
```
-## Sampling antibody sequences without partner context
+### Sampling antibody sequences without partner context
To design/generate all positions on the protein, run:
```bash
MODEL=trained_models/ProtEnT_backup.ckpt
@@ -76,7 +78,7 @@ The above command samples all sequences at 100% masking (i.e. only coord informa
python3 ProteinSequenceSampler.py --help
```
-## Sampling interface residues with partner context
+### Sampling interface residues with partner context
To generate/design the interface residues for the first partner (order determined by partners.json), run:
```bash
@@ -99,7 +101,7 @@ python3 PPIAbAgSequenceSampler.py \
# --partner_name both
```
-## Sampling antibody interface residues with antigen context
+### Sampling antibody interface residues with antigen context
```
MODEL=trained_models/ProtAbAgEnT_backup.ckpt
OUTDIR=./sampled_abag_sequences
@@ -120,6 +122,179 @@ python3 PPIAbAgSequenceSampler.py \
# --mask_ab_indices 10,11,12
```
+### Performance: Timing for Protein Design Tasks (CPU vs GPU)
+- Timing values are displayed in the format `mm:ss.000` (minutes:seconds.milliseconds).
+- Each GPU run was conducted on 1 node, utilizing 6 processes per task on an NVIDIA A100 GPU.
+- Each CPU run was conducted on 1 node, with 8 processes per CPU
+
+
+
+
+ | Sequence Design Task |
+ CPU/GPU |
+ No. of Designs |
+ Real Time |
+ User Time |
+ System Time |
+
+
+
+
+ | Protein Monomer Sequence Design (126 amino acids) |
+ CPU |
+ 100 |
+ 01:18.639 |
+ 00:28.506 |
+ 00:03.628 |
+
+
+ | 1,000 |
+ 00:46.927 |
+ 00:33.980 |
+ 00:04.286 |
+
+
+ | 10,000 |
+ 01:10.349 |
+ 01:03.870 |
+ 00:04.842 |
+
+
+ | 100,000 |
+ 02:16.108 |
+ 06:14.454 |
+ 00:12.911 |
+
+
+ | GPU |
+ 100 |
+ 00:49.538 |
+ 00:06.923 |
+ 00:01.888 |
+
+
+ | 1,000 |
+ 00:56.923 |
+ 00:12.329 |
+ 00:02.101 |
+
+
+ | 10,000 |
+ 00:37.218 |
+ 00:57.562 |
+ 00:02.771 |
+
+
+ | 100,000 |
+ 01:59.589 |
+ 08:33.594 |
+ 00:09.799 |
+
+
+ | Protein-Protein Interface |
+ CPU |
+ 100 |
+ 01:13.022 |
+ 01:16.282 |
+ 00:08.224 |
+
+
+ | 1,000 |
+ 00:43.972 |
+ 01:22.581 |
+ 00:08.596 |
+
+
+ | 10,000 |
+ 01:19.130 |
+ 02:22.664 |
+ 00:09.561 |
+
+
+ | 100,000 |
+ 03:28.817 |
+ 12:41.153 |
+ 00:17.398 |
+
+
+ | GPU |
+ 100 |
+ 00:11.688 |
+ 00:09.020 |
+ 00:03.329 |
+
+
+ | 1,000 |
+ 00:39.591 |
+ 00:18.655 |
+ 00:03.423 |
+
+
+ | 10,000 |
+ 00:49.310 |
+ 01:46.022 |
+ 00:04.493 |
+
+
+ | 100,000 |
+ 03:01.718 |
+ 16:08.428 |
+ 00:14.877 |
+
+
+ | Antibody-Antigen Interface |
+ CPU |
+ 100 |
+ 01:18.330 |
+ 02:45.636 |
+ 00:16.683 |
+
+
+ | 1,000 |
+ 00:48.824 |
+ 03:00.106 |
+ 00:16.751 |
+
+
+ | 10,000 |
+ 01:37.904 |
+ 05:21.302 |
+ 00:18.257 |
+
+
+ | 100,000 |
+ 05:27.519 |
+ 27:10.781 |
+ 00:27.179 |
+
+
+ | GPU |
+ 100 |
+ 01:35.224 |
+ 00:13.541 |
+ 00:04.228 |
+
+
+ | 1,000 |
+ 00:47.984 |
+ 00:29.034 |
+ 00:03.739 |
+
+
+ | 10,000 |
+ 01:11.780 |
+ 03:00.415 |
+ 00:04.555 |
+
+
+ | 100,000 |
+ 04:24.885 |
+ 28:10.995 |
+ 00:14.905 |
+
+
+
+
## Training
### Installation
Model was trained with older versions of torch and pytorch_lightning. Newer versions are not backward compatible. The following instructions work for python 3.9 and cuda 11.1.
diff --git a/data/timing_table.md b/data/timing_table.md
new file mode 100644
index 0000000..e4f31dd
--- /dev/null
+++ b/data/timing_table.md
@@ -0,0 +1,167 @@
+
+
+
+ | Sequence Design Task |
+ CPU/GPU |
+ No. of Designs |
+ Real Time |
+ User Time |
+ System Time |
+
+
+
+
+ | Protein Monomer Sequence Design (126 amino acids) |
+ CPU |
+ 100 |
+ 01:18.639 |
+ 00:28.506 |
+ 00:03.628 |
+
+
+ | 1,000 |
+ 00:46.927 |
+ 00:33.980 |
+ 00:04.286 |
+
+
+ | 10,000 |
+ 01:10.349 |
+ 01:03.870 |
+ 00:04.842 |
+
+
+ | 100,000 |
+ 02:16.108 |
+ 06:14.454 |
+ 00:12.911 |
+
+
+ | GPU |
+ 100 |
+ 00:49.538 |
+ 00:06.923 |
+ 00:01.888 |
+
+
+ | 1,000 |
+ 00:56.923 |
+ 00:12.329 |
+ 00:02.101 |
+
+
+ | 10,000 |
+ 00:37.218 |
+ 00:57.562 |
+ 00:02.771 |
+
+
+ | 100,000 |
+ 01:59.589 |
+ 08:33.594 |
+ 00:09.799 |
+
+
+ | Protein-Protein Interface |
+ CPU |
+ 100 |
+ 01:13.022 |
+ 01:16.282 |
+ 00:08.224 |
+
+
+ | 1,000 |
+ 00:43.972 |
+ 01:22.581 |
+ 00:08.596 |
+
+
+ | 10,000 |
+ 01:19.130 |
+ 02:22.664 |
+ 00:09.561 |
+
+
+ | 100,000 |
+ 03:28.817 |
+ 12:41.153 |
+ 00:17.398 |
+
+
+ | GPU |
+ 100 |
+ 00:11.688 |
+ 00:09.020 |
+ 00:03.329 |
+
+
+ | 1,000 |
+ 00:39.591 |
+ 00:18.655 |
+ 00:03.423 |
+
+
+ | 10,000 |
+ 00:49.310 |
+ 01:46.022 |
+ 00:04.493 |
+
+
+ | 100,000 |
+ 03:01.718 |
+ 16:08.428 |
+ 00:14.877 |
+
+
+ | Antibody-Antigen Interface |
+ CPU |
+ 100 |
+ 01:18.330 |
+ 02:45.636 |
+ 00:16.683 |
+
+
+ | 1,000 |
+ 00:48.824 |
+ 03:00.106 |
+ 00:16.751 |
+
+
+ | 10,000 |
+ 01:37.904 |
+ 05:21.302 |
+ 00:18.257 |
+
+
+ | 100,000 |
+ 05:27.519 |
+ 27:10.781 |
+ 00:27.179 |
+
+
+ | GPU |
+ 100 |
+ 01:35.224 |
+ 00:13.541 |
+ 00:04.228 |
+
+
+ | 1,000 |
+ 00:47.984 |
+ 00:29.034 |
+ 00:03.739 |
+
+
+ | 10,000 |
+ 01:11.780 |
+ 03:00.415 |
+ 00:04.555 |
+
+
+ | 100,000 |
+ 04:24.885 |
+ 28:10.995 |
+ 00:14.905 |
+
+
+
\ No newline at end of file
diff --git a/scripts/fine_tune_ppi-abag.sh b/scripts/fine_tune_ppi-abag.sh
index e723641..bb865b8 100755
--- a/scripts/fine_tune_ppi-abag.sh
+++ b/scripts/fine_tune_ppi-abag.sh
@@ -1,40 +1,47 @@
#!/bin/bash -l
-#source
+source
-
-### MUST be set up #####
-### WANDB ENTITY
+### MUST be set up ####
+### WANDB ENTITY ###
WANDB_ENTITY="YOUR_WANDB_ENTITY"
if [ "$WANDB_ENTITY" = "YOUR_WANDB_ENTITY" ]; then
echo "Error: Please set your WANDB_ENTITY variable."
exit 1
fi
-### CHECK AND CREATE DATASETS DIRECTORY
-if [ ! -d datasets ]; then
- mkdir training_datasets
+### DOWNLOADING TRAINING DATASETS IF NECESSARY ###
+mkdir -p training_datasets
+if [ ! -f training_datasets/ids_train_casp12nr50_nr70Ig_nr40Others.fasta ]; then
+ wget -P training_datasets https://zenodo.org/records/13831403/files/ids_train_casp12nr50_nr70Ig_nr40Others.fasta
fi
+gd2_dataset_ids=training_datasets/ids_train_casp12nr50_nr70Ig_nr40Others.fasta
-### DOWNLOAD DATASET IF NOT EXISTS
+# DOWNLOAD SEPARATELY IF NECESSARY
+if [ ! -f training_datasets/sidechainnet_casp12_50.pkl ]; then
+ wget -P training_datasets https://zenodo.org/records/13831403/files/sidechainnet_casp12_50.pkl
+fi
-# Download ids_train_casp12nr50_nr70Ig_nr40Others.fasta
-#if [ ! -f training_datasets/ids_train_casp12nr50_nr70Ig_nr40Others.fasta ]; then
-# wget -P training_datasets https://zenodo.org/records/13831403/files/ids_train_casp12nr50_nr70Ig_nr40Others.fasta
-#fi
-## DOWNLOAD SEPARATELY
-# # Download sidechainnet_casp12_50.pkl
-# if [ ! -f training_datasets/sidechainnet_casp12_50.pkl ]; then
-# wget -P training_datasets https://zenodo.org/records/13831403/files/sidechainnet_casp12_50.pkl
-# fi
+if [ ! -f training_datasets/AbSCSAbDAb_trainnr90_bkandcbcoords_aug2022.h5 ]; then
+ wget -P training_datasets https://zenodo.org/records/13831403/files/AbSCSAbDAb_trainnr90_bkandcbcoords_aug2022.h5
+fi
-gd2_dataset_ids=/scratch16/jgray21/smahaja4_active/datasets/nredundant_train_test_lists_New/ids_train_casp12nr50_nr70Ig_nr40Others.fasta
+if [ ! -f training_datasets/ppi_trainset_5032_noabag_aug2022.h5 ]; then
+ wget -P training_datasets https://zenodo.org/records/13831403/files/ppi_trainset_5032_noabag_aug2022.h5
+fi
-#### procs, gpus ###############
+### CHECK/DOWNLOAD->DECOMPRESS PRETRAINED MODEL ###
+mkdir -p models
+if [ ! -f models/model.tar.gz ] && [ ! -f models/ProtEnT_backup.ckpt ]; then
+ wget -P models https://zenodo.org/records/8313466/files/model.tar.gz
+fi
+tar --skip-old-files --strip-components=1 -C models -xvzf models/model.tar.gz models/ProtEnT_backup.ckpt
+
+### procs, gpus ###
n_proc=12
num_gpus=1
-#Default settings
+### Default settings ###
LAYERS=4
HEADS=8
DIM=256
@@ -45,9 +52,12 @@ gmodel=egnn-trans-ma
atom_types=backbone_and_cb
NN=48
-if [ ! -f /tmp/trainset_highres_nr90_vhh-rabd-dms_abnr90agnr70_aug2022.h5 ]
-then
- cp /scratch16/jgray21/smahaja4_active/datasets/trainset_highres_nr90_vhh-rabd-dms_abnr90agnr70_aug2022.h5 /tmp/.
+if [ ! -f /tmp/AbSCSAbDAb_trainnr90_bkandcbcoords_aug2022.h5 ]; then
+ cp training_datasets/AbSCSAbDAb_trainnr90_bkandcbcoords_aug2022.h5 /tmp/.
+fi
+
+if [ ! -f /tmp/ppi_trainset_5032_noabag_aug2022.h5 ]; then
+ cp training_datasets/ppi_trainset_5032_noabag_aug2022.h5 /tmp/.
fi
if [ ! -f /tmp/trainset_highres_nr90_vhh-rabd-dms_abnr90agnr70_aug2022.h5 ]
@@ -55,19 +65,18 @@ then
cp /scratch16/jgray21/smahaja4_active/datasets/trainset_highres_nr90_vhh-rabd-dms_abnr90agnr70_aug2022.h5 /tmp/.
fi
-cp /scratch16/jgray21/smahaja4_active/datasets/ppi_trainset_5032_noabag_aug2022.h5 /tmp/.
+scn_dataset_path='training_datasets'
+
H5_FILE_PPI=/tmp/ppi_trainset_5032_noabag_aug2022.h5
+# H5_FILE_AbAg=/tmp/AbSCSAbDAb_trainnr90_bkandcbcoords_aug2022.h5
H5_FILE_AbAg=/tmp/trainset_highres_nr90_vhh-rabd-dms_abnr90agnr70_aug2022.h5
-#### training #######
+### training ###
SEED=1
MODELS_DIR=models
EPOCHS=600
-MODEL=/home/smahaja4/backup/models/ProtEnT_backup.ckpt
+MODEL=models/ProtEnT_backup.ckpt
date
which python3 #MAKE SURE THIS MATCHES THE INSTALLED ENV
-cd /scratch16/jgray21/smahaja4_active/repositories/240519/MaskedProteinEnT/
python3 train_masked_model.py \
- --save_every ${save_every} --lr 0.00001 --batch_size $BS --heads $HEADS --model_dim $DIM --epochs $EPOCHS --dropout 0.2 --masking_rate_max 0.15 --topk_metrics 1 --layers $LAYERS --num_gpus $num_gpus --crop_sequences --scn_sequence_similarity ${SS} --protein_gmodel $gmodel --lr_patience 350 --lr_cooldown 20 --max_ag_neighbors ${NN} --atom_types ${atom_types} --file_with_selected_scn_ids_for_training ${gd2_dataset_ids} --lightning_save_last_model --use_scn --num_proc $n_proc --output_dir ${MODELS_DIR} --seed $SEED --wandb_entity ${WANDB_ENTITY} --model $MODEL --h5_file_ppi $H5_FILE_PPI --h5_file $H5_FILE_AbAg --fine_tune
-
-
+ --save_every ${save_every} --lr 0.00001 --batch_size $BS --heads $HEADS --model_dim $DIM --epochs $EPOCHS --dropout 0.2 --masking_rate_max 0.15 --topk_metrics 1 --layers $LAYERS --num_gpus $num_gpus --crop_sequences --scn_sequence_similarity ${SS} --protein_gmodel $gmodel --lr_patience 350 --lr_cooldown 20 --max_ag_neighbors ${NN} --atom_types ${atom_types} --file_with_selected_scn_ids_for_training ${gd2_dataset_ids} --lightning_save_last_model --use_scn --num_proc $n_proc --output_dir ${MODELS_DIR} --seed $SEED --wandb_entity ${WANDB_ENTITY} --model $MODEL --h5_file_ppi $H5_FILE_PPI --h5_file $H5_FILE_AbAg --fine_tune --scn_path ${scn_dataset_path}
diff --git a/scripts/sample_abag_sequences.sh b/scripts/sample_abag_sequences.sh
index 6699a64..72af1d2 100755
--- a/scripts/sample_abag_sequences.sh
+++ b/scripts/sample_abag_sequences.sh
@@ -1,42 +1,35 @@
#!/bin/bash -l
-module unload python
-module load cuda/11.1.0
-module load python/3.9.0
-module load git
+source
-module list
-
-source /home/smahaja4/repositories/clone_masked_model/venv_py39_torch19/bin/activate
-
-
-#### procs, gpus ###############
-n_proc=$SLURM_NTASKS_PER_NODE
-num_gpus=0
-env | grep -a SLURM | tee slurm_env
-qu="a100"
-if [ "$SLURM_JOB_PARTITION" = "$qu" ]; then
- IFS=','
- read -a strarr <<< "$SLURM_STEP_GPUS"
- num_gpus=${#strarr[*]}
+### CHECK/DOWNLOAD->DECOMPRESS PRETRAINED MODEL ###
+mkdir -p models
+if [ ! -f models/model.tar.gz ] && [ ! -f models/ProtAbAgEnT_backup.ckpt ]; then
+ wget -P models https://zenodo.org/records/8313466/files/model.tar.gz
fi
-##############################
-
-
-MODEL=../trained_models/ProtAbAgEnT_backup.ckpt
-TEST_RESULTS_BASE=/scratch16/jgray21/smahaja4_active/tmp_abag/
-PDB_DIR=/scratch16/jgray21/smahaja4_active/repositories/MaskedProteinEnT/data/abag/
-PPI_PARTNERS_DICT=/scratch16/jgray21/smahaja4_active/repositories/MaskedProteinEnT/data/abag/1n8z_partners.json
-python3 /scratch16/jgray21/smahaja4_active/repositories/MaskedProteinEnT/PPIAbAgSequenceSampler.py \
- --output_dir ${TEST_RESULTS_BASE} \
- --num_gpus $num_gpus \
- --num_procs $n_proc \
- --model $MODEL \
- --from_pdb $PDB_DIR \
- --sample_temperatures 0.2,0.5 \
- --num_samples 100 \
- --partners_json ${PPI_PARTNERS_DICT} \
- --partner_name Ab \
- --antibody
-
-
+tar --skip-old-files --strip-components=1 -C models -xvzf models/model.tar.gz models/ProtAbAgEnT_backup.ckpt
+
+### procs, gpus ###
+n_proc=6
+num_gpus=1
+
+echo "Running inference"
+date
+
+MODEL=models/ProtAbAgEnT_backup.ckpt
+OUTDIR=sampled_sequences
+PDB_DIR=data/abag/
+PPI_PARTNERS_DICT=data/abag/1n8z_partners.json
+SAMPLER_SCRIPT=PPIAbAgSequenceSampler.py
+# SAMPLER_SCRIPT=AntibodySequenceSampler.py
+python3 $SAMPLER_SCRIPT \
+ --output_dir ${OUTDIR} \
+ --model $MODEL \
+ --from_pdb $PDB_DIR \
+ --sample_temperatures 0.2,0.5 \
+ --num_samples 100 \
+ --partners_json ${PPI_PARTNERS_DICT} \
+ --partner_name Ab \
+ --antibody \
+ --num_gpus $num_gpus \
+ --num_procs $n_proc \
diff --git a/scripts/sample_antibody_sequences.sh b/scripts/sample_antibody_sequences.sh
index d406fd7..751d026 100755
--- a/scripts/sample_antibody_sequences.sh
+++ b/scripts/sample_antibody_sequences.sh
@@ -1,37 +1,34 @@
#!/bin/bash -l
-module unload python
-module load cuda/11.1.0
-module load python/3.9.0
-module load git
+source
-source /home/smahaja4/repositories/clone_masked_model/venv_py39_torch19/bin/activate
+### CHECK/DOWNLOAD->DECOMPRESS PRETRAINED MODEL ###
+mkdir -p models
+if [ ! -f models/model.tar.gz ] && [ ! -f models/AbPlusEnT_backup.ckpt ]; then
+ wget -P models https://zenodo.org/records/8313466/files/model.tar.gz
+fi
+tar --skip-old-files --strip-components=1 -C models -xvzf models/model.tar.gz models/AbPlusEnT_backup.ckpt
-export OMP_NUM_THREADS=12
-gmodel=egnn-trans-ma-ppi
+### procs, gpus ###
+n_proc=6
+num_gpus=1
-#### procs, gpus ###############
-n_proc=$SLURM_NTASKS_PER_NODE
-num_gpus=0
-env | grep -a SLURM | tee slurm_env
-qu="a100"
-if [ "$SLURM_JOB_PARTITION" = "$qu" ]; then
- IFS=','
- read -a strarr <<< "$SLURM_STEP_GPUS"
- num_gpus=${#strarr[*]}
-fi
-##############################
+echo "Running inference"
+date
-MODEL=../trained_models/AbPlusEnT_backup.ckpt
-TEST_RESULTS_BASE=/scratch16/jgray21/smahaja4_active/tmp_ab/
-PDB_DIR=/scratch16/jgray21/smahaja4_active/repositories/MaskedProteinEnT/data/antibodies
-python3 /scratch16/jgray21/smahaja4_active/repositories/MaskedProteinEnT/ProteinSequenceSampler.py \
- --output_dir ${TEST_RESULTS_BASE} \
- --num_gpus $num_gpus \
- --num_procs $n_proc \
+MODEL=models/AbPlusEnT_backup.ckpt
+OUTDIR=sampled_sequences
+PDB_DIR=data/antibodies
+SAMPLER_SCRIPT=ProteinSequenceSampler.py
+# SAMPLER_SCRIPT=AntibodySequenceSampler.py
+python3 $SAMPLER_SCRIPT \
+ --output_dir ${OUTDIR} \
--model $MODEL \
--from_pdb $PDB_DIR \
- --sample_temperatures 0.5 \
+ --sample_temperatures 0.2,0.5 \
--num_samples 100 \
- --antibody
-
+ --antibody \
+ --mask_ab_region cdrs \
+ --num_gpus $num_gpus \
+ --num_procs $n_proc \
+ # --mask_ab_indices 10,11,12 \ # 0-indexed
diff --git a/scripts/sample_ppi_sequences.sh b/scripts/sample_ppi_sequences.sh
index 853c0e3..b73b1ae 100755
--- a/scripts/sample_ppi_sequences.sh
+++ b/scripts/sample_ppi_sequences.sh
@@ -1,18 +1,32 @@
#!/bin/bash -l
-#ACTIVATE ENV
+source
+
+### CHECK/DOWNLOAD->DECOMPRESS PRETRAINED MODEL ###
+mkdir -p models
+if [ ! -f models/model.tar.gz ] && [ ! -f models/ProtPPIEnT_backup.ckpt ]; then
+ wget -P models https://zenodo.org/records/8313466/files/model.tar.gz
+fi
+tar --skip-old-files --strip-components=1 -C models -xvzf models/model.tar.gz models/ProtPPIEnT_backup.ckpt
+
+### procs, gpus ###
+n_proc=6
+num_gpus=1
+
+echo "Running inference"
+date
MODEL=models/ProtPPIEnT_backup.ckpt
-OUTDIR=outdir
+OUTDIR=sampled_sequences
PDB_DIR=data/ppis
PPI_PARTNERS_DICT=data/ppis/heteromers_partners_example.json
python3 PPIAbAgSequenceSampler.py \
- --output_dir ${OUTDIR} \
- --model $MODEL \
- --from_pdb $PDB_DIR \
+ --output_dir ${OUTDIR} \
+ --model $MODEL \
+ --from_pdb $PDB_DIR \
--sample_temperatures 0.2,0.5 \
- --num_samples 100 \
+ --num_samples 100 \
--partners_json ${PPI_PARTNERS_DICT} \
- --partner_name p1
-
-
+ --partner_name p1 \
+ --num_gpus $num_gpus \
+ --num_procs $n_proc \
diff --git a/scripts/sample_protein_sequences.sh b/scripts/sample_protein_sequences.sh
index a0e1881..d992ba4 100755
--- a/scripts/sample_protein_sequences.sh
+++ b/scripts/sample_protein_sequences.sh
@@ -1,36 +1,29 @@
#!/bin/bash -l
-module unload python
-module load cuda/11.1.0
-module load python/3.9.0
-module load git
+source
-source /home/smahaja4/repositories/clone_masked_model/venv_py39_torch19/bin/activate
+### CHECK/DOWNLOAD->DECOMPRESS PRETRAINED MODEL ###
+mkdir -p models
+if [ ! -f models/model.tar.gz ] && [ ! -f models/ProtEnT_backup.ckpt ]; then
+ wget -P models https://zenodo.org/records/8313466/files/model.tar.gz
+fi
+tar --skip-old-files --strip-components=1 -C models -xvzf models/model.tar.gz models/ProtEnT_backup.ckpt
-export OMP_NUM_THREADS=12
-gmodel=egnn-trans-ma-ppi
+### procs, gpus ###
+n_proc=6
+num_gpus=1
-#### procs, gpus ###############
-n_proc=$SLURM_NTASKS_PER_NODE
-num_gpus=0
-env | grep -a SLURM | tee slurm_env
-qu="a100"
-if [ "$SLURM_JOB_PARTITION" = "$qu" ]; then
- IFS=','
- read -a strarr <<< "$SLURM_STEP_GPUS"
- num_gpus=${#strarr[*]}
-fi
-##############################
+echo "Running inference"
+date
-MODEL=../trained_models/ProtEnT_backup.ckpt
-TEST_RESULTS_BASE=/scratch16/jgray21/smahaja4_active/tmp/
-PDB_DIR=/scratch16/jgray21/smahaja4_active/repositories/MaskedProteinEnT/data/proteins
-python3 /scratch16/jgray21/smahaja4_active/repositories/MaskedProteinEnT/ProteinSequenceSampler.py \
- --output_dir ${TEST_RESULTS_BASE} \
- --num_gpus $num_gpus \
- --num_procs $n_proc \
+MODEL=models/ProtEnT_backup.ckpt
+OUTDIR=sampled_sequences
+PDB_DIR=data/proteins
+python3 ProteinSequenceSampler.py \
+ --output_dir ${OUTDIR} \
--model $MODEL \
--from_pdb $PDB_DIR \
--sample_temperatures 0.2,0.5 \
- --num_samples 100
-
+ --num_samples 100 \
+ --num_gpus $num_gpus \
+ --num_procs $n_proc \
diff --git a/scripts/train_protein_model.sh b/scripts/train_protein_model.sh
index 547b7d1..8afdc95 100755
--- a/scripts/train_protein_model.sh
+++ b/scripts/train_protein_model.sh
@@ -10,9 +10,7 @@ if [ "$WANDB_ENTITY" = "YOUR_WANDB_ENTITY" ]; then
fi
### CHECK AND CREATE DATASETS DIRECTORY
-if [ ! -d datasets ]; then
- mkdir training_datasets
-fi
+mkdir -p training_datasets
### DOWNLOAD DATASET IF NOT EXISTS
@@ -23,13 +21,13 @@ fi
gd2_dataset_ids=$(pwd)/training_datasets/ids_train_casp12nr50_nr70Ig_nr40Others.fasta
-## DOWNLOAD SEPARATELY
-### UGH, download works but is taking way too long... (5 hours)
-# # Download sidechainnet_casp12_50.pkl
-# if [ ! -f training_datasets/sidechainnet_casp12_50.pkl ]; then
-# wget -P training_datasets https://zenodo.org/records/13831403/files/sidechainnet_casp12_50.pkl
-# fi
+### DOWNLOAD SEPARATELY if necessary
+## Download sidechainnet_casp12_50.pkl
+if [ ! -f training_datasets/sidechainnet_casp12_50.pkl ]; then
+ wget -P training_datasets https://zenodo.org/records/13831403/files/sidechainnet_casp12_50.pkl
+fi
+scn_dataset_path='training_datasets'
#### procs, gpus ###############
n_proc=6
@@ -55,5 +53,4 @@ EPOCHS=100
date
which python3 #MAKE SURE THIS MATCHES THE INSTALLED ENV
python3 train_masked_model.py \
- --save_every ${save_every} --lr 0.00001 --batch_size $BS --heads $HEADS --model_dim $DIM --epochs $EPOCHS --dropout 0.2 --masking_rate_max 0.15 --topk_metrics 1 --layers $LAYERS --num_gpus $num_gpus --crop_sequences --scn_sequence_similarity ${SS} --protein_gmodel $gmodel --lr_patience 350 --lr_cooldown 20 --max_ag_neighbors ${NN} --atom_types ${atom_types} --file_with_selected_scn_ids_for_training ${gd2_dataset_ids} --lightning_save_last_model --use_scn --num_proc $n_proc --output_dir ${MODELS_DIR} --seed $SEED --wandb_entity ${WANDB_ENTITY}
-
+ --save_every ${save_every} --lr 0.00001 --batch_size $BS --heads $HEADS --model_dim $DIM --epochs $EPOCHS --dropout 0.2 --masking_rate_max 0.15 --topk_metrics 1 --layers $LAYERS --num_gpus $num_gpus --crop_sequences --scn_sequence_similarity ${SS} --protein_gmodel $gmodel --lr_patience 350 --lr_cooldown 20 --max_ag_neighbors ${NN} --atom_types ${atom_types} --file_with_selected_scn_ids_for_training ${gd2_dataset_ids} --lightning_save_last_model --use_scn --num_proc $n_proc --output_dir ${MODELS_DIR} --seed $SEED --wandb_entity ${WANDB_ENTITY} --scn_path ${scn_dataset_path}
diff --git a/src/datamodules/MaskedSequenceStructureMADataModule.py b/src/datamodules/MaskedSequenceStructureMADataModule.py
index be23051..531ed74 100644
--- a/src/datamodules/MaskedSequenceStructureMADataModule.py
+++ b/src/datamodules/MaskedSequenceStructureMADataModule.py
@@ -23,9 +23,9 @@ def setup(self, stage: Optional[str] = None):
shared_arguments_protein = get_protein_dataset_setup(args)
casp_version = args.scn_casp_version
thinning = args.scn_sequence_similarity
- scn_path = '/scratch16/jgray21/smahaja4_active/datasets/sidechainnet'
- input_file = '{}_c{}_ss{}/sidechainnet_casp{}_{}.pkl'.format(
- scn_path, casp_version, thinning, casp_version, thinning)
+ scn_path = args.scn_path
+ input_file = '{}/sidechainnet_casp{}_{}.pkl'.format(
+ scn_path, casp_version, thinning)
self.train_protein_dataset = None
self.validation_protein_dataset = None
self.test_protein_dataset = None
diff --git a/utils/command_line_utils.py b/utils/command_line_utils.py
index 5727409..c2bb9fd 100644
--- a/utils/command_line_utils.py
+++ b/utils/command_line_utils.py
@@ -90,6 +90,9 @@ def _get_args():
action='store_true',
default=False,
help='Use SidechainNet Dataset also')
+ parser.add_argument('--scn_path',
+ type=str,
+ help='Path to sidechainnet dataset. Avoid trailing "/"')
parser.add_argument('--atom_types',
type=str,
default='backbone_and_cb',