-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdebug.slurm
More file actions
187 lines (168 loc) · 5.87 KB
/
debug.slurm
File metadata and controls
187 lines (168 loc) · 5.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#!/bin/bash
#SBATCH -A SYB114
#SBATCH -J debug-grm
#SBATCH -N 1
#SBATCH -t 00:15:00
#SBATCH -p batch
#SBATCH -q debug
#SBATCH -o logs/%x-%j.out # Out Path
#SBATCH -e logs/%x-%j.err # Err Path
#SBATCH --open-mode=truncate # Overwrite .out/.err
### env, modules, and settings ###
set -eo pipefail
ulimit -S -c 0 # disaple core dumps
module purge
module load PrgEnv-gnu/8.6.0
module load rocm/6.3.1
module load craype-accel-amd-gfx90a
source /lustre/orion/syb111/proj-shared/Environments/source_miniconda_frontier.sh
source activate /lustre/orion/syb111/world-shared/environments/pytorch-rocm/
### train hyperparameters ###
export G_ENC=True
export E_ENC=True
export LD_ENC=True
export GXE_ENC=tf # options: "tf", "mlp", "cnn"
export WG=True # weighted gate for 3-prong architecture
export G_ENCODER_TYPE=moe # options: "dense", "moe"
export MOE_NUM_EXPERTS=16
export MOE_TOP_K=4
export MOE_SHARED_EXPERT=True
export MOE_LOSS_WEIGHT=0.01
export FULL_TRANSFORMER=True
# calculate microbatch size given global batch size and N nodes * 8 gpus/node
export GBS=2048
NGPUS=$(($SLURM_NNODES * 8))
MBS=$(($GBS / $NGPUS))
export BATCH_SIZE=$MBS
echo "Debug with global batch size $GBS on $NGPUS GPUs (microbatch size $MBS)"
export NUM_EPOCHS=1
export LR=1e-4
export WEIGHT_DECAY=1e-5
export DROPOUT=0.15
export EARLY_STOP=50
export G_LAYERS=1
export LD_LAYERS=2
export MLP_LAYERS=1
export GXE_LAYERS=2
export HEADS=4
export EMB_SIZE=256
export MOE_EXPERT_HIDDEN_DIM=$EMB_SIZE
export MOE_SHARED_EXPERT_HIDDEN_DIM=$EMB_SIZE
export SCALE_TARGETS=False
export G_INPUT_TYPE=grm # options: "tokens", "grm"
export ENV_CATEGORICAL_MODE=drop # "drop" matches best baseline behavior; set "onehot" to include categorical env features
export LOSS="envpcc" # + separated list of losses
export LOSS_WEIGHTS="1.0" # comma separated list of weights for each loss
export ENV_STRATIFIED=True # use environment-stratified batching for envpcc loss
export MIN_SAMPLES_PER_ENV=32 # samples per environment per batch for stable correlation
# Contrastive loss: encourages G embeddings to reflect genetic similarity
export CONTRASTIVE_MODE=g # ablation: none, g, e, g+e
export CONTRASTIVE_WEIGHT=0.1
export CONTRASTIVE_TEMPERATURE=0.5
export CONTRASTIVE_SIM_TYPE=grm # 'grm' (recommended) or 'ibs'
export CONTRASTIVE_LOSS_TYPE=mse # 'mse' (recommended), 'cosine', or 'kl'
# LEO (Leave-Environment-Out) validation: hold out entire environments for val
export LEO_VAL=True
export LEO_VAL_FRACTION=0.15
### wandb settings ###
export https_proxy=http://proxy.ccs.ornl.gov:3128
export http_proxy=$https_proxy
export WANDB_HTTP_TIMEOUT=90
# make wandb logging dir
mkdir -p logs/ckpt_ids logs/run_ids checkpoints data/results
export WANDB_RUN_ID_FILE="logs/run_ids/wandb_run_id_${SLURM_JOB_ID}.txt"
export CHECKPOINT_DIR_FILE="logs/ckpt_ids/checkpoint_dir_${SLURM_JOB_ID}.txt"
rm -f "$WANDB_RUN_ID_FILE" "$CHECKPOINT_DIR_FILE"
### network / rccl settings ###
export MASTER_ADDR=$(scontrol show hostname $SLURM_NODELIST | head -n1)
export MASTER_PORT=6000
export NCCL_SOCKET_IFNAME=hsn0
### miopen settings ###
export MIOPEN_USER_DB_PATH=/tmp/miopen-$SLURM_JOB_ID
export MIOPEN_CUSTOM_CACHE_DIR=$MIOPEN_USER_DB_PATH
rm -rf "$MIOPEN_USER_DB_PATH" && mkdir -p "$MIOPEN_USER_DB_PATH"
export MIOPEN_FIND_MODE=NORMAL
### srun (model train) ###
NGPUS=$(($SLURM_NNODES * 8))
srun -N ${SLURM_NNODES} \
-n ${NGPUS} \
--gpus-per-task=1 \
--cpu-bind=cores \
--export=ALL \
--kill-on-bad-exit=1 \
python -u scripts/train.py \
--batch_size $BATCH_SIZE \
--gbs $GBS \
--g_enc $G_ENC \
--e_enc $E_ENC \
--ld_enc $LD_ENC \
--gxe_enc $GXE_ENC \
--wg $WG \
--full_transformer $FULL_TRANSFORMER \
--num_epochs $NUM_EPOCHS \
--lr $LR \
--weight_decay $WEIGHT_DECAY \
--dropout $DROPOUT \
--early_stop $EARLY_STOP \
--g_layers $G_LAYERS \
--ld_layers $LD_LAYERS \
--mlp_layers $MLP_LAYERS \
--gxe_layers $GXE_LAYERS \
--heads $HEADS \
--emb_size $EMB_SIZE \
--g_input_type $G_INPUT_TYPE \
--env_categorical_mode $ENV_CATEGORICAL_MODE \
--loss $LOSS \
--loss_weights $LOSS_WEIGHTS \
--scale_targets $SCALE_TARGETS \
--env_stratified $ENV_STRATIFIED \
--min_samples_per_env $MIN_SAMPLES_PER_ENV \
--contrastive_mode $CONTRASTIVE_MODE \
--contrastive_weight $CONTRASTIVE_WEIGHT \
--contrastive_temperature $CONTRASTIVE_TEMPERATURE \
--contrastive_sim_type $CONTRASTIVE_SIM_TYPE \
--contrastive_loss_type $CONTRASTIVE_LOSS_TYPE \
--leo_val $LEO_VAL \
--leo_val_fraction $LEO_VAL_FRACTION
### python (model eval) ###
if [[ -f "$WANDB_RUN_ID_FILE" ]]; then
export WANDB_RESUME=allow
export WANDB_RUN_ID=$(cat "$WANDB_RUN_ID_FILE")
else
echo "[WARNING] WandB run ID file not found... Eval will create a new run."
unset WANDB_RESUME
unset WANDB_RUN_ID
fi
if [[ -f "$CHECKPOINT_DIR_FILE" ]]; then
CHECKPOINT_DIR=$(cat "$CHECKPOINT_DIR_FILE")
echo "[INFO] Using checkpoint dir from train: $CHECKPOINT_DIR"
else
echo "[ERROR] checkpoint dir file not found: $CHECKPOINT_DIR_FILE"
exit 2
fi
python -u scripts/eval.py \
--batch_size $BATCH_SIZE \
--gbs $GBS \
--g_enc $G_ENC \
--e_enc $E_ENC \
--ld_enc $LD_ENC \
--gxe_enc $GXE_ENC \
--wg $WG \
--full_transformer $FULL_TRANSFORMER \
--num_epochs $NUM_EPOCHS \
--lr $LR \
--weight_decay $WEIGHT_DECAY \
--dropout $DROPOUT \
--early_stop $EARLY_STOP \
--g_layers $G_LAYERS \
--ld_layers $LD_LAYERS \
--mlp_layers $MLP_LAYERS \
--gxe_layers $GXE_LAYERS \
--heads $HEADS \
--emb_size $EMB_SIZE \
--g_input_type $G_INPUT_TYPE \
--env_categorical_mode $ENV_CATEGORICAL_MODE \
--loss $LOSS \
--loss_weights $LOSS_WEIGHTS \
--scale_targets $SCALE_TARGETS \
--checkpoint_dir "$CHECKPOINT_DIR"