|
316 | 316 | " error = torch.abs(x[:,-1, :, :] - 0.9).mean() \n", |
317 | 317 | " return error\n", |
318 | 318 | "\n", |
| 319 | + "# MSE loss from init\n", |
| 320 | + "def make_mse_loss(target):\n", |
| 321 | + " def mse_loss(x, sigma, **kwargs):\n", |
| 322 | + " return (x - target).square().mean()\n", |
| 323 | + " return mse_loss\n", |
319 | 324 | "\n", |
320 | 325 | "###\n", |
321 | 326 | "# Conditioning helper functions\n", |
|
326 | 331 | " # loss_fn (function): func(x, sigma, denoised) -> number\n", |
327 | 332 | " # scale (number): how much this loss is applied to the image\n", |
328 | 333 | " def cond_fn(x, sigma, denoised, **kwargs):\n", |
329 | | - " # x = x.detach().requires_grad_()\n", |
330 | | - " # denoised = denoised.detach().requires_grad_()\n", |
331 | 334 | " with torch.enable_grad():\n", |
332 | 335 | " denoised_sample = model.differentiable_decode_first_stage(denoised).requires_grad_()\n", |
333 | 336 | " loss = loss_fn(denoised_sample, sigma, **kwargs) * scale\n", |
|
662 | 665 | " if args.sampler in ['plms','ddim']:\n", |
663 | 666 | " sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, ddim_discretize='fill', verbose=False)\n", |
664 | 667 | "\n", |
| 668 | + " if args.init_mse_scale > 0 and init_latent is None:\n", |
| 669 | + " raise Exception(\"Cannot use mse loss without an init image\")\n", |
| 670 | + "\n", |
665 | 671 | " cond_fns = [\n", |
666 | | - " make_cond_fn(blue_loss_fn, args.blue_loss_scale, verbose=True) if args.blue_loss_scale > 0 else None\n", |
| 672 | + " make_cond_fn(blue_loss_fn, args.blue_loss_scale, verbose=True) if args.blue_loss_scale > 0 else None,\n", |
| 673 | + " make_cond_fn(make_mse_loss(init_image), args.init_mse_scale, verbose=True) if args.init_mse_scale > 0 else None,\n", |
667 | 674 | " ]\n", |
668 | 675 | "\n", |
669 | 676 | " callback = make_callback(sampler_name=args.sampler,\n", |
|
1064 | 1071 | "\n", |
1065 | 1072 | " #@markdown **Conditioning Settings**\n", |
1066 | 1073 | " blue_loss_scale = 200 #@param {type:\"number\"}\n", |
| 1074 | + " init_mse_scale = 200 #@param {type:\"number\"}\n", |
1067 | 1075 | "\n", |
1068 | 1076 | " n_samples = 1 # doesnt do anything\n", |
1069 | 1077 | " precision = 'autocast' \n", |
|
0 commit comments