Skip to content

Commit b54732e

Browse files
Merge pull request CompVis#56 from enzymezoo-code/conditioning
Added mse loss
2 parents fa8b51b + 98bdf58 commit b54732e

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

Deforum_Stable_Diffusion.ipynb

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,11 @@
316316
" error = torch.abs(x[:,-1, :, :] - 0.9).mean() \n",
317317
" return error\n",
318318
"\n",
319+
"# MSE loss from init\n",
320+
"def make_mse_loss(target):\n",
321+
" def mse_loss(x, sigma, **kwargs):\n",
322+
" return (x - target).square().mean()\n",
323+
" return mse_loss\n",
319324
"\n",
320325
"###\n",
321326
"# Conditioning helper functions\n",
@@ -326,8 +331,6 @@
326331
" # loss_fn (function): func(x, sigma, denoised) -> number\n",
327332
" # scale (number): how much this loss is applied to the image\n",
328333
" def cond_fn(x, sigma, denoised, **kwargs):\n",
329-
" # x = x.detach().requires_grad_()\n",
330-
" # denoised = denoised.detach().requires_grad_()\n",
331334
" with torch.enable_grad():\n",
332335
" denoised_sample = model.differentiable_decode_first_stage(denoised).requires_grad_()\n",
333336
" loss = loss_fn(denoised_sample, sigma, **kwargs) * scale\n",
@@ -662,8 +665,12 @@
662665
" if args.sampler in ['plms','ddim']:\n",
663666
" sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, ddim_discretize='fill', verbose=False)\n",
664667
"\n",
668+
" if args.init_mse_scale > 0 and init_latent is None:\n",
669+
" raise Exception(\"Cannot use mse loss without an init image\")\n",
670+
"\n",
665671
" cond_fns = [\n",
666-
" make_cond_fn(blue_loss_fn, args.blue_loss_scale, verbose=True) if args.blue_loss_scale > 0 else None\n",
672+
" make_cond_fn(blue_loss_fn, args.blue_loss_scale, verbose=True) if args.blue_loss_scale > 0 else None,\n",
673+
" make_cond_fn(make_mse_loss(init_image), args.init_mse_scale, verbose=True) if args.init_mse_scale > 0 else None,\n",
667674
" ]\n",
668675
"\n",
669676
" callback = make_callback(sampler_name=args.sampler,\n",
@@ -1064,6 +1071,7 @@
10641071
"\n",
10651072
" #@markdown **Conditioning Settings**\n",
10661073
" blue_loss_scale = 200 #@param {type:\"number\"}\n",
1074+
" init_mse_scale = 200 #@param {type:\"number\"}\n",
10671075
"\n",
10681076
" n_samples = 1 # doesnt do anything\n",
10691077
" precision = 'autocast' \n",

Deforum_Stable_Diffusion.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,11 @@ def blue_loss_fn(x, sigma, **kwargs):
295295
error = torch.abs(x[:,-1, :, :] - 0.9).mean()
296296
return error
297297

298+
# MSE loss from init
299+
def make_mse_loss(target):
300+
def mse_loss(x, sigma, **kwargs):
301+
return (x - target).square().mean()
302+
return mse_loss
298303

299304
###
300305
# Conditioning helper functions
@@ -305,8 +310,6 @@ def make_cond_fn(loss_fn, scale, verbose=False):
305310
# loss_fn (function): func(x, sigma, denoised) -> number
306311
# scale (number): how much this loss is applied to the image
307312
def cond_fn(x, sigma, denoised, **kwargs):
308-
# x = x.detach().requires_grad_()
309-
# denoised = denoised.detach().requires_grad_()
310313
with torch.enable_grad():
311314
denoised_sample = model.differentiable_decode_first_stage(denoised).requires_grad_()
312315
loss = loss_fn(denoised_sample, sigma, **kwargs) * scale
@@ -641,8 +644,12 @@ def generate(args, return_latent=False, return_sample=False, return_c=False):
641644
if args.sampler in ['plms','ddim']:
642645
sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, ddim_discretize='fill', verbose=False)
643646

647+
if args.init_mse_scale > 0 and init_latent is None:
648+
raise Exception("Cannot use mse loss without an init image")
649+
644650
cond_fns = [
645-
make_cond_fn(blue_loss_fn, args.blue_loss_scale, verbose=True) if args.blue_loss_scale > 0 else None
651+
make_cond_fn(blue_loss_fn, args.blue_loss_scale, verbose=True) if args.blue_loss_scale > 0 else None,
652+
make_cond_fn(make_mse_loss(init_image), args.init_mse_scale, verbose=True) if args.init_mse_scale > 0 else None,
646653
]
647654

648655
callback = make_callback(sampler_name=args.sampler,
@@ -1020,6 +1027,7 @@ def DeforumArgs():
10201027

10211028
#@markdown **Conditioning Settings**
10221029
blue_loss_scale = 200 #@param {type:"number"}
1030+
init_mse_scale = 200 #@param {type:"number"}
10231031

10241032
n_samples = 1 # doesnt do anything
10251033
precision = 'autocast'

0 commit comments

Comments
 (0)