Skip to content

Commit 2dc4a93

Browse files
Merge pull request CompVis#54 from pharmapsychotic/dev
diffusion_cadence (turbo) for 2D and 3D animations!
2 parents 998c763 + a3d4320 commit 2dc4a93

File tree

1 file changed

+96
-43
lines changed

1 file changed

+96
-43
lines changed

Deforum_Stable_Diffusion.py

Lines changed: 96 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
all_process = [
8181
['pip', 'install', 'torch==1.12.1+cu113', 'torchvision==0.13.1+cu113', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],
8282
['pip', 'install', 'omegaconf==2.2.3', 'einops==0.4.1', 'pytorch-lightning==1.7.4', 'torchmetrics==0.9.3', 'torchtext==0.13.1', 'transformers==4.21.2', 'kornia==0.6.7'],
83-
['git', 'clone', 'https://github.com/deforum/stable-diffusion'],
83+
['git', 'clone', '-b', 'dev', 'https://github.com/deforum/stable-diffusion'],
8484
['pip', 'install', '-e', 'git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers'],
8585
['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'],
8686
['pip', 'install', 'accelerate', 'ftfy', 'jsonmerge', 'matplotlib', 'resize-right', 'timm', 'torchdiffeq'],
@@ -178,7 +178,7 @@ def anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx):
178178
borderMode=cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE
179179
)
180180

181-
def anim_frame_warp_3d(prev_img_cv2, anim_args, keys, frame_idx, adabins_helper, midas_model, midas_transform):
181+
def anim_frame_warp_3d(prev_img_cv2, depth, anim_args, keys, frame_idx):
182182
TRANSLATION_SCALE = 1.0/200.0 # matches Disco
183183
translate_xyz = [
184184
-keys.translation_x_series[frame_idx] * TRANSLATION_SCALE,
@@ -191,7 +191,7 @@ def anim_frame_warp_3d(prev_img_cv2, anim_args, keys, frame_idx, adabins_helper,
191191
math.radians(keys.rotation_3d_z_series[frame_idx])
192192
]
193193
rot_mat = p3d.euler_angles_to_matrix(torch.tensor(rotate_xyz, device=device), "XYZ").unsqueeze(0)
194-
result = transform_image_3d(prev_img_cv2, adabins_helper, midas_model, midas_transform, rot_mat, translate_xyz, anim_args)
194+
result = transform_image_3d(prev_img_cv2, depth, rot_mat, translate_xyz, anim_args)
195195
torch.cuda.empty_cache()
196196
return result
197197

@@ -376,27 +376,24 @@ def sample_from_cv2(sample: np.ndarray) -> torch.Tensor:
376376
sample = torch.from_numpy(sample)
377377
return sample
378378

379-
def sample_to_cv2(sample: torch.Tensor) -> np.ndarray:
379+
def sample_to_cv2(sample: torch.Tensor, type=np.uint8) -> np.ndarray:
380380
sample_f32 = rearrange(sample.squeeze().cpu().numpy(), "c h w -> h w c").astype(np.float32)
381381
sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1)
382-
sample_int8 = (sample_f32 * 255).astype(np.uint8)
383-
return sample_int8
382+
sample_int8 = (sample_f32 * 255)
383+
return sample_int8.astype(type)
384384

385385
@torch.no_grad()
386-
def transform_image_3d(prev_img_cv2, adabins_helper, midas_model, midas_transform, rot_mat, translate, anim_args):
387-
# adapted and optimized version of transform_image_3d from Disco Diffusion https://github.com/alembics/disco-diffusion
388-
386+
def predict_depth(prev_img_cv2, adabins_helper, midas_model, midas_transform, anim_args) -> torch.Tensor:
389387
w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0]
390388

391389
# predict depth with AdaBins
392390
use_adabins = anim_args.midas_weight < 1.0 and adabins_helper is not None
393391
if use_adabins:
394-
print(f"Estimating depth of {w}x{h} image with AdaBins...")
395392
MAX_ADABINS_AREA = 500000
396393
MIN_ADABINS_AREA = 448*448
397394

