Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 132 additions & 142 deletions inference_propainter.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,10 @@ def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=
masked_frame_for_save.append(fuse_img.astype(np.uint8))

frames_inp = [np.array(f).astype(np.uint8) for f in frames]
frames = to_tensors()(frames).unsqueeze(0) * 2 - 1
flow_masks = to_tensors()(flow_masks).unsqueeze(0)
masks_dilated = to_tensors()(masks_dilated).unsqueeze(0)
frames, flow_masks, masks_dilated = frames.to(device), flow_masks.to(device), masks_dilated.to(device)
# Keep frames on CPU - will move to GPU in chunks
frames_cpu = frames
flow_masks_cpu = flow_masks
masks_dilated_cpu = masks_dilated


##############################################
Expand Down Expand Up @@ -293,163 +293,153 @@ def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=


##############################################
# ProPainter inference
# ProPainter inference with chunking
##############################################
video_length = frames.size(1)
video_length = len(frames_cpu)
print(f'\nProcessing: {video_name} [{video_length} frames]...')
with torch.no_grad():
# ---- compute flow ----
if frames.size(-1) <= 640:
short_clip_len = 12
elif frames.size(-1) <= 720:
short_clip_len = 8
elif frames.size(-1) <= 1280:
short_clip_len = 4
else:
short_clip_len = 2

# Determine chunk size and overlap
chunk_size = args.subvideo_length
overlap = 5 # frames overlap on each side for smooth blending

# Store results
comp_frames = [None] * video_length
comp_weights = [0] * video_length # for blending overlaps

# Process video in chunks
for chunk_start in tqdm(range(0, video_length, chunk_size - overlap), desc="Processing chunks"):
chunk_end = min(chunk_start + chunk_size, video_length)

# use fp32 for RAFT
if frames.size(1) > short_clip_len:
gt_flows_f_list, gt_flows_b_list = [], []
for f in range(0, video_length, short_clip_len):
end_f = min(video_length, f + short_clip_len)
if f == 0:
flows_f, flows_b = fix_raft(frames[:,f:end_f], iters=args.raft_iter)
else:
flows_f, flows_b = fix_raft(frames[:,f-1:end_f], iters=args.raft_iter)

gt_flows_f_list.append(flows_f)
gt_flows_b_list.append(flows_b)
torch.cuda.empty_cache()

gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
gt_flows_bi = (gt_flows_f, gt_flows_b)
else:
gt_flows_bi = fix_raft(frames, iters=args.raft_iter)
torch.cuda.empty_cache()


if use_half:
frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half()
gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half())
fix_flow_complete = fix_flow_complete.half()
model = model.half()

# Extract chunk with overlap
chunk_frames = frames_cpu[chunk_start:chunk_end]
chunk_flow_masks = flow_masks_cpu[chunk_start:chunk_end]
chunk_masks_dilated = masks_dilated_cpu[chunk_start:chunk_end]
chunk_frames_inp = frames_inp[chunk_start:chunk_end]

chunk_len = len(chunk_frames)
if chunk_len == 0:
break

print(f'\nProcessing chunk: frames {chunk_start} to {chunk_end-1} ({chunk_len} frames)')

# Convert chunk to tensors and move to GPU
frames = to_tensors()(chunk_frames).unsqueeze(0) * 2 - 1
flow_masks = to_tensors()(chunk_flow_masks).unsqueeze(0)
masks_dilated = to_tensors()(chunk_masks_dilated).unsqueeze(0)
frames = frames.to(device)
flow_masks = flow_masks.to(device)
masks_dilated = masks_dilated.to(device)

# ---- complete flow ----
flow_length = gt_flows_bi[0].size(1)
if flow_length > args.subvideo_length:
pred_flows_f, pred_flows_b = [], []
pad_len = 5
for f in range(0, flow_length, args.subvideo_length):
s_f = max(0, f - pad_len)
e_f = min(flow_length, f + args.subvideo_length + pad_len)
pad_len_s = max(0, f) - s_f
pad_len_e = e_f - min(flow_length, f + args.subvideo_length)
pred_flows_bi_sub, _ = fix_flow_complete.forward_bidirect_flow(
(gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
flow_masks[:, s_f:e_f+1])
pred_flows_bi_sub = fix_flow_complete.combine_flow(
(gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
pred_flows_bi_sub,
flow_masks[:, s_f:e_f+1])

pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
with torch.no_grad():
# ---- compute flow ----
if frames.size(-1) <= 640:
short_clip_len = 12
elif frames.size(-1) <= 720:
short_clip_len = 8
elif frames.size(-1) <= 1280:
short_clip_len = 4
else:
short_clip_len = 2

# use fp32 for RAFT
if frames.size(1) > short_clip_len:
gt_flows_f_list, gt_flows_b_list = [], []
for f in range(0, chunk_len, short_clip_len):
end_f = min(chunk_len, f + short_clip_len)
if f == 0:
flows_f, flows_b = fix_raft(frames[:,f:end_f], iters=args.raft_iter)
else:
flows_f, flows_b = fix_raft(frames[:,f-1:end_f], iters=args.raft_iter)

gt_flows_f_list.append(flows_f)
gt_flows_b_list.append(flows_b)
torch.cuda.empty_cache()

gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
gt_flows_bi = (gt_flows_f, gt_flows_b)
else:
gt_flows_bi = fix_raft(frames, iters=args.raft_iter)
torch.cuda.empty_cache()

pred_flows_f = torch.cat(pred_flows_f, dim=1)
pred_flows_b = torch.cat(pred_flows_b, dim=1)
pred_flows_bi = (pred_flows_f, pred_flows_b)
else:

if use_half:
frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half()
gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half())
if chunk_start == 0: # Only convert models once
fix_flow_complete = fix_flow_complete.half()
model = model.half()

# ---- complete flow ----
pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
pred_flows_bi = fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
torch.cuda.empty_cache()


# ---- image propagation ----
masked_frames = frames * (1 - masks_dilated)
subvideo_length_img_prop = min(100, args.subvideo_length) # ensure a minimum of 100 frames for image propagation
if video_length > subvideo_length_img_prop:
updated_frames, updated_masks = [], []
pad_len = 10
for f in range(0, video_length, subvideo_length_img_prop):
s_f = max(0, f - pad_len)
e_f = min(video_length, f + subvideo_length_img_prop + pad_len)
pad_len_s = max(0, f) - s_f
pad_len_e = e_f - min(video_length, f + subvideo_length_img_prop)

b, t, _, _, _ = masks_dilated[:, s_f:e_f].size()
pred_flows_bi_sub = (pred_flows_bi[0][:, s_f:e_f-1], pred_flows_bi[1][:, s_f:e_f-1])
prop_imgs_sub, updated_local_masks_sub = model.img_propagation(masked_frames[:, s_f:e_f],
pred_flows_bi_sub,
masks_dilated[:, s_f:e_f],
'nearest')
updated_frames_sub = frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) + \
prop_imgs_sub.view(b, t, 3, h, w) * masks_dilated[:, s_f:e_f]
updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w)

updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
torch.cuda.empty_cache()

updated_frames = torch.cat(updated_frames, dim=1)
updated_masks = torch.cat(updated_masks, dim=1)
else:
# ---- image propagation ----
masked_frames = frames * (1 - masks_dilated)
b, t, _, _, _ = masks_dilated.size()
prop_imgs, updated_local_masks = model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest')
updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
updated_masks = updated_local_masks.view(b, t, 1, h, w)
torch.cuda.empty_cache()


ori_frames = frames_inp
comp_frames = [None] * video_length

neighbor_stride = args.neighbor_length // 2
if video_length > args.subvideo_length:
ref_num = args.subvideo_length // args.ref_stride
else:
ref_num = -1

# ---- feature propagation + transformer ----
for f in tqdm(range(0, video_length, neighbor_stride)):
neighbor_ids = [
i for i in range(max(0, f - neighbor_stride),
min(video_length, f + neighbor_stride + 1))
]
ref_ids = get_ref_index(f, neighbor_ids, video_length, args.ref_stride, ref_num)
selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :]
selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :]
selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :]
selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :])

with torch.no_grad():
# 1.0 indicates mask
l_t = len(neighbor_ids)

# pred_img = selected_imgs # results of image propagation
pred_img = model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t)
# ---- feature propagation + transformer ----
neighbor_stride = args.neighbor_length // 2
ref_num = chunk_size // args.ref_stride if chunk_len > chunk_size // 2 else -1

for f in range(0, chunk_len, neighbor_stride):
neighbor_ids = [
i for i in range(max(0, f - neighbor_stride),
min(chunk_len, f + neighbor_stride + 1))
]
ref_ids = get_ref_index(f, neighbor_ids, chunk_len, args.ref_stride, ref_num)
selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :]
selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :]
selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :]
selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :])

pred_img = pred_img.view(-1, 3, h, w)

pred_img = (pred_img + 1) / 2
pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
binary_masks = masks_dilated[0, neighbor_ids, :, :, :].cpu().permute(
0, 2, 3, 1).numpy().astype(np.uint8)
for i in range(len(neighbor_ids)):
idx = neighbor_ids[i]
img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \
+ ori_frames[idx] * (1 - binary_masks[i])
if comp_frames[idx] is None:
comp_frames[idx] = img
else:
comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5
with torch.no_grad():
l_t = len(neighbor_ids)
pred_img = model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t)
pred_img = pred_img.view(-1, 3, h, w)
pred_img = (pred_img + 1) / 2
pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
binary_masks = masks_dilated[0, neighbor_ids, :, :, :].cpu().permute(
0, 2, 3, 1).numpy().astype(np.uint8)

for i in range(len(neighbor_ids)):
idx = neighbor_ids[i]
global_idx = chunk_start + idx

comp_frames[idx] = comp_frames[idx].astype(np.uint8)
img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \
+ chunk_frames_inp[idx] * (1 - binary_masks[i])

# Compute blending weight based on position in chunk
weight = 1.0
if overlap > 0:
# Reduce weight near chunk boundaries for smooth blending
if idx < overlap and chunk_start > 0:
weight = idx / overlap
elif idx >= chunk_len - overlap and chunk_end < video_length:
weight = (chunk_len - idx) / overlap

if comp_frames[global_idx] is None:
comp_frames[global_idx] = img.astype(np.float32) * weight
comp_weights[global_idx] = weight
else:
comp_frames[global_idx] = comp_frames[global_idx] + img.astype(np.float32) * weight
comp_weights[global_idx] += weight

torch.cuda.empty_cache()

# Free GPU memory for this chunk
del frames, flow_masks, masks_dilated, gt_flows_bi, pred_flows_bi
del masked_frames, prop_imgs, updated_frames, updated_masks
torch.cuda.empty_cache()

# Normalize blended frames
for i in range(video_length):
if comp_weights[i] > 0:
comp_frames[i] = (comp_frames[i] / comp_weights[i]).astype(np.uint8)

# save each frame
if args.save_frames:
Expand Down