diff --git a/inference_propainter.py b/inference_propainter.py index 4d7f92f3..90470268 100644 --- a/inference_propainter.py +++ b/inference_propainter.py @@ -261,10 +261,10 @@ def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num= masked_frame_for_save.append(fuse_img.astype(np.uint8)) frames_inp = [np.array(f).astype(np.uint8) for f in frames] - frames = to_tensors()(frames).unsqueeze(0) * 2 - 1 - flow_masks = to_tensors()(flow_masks).unsqueeze(0) - masks_dilated = to_tensors()(masks_dilated).unsqueeze(0) - frames, flow_masks, masks_dilated = frames.to(device), flow_masks.to(device), masks_dilated.to(device) + # Keep frames on CPU - will move to GPU in chunks + frames_cpu = frames + flow_masks_cpu = flow_masks + masks_dilated_cpu = masks_dilated ############################################## @@ -293,163 +293,153 @@ def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num= ############################################## - # ProPainter inference + # ProPainter inference with chunking ############################################## - video_length = frames.size(1) + video_length = len(frames_cpu) print(f'\nProcessing: {video_name} [{video_length} frames]...') - with torch.no_grad(): - # ---- compute flow ---- - if frames.size(-1) <= 640: - short_clip_len = 12 - elif frames.size(-1) <= 720: - short_clip_len = 8 - elif frames.size(-1) <= 1280: - short_clip_len = 4 - else: - short_clip_len = 2 + + # Determine chunk size and overlap + chunk_size = args.subvideo_length + overlap = 5 # frames overlap on each side for smooth blending + + # Store results + comp_frames = [None] * video_length + comp_weights = [0] * video_length # for blending overlaps + + # Process video in chunks + for chunk_start in tqdm(range(0, video_length, chunk_size - overlap), desc="Processing chunks"): + chunk_end = min(chunk_start + chunk_size, video_length) - # use fp32 for RAFT - if frames.size(1) > short_clip_len: - gt_flows_f_list, gt_flows_b_list = [], [] - for f in range(0, video_length, short_clip_len): - end_f = min(video_length, f + short_clip_len) - if f == 0: - flows_f, flows_b = fix_raft(frames[:,f:end_f], iters=args.raft_iter) - else: - flows_f, flows_b = fix_raft(frames[:,f-1:end_f], iters=args.raft_iter) - - gt_flows_f_list.append(flows_f) - gt_flows_b_list.append(flows_b) - torch.cuda.empty_cache() - - gt_flows_f = torch.cat(gt_flows_f_list, dim=1) - gt_flows_b = torch.cat(gt_flows_b_list, dim=1) - gt_flows_bi = (gt_flows_f, gt_flows_b) - else: - gt_flows_bi = fix_raft(frames, iters=args.raft_iter) - torch.cuda.empty_cache() - - - if use_half: - frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half() - gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half()) - fix_flow_complete = fix_flow_complete.half() - model = model.half() - + # Extract chunk with overlap + chunk_frames = frames_cpu[chunk_start:chunk_end] + chunk_flow_masks = flow_masks_cpu[chunk_start:chunk_end] + chunk_masks_dilated = masks_dilated_cpu[chunk_start:chunk_end] + chunk_frames_inp = frames_inp[chunk_start:chunk_end] + + chunk_len = len(chunk_frames) + if chunk_len == 0: + break + + print(f'\nProcessing chunk: frames {chunk_start} to {chunk_end-1} ({chunk_len} frames)') + + # Convert chunk to tensors and move to GPU + frames = to_tensors()(chunk_frames).unsqueeze(0) * 2 - 1 + flow_masks = to_tensors()(chunk_flow_masks).unsqueeze(0) + masks_dilated = to_tensors()(chunk_masks_dilated).unsqueeze(0) + frames = frames.to(device) + flow_masks = flow_masks.to(device) + masks_dilated = masks_dilated.to(device) - # ---- complete flow ---- - flow_length = gt_flows_bi[0].size(1) - if flow_length > args.subvideo_length: - pred_flows_f, pred_flows_b = [], [] - pad_len = 5 - for f in range(0, flow_length, args.subvideo_length): - s_f = max(0, f - pad_len) - e_f = min(flow_length, f + args.subvideo_length + pad_len) - pad_len_s = max(0, f) - s_f - pad_len_e = e_f - min(flow_length, f + args.subvideo_length) - pred_flows_bi_sub, _ = fix_flow_complete.forward_bidirect_flow( - (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]), - flow_masks[:, s_f:e_f+1]) - pred_flows_bi_sub = fix_flow_complete.combine_flow( - (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]), - pred_flows_bi_sub, - flow_masks[:, s_f:e_f+1]) - - pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e]) - pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e]) + with torch.no_grad(): + # ---- compute flow ---- + if frames.size(-1) <= 640: + short_clip_len = 12 + elif frames.size(-1) <= 720: + short_clip_len = 8 + elif frames.size(-1) <= 1280: + short_clip_len = 4 + else: + short_clip_len = 2 + + # use fp32 for RAFT + if frames.size(1) > short_clip_len: + gt_flows_f_list, gt_flows_b_list = [], [] + for f in range(0, chunk_len, short_clip_len): + end_f = min(chunk_len, f + short_clip_len) + if f == 0: + flows_f, flows_b = fix_raft(frames[:,f:end_f], iters=args.raft_iter) + else: + flows_f, flows_b = fix_raft(frames[:,f-1:end_f], iters=args.raft_iter) + + gt_flows_f_list.append(flows_f) + gt_flows_b_list.append(flows_b) + torch.cuda.empty_cache() + + gt_flows_f = torch.cat(gt_flows_f_list, dim=1) + gt_flows_b = torch.cat(gt_flows_b_list, dim=1) + gt_flows_bi = (gt_flows_f, gt_flows_b) + else: + gt_flows_bi = fix_raft(frames, iters=args.raft_iter) torch.cuda.empty_cache() - - pred_flows_f = torch.cat(pred_flows_f, dim=1) - pred_flows_b = torch.cat(pred_flows_b, dim=1) - pred_flows_bi = (pred_flows_f, pred_flows_b) - else: + + if use_half: + frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half() + gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half()) + if chunk_start == 0: # Only convert models once + fix_flow_complete = fix_flow_complete.half() + model = model.half() + + # ---- complete flow ---- pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks) pred_flows_bi = fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks) torch.cuda.empty_cache() - - # ---- image propagation ---- - masked_frames = frames * (1 - masks_dilated) - subvideo_length_img_prop = min(100, args.subvideo_length) # ensure a minimum of 100 frames for image propagation - if video_length > subvideo_length_img_prop: - updated_frames, updated_masks = [], [] - pad_len = 10 - for f in range(0, video_length, subvideo_length_img_prop): - s_f = max(0, f - pad_len) - e_f = min(video_length, f + subvideo_length_img_prop + pad_len) - pad_len_s = max(0, f) - s_f - pad_len_e = e_f - min(video_length, f + subvideo_length_img_prop) - - b, t, _, _, _ = masks_dilated[:, s_f:e_f].size() - pred_flows_bi_sub = (pred_flows_bi[0][:, s_f:e_f-1], pred_flows_bi[1][:, s_f:e_f-1]) - prop_imgs_sub, updated_local_masks_sub = model.img_propagation(masked_frames[:, s_f:e_f], - pred_flows_bi_sub, - masks_dilated[:, s_f:e_f], - 'nearest') - updated_frames_sub = frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) + \ - prop_imgs_sub.view(b, t, 3, h, w) * masks_dilated[:, s_f:e_f] - updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w) - - updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e]) - updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e]) - torch.cuda.empty_cache() - - updated_frames = torch.cat(updated_frames, dim=1) - updated_masks = torch.cat(updated_masks, dim=1) - else: + # ---- image propagation ---- + masked_frames = frames * (1 - masks_dilated) b, t, _, _, _ = masks_dilated.size() prop_imgs, updated_local_masks = model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest') updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated updated_masks = updated_local_masks.view(b, t, 1, h, w) torch.cuda.empty_cache() - - - ori_frames = frames_inp - comp_frames = [None] * video_length - - neighbor_stride = args.neighbor_length // 2 - if video_length > args.subvideo_length: - ref_num = args.subvideo_length // args.ref_stride - else: - ref_num = -1 - - # ---- feature propagation + transformer ---- - for f in tqdm(range(0, video_length, neighbor_stride)): - neighbor_ids = [ - i for i in range(max(0, f - neighbor_stride), - min(video_length, f + neighbor_stride + 1)) - ] - ref_ids = get_ref_index(f, neighbor_ids, video_length, args.ref_stride, ref_num) - selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :] - selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :] - selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :] - selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :]) - with torch.no_grad(): - # 1.0 indicates mask - l_t = len(neighbor_ids) - - # pred_img = selected_imgs # results of image propagation - pred_img = model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t) + # ---- feature propagation + transformer ---- + neighbor_stride = args.neighbor_length // 2 + ref_num = chunk_size // args.ref_stride if chunk_len > chunk_size // 2 else -1 + + for f in range(0, chunk_len, neighbor_stride): + neighbor_ids = [ + i for i in range(max(0, f - neighbor_stride), + min(chunk_len, f + neighbor_stride + 1)) + ] + ref_ids = get_ref_index(f, neighbor_ids, chunk_len, args.ref_stride, ref_num) + selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :] + selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :] + selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :] + selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :]) - pred_img = pred_img.view(-1, 3, h, w) - - pred_img = (pred_img + 1) / 2 - pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255 - binary_masks = masks_dilated[0, neighbor_ids, :, :, :].cpu().permute( - 0, 2, 3, 1).numpy().astype(np.uint8) - for i in range(len(neighbor_ids)): - idx = neighbor_ids[i] - img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \ - + ori_frames[idx] * (1 - binary_masks[i]) - if comp_frames[idx] is None: - comp_frames[idx] = img - else: - comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5 + with torch.no_grad(): + l_t = len(neighbor_ids) + pred_img = model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t) + pred_img = pred_img.view(-1, 3, h, w) + pred_img = (pred_img + 1) / 2 + pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255 + binary_masks = masks_dilated[0, neighbor_ids, :, :, :].cpu().permute( + 0, 2, 3, 1).numpy().astype(np.uint8) + + for i in range(len(neighbor_ids)): + idx = neighbor_ids[i] + global_idx = chunk_start + idx - comp_frames[idx] = comp_frames[idx].astype(np.uint8) + img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \ + + chunk_frames_inp[idx] * (1 - binary_masks[i]) + + # Compute blending weight based on position in chunk + weight = 1.0 + if overlap > 0: + # Reduce weight near chunk boundaries for smooth blending + if idx < overlap and chunk_start > 0: + weight = idx / overlap + elif idx >= chunk_len - overlap and chunk_end < video_length: + weight = (chunk_len - idx) / overlap + + if comp_frames[global_idx] is None: + comp_frames[global_idx] = img.astype(np.float32) * weight + comp_weights[global_idx] = weight + else: + comp_frames[global_idx] = comp_frames[global_idx] + img.astype(np.float32) * weight + comp_weights[global_idx] += weight + + torch.cuda.empty_cache() + # Free GPU memory for this chunk + del frames, flow_masks, masks_dilated, gt_flows_bi, pred_flows_bi + del masked_frames, prop_imgs, updated_frames, updated_masks torch.cuda.empty_cache() + + # Normalize blended frames + for i in range(video_length): + if comp_weights[i] > 0: + comp_frames[i] = (comp_frames[i] / comp_weights[i]).astype(np.uint8) # save each frame if args.save_frames: