-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_vae.py
More file actions
86 lines (77 loc) · 3.46 KB
/
train_vae.py
File metadata and controls
86 lines (77 loc) · 3.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from pytorch_lightning import Trainer
from pytorch_lightning import seed_everything
from argparse import ArgumentParser
from models import ConvVAEModule
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from datasets import DATAMODULES
import os
from vae_embed import embed_dataset
if __name__ == "__main__":
parser = ArgumentParser(prog="vae-for-ddf", description="Train VAE for downstream use with differential decision forests")
parser.add_argument("--batch-size", type=int, default=512, help="batch size")
parser.add_argument("--epochs", type=int, default=50, help="num epochs")
parser.add_argument("--latent-dim", type=int, default=32, help="size of latent dim for our vae")
parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
parser.add_argument("--kl-coeff", type=int, default=1, help="kl coeff aka beta term in the elbo loss function")
parser.add_argument("--output-dir", type=str, default=os.path.join("results", 'vae'), help="output directory")
parser.add_argument("--anomaly-detect", help="Detect anomalies", action="store_true", default=False)
parser.add_argument("--name", type=str, default="vae-for-ddf", help="wandb name of the run")
parser.add_argument("--checkpoint", type=str, default=None, help="checkpoint to load")
parser.add_argument("--dataset", type=str, default="mnist", help="dataset to use", choices=DATAMODULES.keys())
parser.add_argument("--seed", type=int, default=42, help="random seed")
args = parser.parse_args()
seed_everything(args.seed)
args.output_dir = os.path.join(args.output_dir, args.dataset, f"latent_dim_{args.latent_dim}")
os.makedirs(args.output_dir, exist_ok=True)
wandb_logger = WandbLogger(
name=args.name,
project=parser.prog,
save_dir=args.output_dir,
log_model=True, # Log checkpoint only at the end of training (to stop my wandb running out of storage!)
)
#args without name and output_dir
config = vars(args).copy()
config.pop("name")
config.pop("output_dir")
wandb_logger.experiment.config.update(config)
model_params = {
"input_shape" : (1, 28, 28),
"encoder_conv_filters" : [28, 64, 64],
"decoder_conv_t_filters" : [64, 28, 1],
"latent_dim" : args.latent_dim,
"kl_coeff" : args.kl_coeff,
"lr" : args.lr,
}
val_checkpoint = ModelCheckpoint(
monitor="val_loss",
mode="min",
auto_insert_metric_name=True,
)
latest_checkpoint = ModelCheckpoint(
filename="latest-checkpoint",
every_n_epochs=1,
save_top_k=1,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvVAEModule(**model_params)
model.to(device)
trainer = Trainer(
accelerator=str(device),
logger=wandb_logger,
callbacks=[latest_checkpoint, val_checkpoint],
detect_anomaly=args.anomaly_detect,
max_epochs=args.epochs,
)
datamodule = DATAMODULES[args.dataset](batch_size=args.batch_size)
trainer.fit(
model,
datamodule=datamodule,
ckpt_path=args.checkpoint,
)
trainer.test(datamodule=datamodule)
embeddings_path = os.path.join(args.output_dir, "embeddings.npz")
embed_dataset(model.vae, dataset=args.dataset, batch_size=args.batch_size, output_path=embeddings_path, device=device)