398395
# resize image if too large or too small
399-
img_pil = Image.fromarray(cv2.cvtColor(prev_img_cv2, cv2.COLOR_RGB2BGR))
396+
img_pil = Image.fromarray(cv2.cvtColor(prev_img_cv2.astype(np.uint8), cv2.COLOR_RGB2BGR))
400397
image_pil_area = w*h
401398
resized = True
402399
if image_pil_area > MAX_ADABINS_AREA:
@@ -415,10 +412,10 @@ def transform_image_3d(prev_img_cv2, adabins_helper, midas_model, midas_transfor
415412
try:
416413
_, adabins_depth = adabins_helper.predict_pil(depth_input)
417414
if resized:
418-
adabins_depth = torchvision.transforms.functional.resize(
415+
adabins_depth = TF.resize(
419416
torch.from_numpy(adabins_depth),
420417
torch.Size([h, w]),
421-
interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC
418+
interpolation=TF.InterpolationMode.BICUBIC
422419
)
423420
adabins_depth = adabins_depth.squeeze()
424421
except:
@@ -432,7 +429,6 @@ def transform_image_3d(prev_img_cv2, adabins_helper, midas_model, midas_transfor
432429
img_midas_input = midas_transform({"image": img_midas})["image"]
433430

434431
# MiDaS depth estimation implementation
435-
print(f"Estimating depth of {w}x{h} image with MiDaS...")
436432
sample = torch.from_numpy(img_midas_input).float().to(device).unsqueeze(0)
437433
if device == torch.device("cuda"):
438434
sample = sample.to(memory_format=torch.channels_last)
@@ -461,6 +457,12 @@ def transform_image_3d(prev_img_cv2, adabins_helper, midas_model, midas_transfor
461457
depth_tensor = torch.from_numpy(depth_map).squeeze().to(device)
462458
else:
463459
depth_tensor = torch.ones((h, w), device=device)
460+
461+
return depth_tensor
462+
463+
def transform_image_3d(prev_img_cv2, depth_tensor, rot_mat, translate, anim_args):
464+
# adapted and optimized version of transform_image_3d from Disco Diffusion https://github.com/alembics/disco-diffusion
465+
w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0]
464466

465467
pixel_aspect = 1.0 # aspect of an individual pixel (so usually 1.0)
466468
near, far, fov_deg = anim_args.near_plane, anim_args.far_plane, anim_args.fov
@@ -482,7 +484,7 @@ def transform_image_3d(prev_img_cv2, adabins_helper, midas_model, midas_transfor
482484
coords_2d = torch.nn.functional.affine_grid(identity_2d_batch, [1,1,h,w], align_corners=False)
483485
offset_coords_2d = coords_2d - torch.reshape(offset_xy, (h,w,2)).unsqueeze(0)
484486

485-
image_tensor = torchvision.transforms.functional.to_tensor(Image.fromarray(prev_img_cv2)).to(device)
487+
image_tensor = rearrange(torch.from_numpy(prev_img_cv2.astype(np.float32)), 'h w c -> c h w').to(device)
486488
new_image = torch.nn.functional.grid_sample(
487489
image_tensor.add(1/512 - 0.0001).unsqueeze(0),
488490
offset_coords_2d,
@@ -491,11 +493,11 @@ def transform_image_3d(prev_img_cv2, adabins_helper, midas_model, midas_transfor
491493
align_corners=False
492494
)
493495

494-
# convert back to cv2 style numpy array 0->255 uint8
496+
# convert back to cv2 style numpy array
495497
result = rearrange(
496-
new_image.squeeze().clamp(0,1) * 255.0,
498+
new_image.squeeze().clamp(0,255),
497499
'c h w -> h w c'
498-
).cpu().numpy().astype(np.uint8)
500+
).cpu().numpy().astype(prev_img_cv2.dtype)
499501
return result
500502

501503
def generate(args, return_latent=False, return_sample=False, return_c=False):
@@ -763,8 +765,9 @@ def DeforumAnimArgs():
763765

764766
#@markdown ####**Coherence:**
765767
color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}
768+
diffusion_cadence = '3' #@param ['1','2','3','4','5','6','7','8'] {type:'string'}
766769

767-
#@markdown #### Depth Warping
770+
#@markdown ####**3D Depth Warping:**
768771
use_depth_warping = True #@param {type:"boolean"}
769772
midas_weight = 0.3#@param {type:"number"}
770773
near_plane = 200
@@ -878,7 +881,7 @@ def parse_key_frames(string, prompt_parser=None):
878881
# !! "cellView": "form"
879882
# !! }}
880883
def DeforumArgs():
881-
884+
882885
#@markdown **Image Settings**
883886
W = 512 #@param
884887
H = 512 #@param
@@ -906,7 +909,7 @@ def DeforumArgs():
906909
make_grid = False #@param {type:"boolean"}
907910
grid_rows = 2 #@param
908911
outdir = get_output_folder(output_path, batch_name)
909-
912+
910913
#@markdown **Init Settings**
911914
use_init = False #@param {type:"boolean"}
912915
strength = 0.0 #@param {type:"number"}
@@ -940,7 +943,7 @@ def next_seed(args):
940943
elif args.seed_behavior == 'fixed':
941944
pass # always keep seed the same
942945
else:
943-
args.seed = random.randint(0, 2**32)
946+
args.seed = random.randint(0, 2**32 - 1)
944947
return args.seed
945948

946949
def render_image_batch(args):
@@ -984,7 +987,7 @@ def render_image_batch(args):
984987
args.prompt = prompt
985988
print(f"Prompt {iprompt+1} of {len(prompts)}")
986989
print(f"{args.prompt}")
987-
990+
988991
all_images = []
989992

990993
for batch_index in range(args.n_batch):
@@ -1066,29 +1069,75 @@ def render_animation(args, anim_args):
10661069
else:
10671070
adabins_helper, midas_model, midas_transform = None, None, None
10681071

1069-
args.n_samples = 1
1072+
# state for interpolating between diffusion steps
1073+
turbo_steps = 1 if using_vid_init else int(anim_args.diffusion_cadence)
1074+
turbo_prev_image, turbo_prev_frame_idx = None, 0
1075+
turbo_next_image, turbo_next_frame_idx = None, 0
1076+
1077+
# resume animation
10701078
prev_sample = None
10711079
color_match_sample = None
1072-
for frame_idx in range(start_frame,anim_args.max_frames):
1080+
if anim_args.resume_from_timestring:
1081+
last_frame = start_frame-1
1082+
if turbo_steps > 1:
1083+
last_frame -= last_frame%turbo_steps
1084+
path = os.path.join(args.outdir,f"{args.timestring}_{last_frame:05}.png")
1085+
img = cv2.imread(path)
1086+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
1087+
prev_sample = sample_from_cv2(img)
1088+
if anim_args.color_coherence != 'None':
1089+
color_match_sample = img
1090+
if turbo_steps > 1:
1091+
turbo_next_image, turbo_next_frame_idx = sample_to_cv2(prev_sample, type=np.float32), last_frame
1092+
turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx
1093+
start_frame = last_frame+turbo_steps
1094+
1095+
args.n_samples = 1
1096+
frame_idx = start_frame
1097+
while frame_idx < anim_args.max_frames:
10731098
print(f"Rendering animation frame {frame_idx} of {anim_args.max_frames}")
10741099
noise = keys.noise_schedule_series[frame_idx]
10751100
strength = keys.strength_schedule_series[frame_idx]
10761101
contrast = keys.contrast_schedule_series[frame_idx]
10771102

1078-
# resume animation
1079-
if anim_args.resume_from_timestring:
1080-
path = os.path.join(args.outdir,f"{args.timestring}_{frame_idx-1:05}.png")
1081-
img = cv2.imread(path)
1082-
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
1083-
prev_sample = sample_from_cv2(img)
1103+
# emit in-between frames
1104+
if turbo_steps > 1:
1105+
tween_frame_start_idx = max(0, frame_idx-turbo_steps)
1106+
for tween_frame_idx in range(tween_frame_start_idx, frame_idx):
1107+
tween = float(tween_frame_idx - tween_frame_start_idx + 1) / float(frame_idx - tween_frame_start_idx)
1108+
print(f" creating in between frame {tween_frame_idx} tween:{tween:0.2f}")
1109+
if anim_args.animation_mode == '2D':
1110+
if turbo_prev_image is not None and tween_frame_idx > turbo_prev_frame_idx:
1111+
turbo_prev_image = anim_frame_warp_2d(turbo_prev_image, args, anim_args, keys, tween_frame_idx)
1112+
if tween_frame_idx > turbo_next_frame_idx:
1113+
turbo_next_image = anim_frame_warp_2d(turbo_next_image, args, anim_args, keys, tween_frame_idx)
1114+
else: # '3D'
1115+
if turbo_prev_image is not None and tween_frame_idx > turbo_prev_frame_idx:
1116+
prev_depth = predict_depth(turbo_prev_image, adabins_helper, midas_model, midas_transform, anim_args)
1117+
turbo_prev_image = anim_frame_warp_3d(turbo_prev_image, prev_depth, anim_args, keys, tween_frame_idx)
1118+
if tween_frame_idx > turbo_next_frame_idx:
1119+
next_depth = predict_depth(turbo_next_image, adabins_helper, midas_model, midas_transform, anim_args)
1120+
turbo_next_image = anim_frame_warp_3d(turbo_next_image, next_depth, anim_args, keys, tween_frame_idx)
1121+
turbo_prev_frame_idx = turbo_next_frame_idx = tween_frame_idx
1122+
1123+
if turbo_prev_image is not None and tween < 1.0:
1124+
img = turbo_prev_image*(1.0-tween) + turbo_next_image*tween
1125+
else:
1126+
img = turbo_next_image
1127+
1128+
filename = f"{args.timestring}_{tween_frame_idx:05}.png"
1129+
cv2.imwrite(os.path.join(args.outdir, filename), cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2BGR))
1130+
if turbo_next_image is not None:
1131+
prev_sample = sample_from_cv2(turbo_next_image)
10841132

10851133
# apply transforms to previous frame
10861134
if prev_sample is not None:
1087-
10881135
if anim_args.animation_mode == '2D':
10891136
prev_img = anim_frame_warp_2d(sample_to_cv2(prev_sample), args, anim_args, keys, frame_idx)
10901137
else: # '3D'
1091-
prev_img = anim_frame_warp_3d(sample_to_cv2(prev_sample), anim_args, keys, frame_idx, adabins_helper, midas_model, midas_transform)
1138+
prev_img_cv2 = sample_to_cv2(prev_sample)
1139+
depth = predict_depth(prev_img_cv2, adabins_helper, midas_model, midas_transform, anim_args)
1140+
prev_img = anim_frame_warp_3d(prev_img_cv2, depth, anim_args, keys, frame_idx)
10921141

10931142
# apply color matching
10941143
if anim_args.color_coherence != 'None':
@@ -1104,7 +1153,6 @@ def render_animation(args, anim_args):
11041153

11051154
# use transformed previous frame as init for current
11061155
args.use_init = True
1107-
#args.init_sample = noised_sample.half().to(device)
11081156
if half_precision:
11091157
args.init_sample = noised_sample.half().to(device)
11101158
else:
@@ -1122,14 +1170,19 @@ def render_animation(args, anim_args):
11221170
args.init_image = init_frame
11231171

11241172
# sample the diffusion model
1125-
results = generate(args, return_latent=False, return_sample=True)
1126-
sample, image = results[0], results[1]
1127-
1128-
filename = f"{args.timestring}_{frame_idx:05}.png"
1129-
image.save(os.path.join(args.outdir, filename))
1173+
sample, image = generate(args, return_latent=False, return_sample=True)
11301174
if not using_vid_init:
11311175
prev_sample = sample
1132-
1176+
1177+
if turbo_steps > 1:
1178+
turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx
1179+
turbo_next_image, turbo_next_frame_idx = sample_to_cv2(sample, type=np.float32), frame_idx
1180+
frame_idx += turbo_steps
1181+
else:
1182+
filename = f"{args.timestring}_{frame_idx:05}.png"
1183+
image.save(os.path.join(args.outdir, filename))
1184+
frame_idx += 1
1185+
11331186
display.clear_output(wait=True)
11341187
display.display(image)
11351188

@@ -1270,7 +1323,7 @@ def render_interpolation(args, anim_args):
12701323
args.strength = max(0.0, min(1.0, args.strength))
12711324

12721325
if args.seed == -1:
1273-
args.seed = random.randint(0, 2**32)
1326+
args.seed = random.randint(0, 2**32 - 1)
12741327
if not args.use_init:
12751328
args.init_image = None
12761329
if args.sampler == 'plms' and (args.use_init or anim_args.animation_mode != 'None'):
@@ -1307,10 +1360,10 @@ def render_interpolation(args, anim_args):
13071360
# !! "cellView": "form",
13081361
# !! "id": "no2jP8HTMBM0"
13091362
# !! }}
1310-
skip_video_for_run_all = False #@param {type: 'boolean'}
1363+
skip_video_for_run_all = True #@param {type: 'boolean'}
13111364
fps = 12 #@param {type:"number"}
13121365
#@markdown **Manual Settings**
1313-
use_manual_settings = True #@param {type:"boolean"}
1366+
use_manual_settings = False #@param {type:"boolean"}
13141367
image_path = "/content/drive/MyDrive/AI/StableDiffusion/2022-09/20220903000939_%05d.png" #@param {type:"string"}
13151368
mp4_path = "/content/drive/MyDrive/AI/StableDiffusion/2022-09/20220903000939.mp4" #@param {type:"string"}
13161369

0 commit comments

Comments
 (0)