forked from konpatp/diffae
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtemplates.py
More file actions
66 lines (58 loc) · 1.8 KB
/
templates.py
File metadata and controls
66 lines (58 loc) · 1.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from experiment import *
def autoenc_base():
"""
base configuration for all Diff-AE models.
"""
conf = TrainConfig()
conf.batch_size = 32
conf.beatgans_gen_type = GenerativeType.ddim
conf.beta_scheduler = 'linear'
conf.data_name = 'tapev06'
conf.diffusion_type = 'beatgans'
conf.eval_ema_every_samples = 200_000
conf.eval_every_samples = 200_000
conf.fp16 = True
conf.lr = 1e-4
conf.model_name = ModelName.beatgans_autoenc
conf.net_attn = (16, )
conf.net_beatgans_attn_head = 1
conf.net_beatgans_embed_channels = 512
conf.net_beatgans_resnet_two_cond = True
conf.net_ch_mult = (1, 2, 4, 8)
conf.net_ch = 64
conf.net_enc_channel_mult = (1, 2, 4, 8, 8)
conf.net_enc_pool = 'adaptivenonzero'
conf.sample_size = 32
conf.T_eval = 20
conf.T = 1000
conf.make_model_conf()
return conf
def tapev06_autoenc_base():
conf = autoenc_base()
conf.data_name = 'tapev06'
conf.scale_up_gpus(4)
conf.img_size = 128
conf.net_ch = 128
# final resolution = 8x8
conf.net_ch_mult = (1, 1, 2, 3, 4)
# final resolution = 4x4
conf.net_enc_channel_mult = (1, 1, 2, 3, 4, 4)
conf.eval_ema_every_samples = 10_000_000
conf.eval_every_samples = 10_000_000
conf.make_model_conf()
return conf
def tapev06_autoenc():
conf = tapev06_autoenc_base()
conf.total_samples = 130_000_000
conf.eval_ema_every_samples = 10_000_000
conf.eval_every_samples = 10_000_000
conf.name = 'tapev06_autoenc'
return conf
def pretrain_tapev06_autoenc():
conf = tapev06_autoenc_base()
conf.pretrain = PretrainConfig(
name='130M',
path=f'checkpoints/{tapev06_autoenc().name}/last.ckpt',
)
conf.latent_infer_path = f'checkpoints/{tapev06_autoenc().name}/latent.pkl'
return conf