From ed623f3258b496a3c61c1ada94e21b1e454dcb04 Mon Sep 17 00:00:00 2001 From: Dilip Parasu Date: Thu, 7 Sep 2023 14:34:54 +0530 Subject: [PATCH] Update ddpm_conditional.py --- ddpm_conditional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ddpm_conditional.py b/ddpm_conditional.py index d2e0cecff..0fc104865 100644 --- a/ddpm_conditional.py +++ b/ddpm_conditional.py @@ -98,7 +98,7 @@ def train(args): logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i) if epoch % 10 == 0: - labels = torch.arange(10).long().to(device) + labels = torch.arange(args.num_classes).long().to(device) sampled_images = diffusion.sample(model, n=len(labels), labels=labels) ema_sampled_images = diffusion.sample(ema_model, n=len(labels), labels=labels) plot_images(sampled_images)