From f6cb831ee97bab16271aef61d188f6eb79e5d94c Mon Sep 17 00:00:00 2001 From: jsch Date: Mon, 30 Jun 2025 20:18:28 +0900 Subject: [PATCH] betas now does not accept ints --- train.py | 4 ++-- training/training_loop.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index b611c981..037c1a71 100644 --- a/train.py +++ b/train.py @@ -156,8 +156,8 @@ def main(**kwargs): c.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator') c.D_kwargs = dnnlib.EasyDict(class_name='training.networks.Discriminator') - c.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0], eps=1e-8) - c.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0], eps=1e-8) + c.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0.0, 0.0], eps=1e-8) + c.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0.0, 0.0], eps=1e-8) c.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.R3GANLoss') c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, prefetch_factor=2) diff --git a/training/training_loop.py b/training/training_loop.py index 8747c428..ccf6c573 100644 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -328,7 +328,7 @@ def training_loop( # Update weights. for g in phase.opt.param_groups: g['lr'] = cur_lr - g['betas'] = (0, cur_beta2) + g['betas'] = (0.0, cur_beta2) with torch.autograd.profiler.record_function(phase.name + '_opt'): params = [param for param in phase.module.parameters() if param.grad is not None]