From e7506add90100c02950bb246d8def910df3f5eba Mon Sep 17 00:00:00 2001 From: Mahimai Raja J Date: Thu, 28 Sep 2023 10:12:40 +0530 Subject: [PATCH] Added jupyter notebook for demo in colab --- Notebooks/propainter_demo_notebook.ipynb | 827 +++++++++++++++++++++++ README.md | 4 +- 2 files changed, 830 insertions(+), 1 deletion(-) create mode 100644 Notebooks/propainter_demo_notebook.ipynb diff --git a/Notebooks/propainter_demo_notebook.ipynb b/Notebooks/propainter_demo_notebook.ipynb new file mode 100644 index 00000000..e2ad4003 --- /dev/null +++ b/Notebooks/propainter_demo_notebook.ipynb @@ -0,0 +1,827 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "I0suuI_1nOh4" + }, + "source": [ + "
\n", + "\n", + "
\n", + " \n", + " \n", + " \n", + "
\n", + "\n", + "

ProPainter: Improving Propagation and Transformer for Video Inpainting

\n", + "\n", + "
\n", + " Shangchen Zhou \n", + " Chongyi Li \n", + " Kelvin C.K. Chan \n", + " Chen Change Loy\n", + "
\n", + "
\n", + " S-Lab, Nanyang Technological University \n", + "
\n", + "\n", + "
\n", + " ICCV 2023\n", + "
\n", + "\n", + "
\n", + "

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/19uQfU5_Mj9bjrIDsZimVkH7l2P105AYb?usp=sharing)\n", + "

\n", + "
\n", + "\n", + "⭐ If ProPainter is helpful to your projects, please help star this repo. Thanks! 🤗\n", + "\n", + "📖 For more visual results, go checkout our project page\n", + "\n", + "\n", + "---\n", + "\n", + "
\n", + "\n", + "## Update\n", + "- **2023.09.24**: We remove the watermark removal demos officially to prevent the misuse of our work for unethical purposes.\n", + "- **2023.09.21**: Add features for memory-efficient inference. Check our [GPU memory](https://github.com/sczhou/ProPainter#-memory-efficient-inference) requirements. 🚀\n", + "- **2023.09.07**: Our code and model are publicly available. 🐳\n", + "- **2023.09.01**: This repo is created.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "execution": { + "iopub.execute_input": "2023-09-27T14:20:29.808418Z", + "iopub.status.busy": "2023-09-27T14:20:29.808042Z", + "iopub.status.idle": "2023-09-27T14:20:33.717092Z", + "shell.execute_reply": "2023-09-27T14:20:33.715986Z", + "shell.execute_reply.started": "2023-09-27T14:20:29.808387Z" + }, + "id": "o9mzDHxbf17y", + "outputId": "f7a0aab9-cf72-47fd-e798-5af4c2ab723b", + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cloning into 'ProPainter'...\n", + "remote: Enumerating objects: 500, done.\u001b[K\n", + "remote: Counting objects: 100% (117/117), done.\u001b[K\n", + "remote: Compressing objects: 100% (72/72), done.\u001b[K\n", + "remote: Total 500 (delta 47), reused 112 (delta 45), pack-reused 383\u001b[K\n", + "Receiving objects: 100% (500/500), 50.33 MiB | 42.08 MiB/s, done.\n", + "Resolving deltas: 100% (49/49), done.\n", + "/content/ProPainter\n" + ] + } + ], + "source": [ + "!git clone https://github.com/sczhou/ProPainter.git\n", + "%cd ProPainter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "execution": { + "iopub.execute_input": "2023-09-27T14:20:37.407464Z", + "iopub.status.busy": "2023-09-27T14:20:37.407073Z", + "iopub.status.idle": "2023-09-27T14:20:53.135989Z", + "shell.execute_reply": "2023-09-27T14:20:53.134780Z", + "shell.execute_reply.started": "2023-09-27T14:20:37.407432Z" + }, + "id": "vHessKaaf170", + "jupyter": { + "outputs_hidden": true + }, + "trusted": true + }, + "outputs": [], + "source": [ + "!pip install -q -r requirements.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7Wgek3VIlBPY" + }, + "outputs": [], + "source": [ + "import os\n", + "import cv2\n", + "import argparse\n", + "import imageio\n", + "import numpy as np\n", + "import scipy.ndimage\n", + "from PIL import Image\n", + "from tqdm import tqdm\n", + "\n", + "import torch\n", + "import torchvision\n", + "\n", + "from model.modules.flow_comp_raft import RAFT_bi\n", + "from model.recurrent_flow_completion import RecurrentFlowCompleteNet\n", + "from model.propainter import InpaintGenerator\n", + "from utils.download_util import load_file_from_url\n", + "from core.utils import to_tensors\n", + "from model.misc import get_device\n", + "\n", + "from IPython.display import HTML\n", + "\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0oxF6v51hoYs" + }, + "source": [ + "## Necessary Modules" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "execution": { + "iopub.execute_input": "2023-09-27T14:31:51.733606Z", + "iopub.status.busy": "2023-09-27T14:31:51.733238Z", + "iopub.status.idle": "2023-09-27T14:31:51.757508Z", + "shell.execute_reply": "2023-09-27T14:31:51.756113Z", + "shell.execute_reply.started": "2023-09-27T14:31:51.733573Z" + }, + "id": "4Av-Wsd3f171", + "trusted": true + }, + "outputs": [], + "source": [ + "pretrain_model_url = 'https://github.com/sczhou/ProPainter/releases/download/v0.1.0/'\n", + "\n", + "def imwrite(img, file_path, params=None, auto_mkdir=True):\n", + " if auto_mkdir:\n", + " dir_name = os.path.abspath(os.path.dirname(file_path))\n", + " os.makedirs(dir_name, exist_ok=True)\n", + " return cv2.imwrite(file_path, img, params)\n", + "\n", + "\n", + "# resize frames\n", + "def resize_frames(frames, size=None):\n", + " if size is not None:\n", + " out_size = size\n", + " process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)\n", + " frames = [f.resize(process_size) for f in frames]\n", + " else:\n", + " out_size = frames[0].size\n", + " process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)\n", + " if not out_size == process_size:\n", + " frames = [f.resize(process_size) for f in frames]\n", + "\n", + " return frames, process_size, out_size\n", + "\n", + "# read frames from video\n", + "def read_frame_from_videos(frame_root):\n", + " if frame_root.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path\n", + " video_name = os.path.basename(frame_root)[:-4]\n", + " vframes, aframes, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec') # RGB\n", + " frames = list(vframes.numpy())\n", + " frames = [Image.fromarray(f) for f in frames]\n", + " fps = info['video_fps']\n", + " else:\n", + " video_name = os.path.basename(frame_root)\n", + " frames = []\n", + " fr_lst = sorted(os.listdir(frame_root))\n", + " for fr in fr_lst:\n", + " frame = cv2.imread(os.path.join(frame_root, fr))\n", + " frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))\n", + " frames.append(frame)\n", + " fps = None\n", + " size = frames[0].size\n", + "\n", + " return frames, fps, size, video_name\n", + "\n", + "\n", + "def binary_mask(mask, th=0.1):\n", + " mask[mask>th] = 1\n", + " mask[mask<=th] = 0\n", + " return mask\n", + "\n", + "\n", + "# read frame-wise masks\n", + "def read_mask(mpath, length, size, flow_mask_dilates=8, mask_dilates=5):\n", + " masks_img = []\n", + " masks_dilated = []\n", + " flow_masks = []\n", + "\n", + " if mpath.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path\n", + " masks_img = [Image.open(mpath)]\n", + " else:\n", + " mnames = sorted(os.listdir(mpath))\n", + " for mp in mnames:\n", + " masks_img.append(Image.open(os.path.join(mpath, mp)))\n", + "\n", + " for mask_img in masks_img:\n", + " if size is not None:\n", + " mask_img = mask_img.resize(size, Image.NEAREST)\n", + " mask_img = np.array(mask_img.convert('L'))\n", + "\n", + " # Dilate 8 pixel so that all known pixel is trustworthy\n", + " if flow_mask_dilates > 0:\n", + " flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=flow_mask_dilates).astype(np.uint8)\n", + " else:\n", + " flow_mask_img = binary_mask(mask_img).astype(np.uint8)\n", + " # Close the small holes inside the foreground objects\n", + " # flow_mask_img = cv2.morphologyEx(flow_mask_img, cv2.MORPH_CLOSE, np.ones((21, 21),np.uint8)).astype(bool)\n", + " # flow_mask_img = scipy.ndimage.binary_fill_holes(flow_mask_img).astype(np.uint8)\n", + " flow_masks.append(Image.fromarray(flow_mask_img * 255))\n", + "\n", + " if mask_dilates > 0:\n", + " mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=mask_dilates).astype(np.uint8)\n", + " else:\n", + " mask_img = binary_mask(mask_img).astype(np.uint8)\n", + " masks_dilated.append(Image.fromarray(mask_img * 255))\n", + "\n", + " if len(masks_img) == 1:\n", + " flow_masks = flow_masks * length\n", + " masks_dilated = masks_dilated * length\n", + "\n", + " return flow_masks, masks_dilated\n", + "\n", + "\n", + "def extrapolation(video_ori, scale):\n", + " \"\"\"Prepares the data for video outpainting.\n", + " \"\"\"\n", + " nFrame = len(video_ori)\n", + " imgW, imgH = video_ori[0].size\n", + "\n", + " # Defines new FOV.\n", + " imgH_extr = int(scale[0] * imgH)\n", + " imgW_extr = int(scale[1] * imgW)\n", + " imgH_extr = imgH_extr - imgH_extr % 8\n", + " imgW_extr = imgW_extr - imgW_extr % 8\n", + " H_start = int((imgH_extr - imgH) / 2)\n", + " W_start = int((imgW_extr - imgW) / 2)\n", + "\n", + " # Extrapolates the FOV for video.\n", + " frames = []\n", + " for v in video_ori:\n", + " frame = np.zeros(((imgH_extr, imgW_extr, 3)), dtype=np.uint8)\n", + " frame[H_start: H_start + imgH, W_start: W_start + imgW, :] = v\n", + " frames.append(Image.fromarray(frame))\n", + "\n", + " # Generates the mask for missing region.\n", + " masks_dilated = []\n", + " flow_masks = []\n", + "\n", + " dilate_h = 4 if H_start > 10 else 0\n", + " dilate_w = 4 if W_start > 10 else 0\n", + " mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.uint8)\n", + "\n", + " mask[H_start+dilate_h: H_start+imgH-dilate_h,\n", + " W_start+dilate_w: W_start+imgW-dilate_w] = 0\n", + " flow_masks.append(Image.fromarray(mask * 255))\n", + "\n", + " mask[H_start: H_start+imgH, W_start: W_start+imgW] = 0\n", + " masks_dilated.append(Image.fromarray(mask * 255))\n", + "\n", + " flow_masks = flow_masks * nFrame\n", + " masks_dilated = masks_dilated * nFrame\n", + "\n", + " return frames, flow_masks, masks_dilated, (imgW_extr, imgH_extr)\n", + "\n", + "\n", + "def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1):\n", + " ref_index = []\n", + " if ref_num == -1:\n", + " for i in range(0, length, ref_stride):\n", + " if i not in neighbor_ids:\n", + " ref_index.append(i)\n", + " else:\n", + " start_idx = max(0, mid_neighbor_id - ref_stride * (ref_num // 2))\n", + " end_idx = min(length, mid_neighbor_id + ref_stride * (ref_num // 2))\n", + " for i in range(start_idx, end_idx, ref_stride):\n", + " if i not in neighbor_ids:\n", + " if len(ref_index) > ref_num:\n", + " break\n", + " ref_index.append(i)\n", + " return ref_index\n", + "\n", + "\n", + "def display_video(video1, video2):\n", + " video1_path = video1\n", + " video2_path = video2\n", + "\n", + " # Read video files\n", + " video1 = open(video1_path, 'rb').read()\n", + " video2 = open(video2_path, 'rb').read()\n", + "\n", + " # Encode videos to base64\n", + " video1_base64 = b64encode(video1).decode()\n", + " video2_base64 = b64encode(video2).decode()\n", + "\n", + " # Create data URLs\n", + " video1_data_url = \"data:video/mp4;base64,\" + video1_base64\n", + " video2_data_url = \"data:video/mp4;base64,\" + video2_base64\n", + "\n", + " gap_width = 33\n", + "\n", + " # Embed videos with a gap in the middle in an HTML container\n", + " html_code = f\"\"\"\n", + "
\n", + " \n", + "
\n", + " \n", + "
\n", + " \"\"\"\n", + "\n", + " # Display the HTML container\n", + " return HTML(html_code)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qw8aW20ehtkE" + }, + "source": [ + "## Scoring Function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2023-09-27T14:31:52.403074Z", + "iopub.status.busy": "2023-09-27T14:31:52.401968Z", + "iopub.status.idle": "2023-09-27T14:32:37.140042Z", + "shell.execute_reply": "2023-09-27T14:32:37.138868Z", + "shell.execute_reply.started": "2023-09-27T14:31:52.403038Z" + }, + "id": "4dsgf9kAf173", + "trusted": true + }, + "outputs": [], + "source": [ + "def run_propainter(video = 'inputs/object_removal/bmx-trees',\n", + " mask = 'inputs/object_removal/bmx-trees_mask',\n", + " output = '../results',\n", + " mode = 'video_inpainting',\n", + " resize_ratio = 1.0,\n", + " height = -1,\n", + " width = -1,\n", + " mask_dilation = 4,\n", + " ref_stride = 10,\n", + " neighbor_length = 10,\n", + " subvideo_length = 80,\n", + " raft_iter = 20,\n", + " save_fps = 24,\n", + " use_half = False\n", + " ):\n", + " device = get_device()\n", + "\n", + " frames, fps, size, video_name = read_frame_from_videos(video)\n", + " if not width == -1 and not height == -1:\n", + " size = (width, height)\n", + " if not resize_ratio == 1.0:\n", + " size = (int(resize_ratio * size[0]), int(resize_ratio * size[1]))\n", + "\n", + " frames, size, out_size = resize_frames(frames, size)\n", + "\n", + " fps = 24\n", + "\n", + " save_root = os.path.join(output, video_name)\n", + " if not os.path.exists(save_root):\n", + " os.makedirs(save_root, exist_ok=True)\n", + "\n", + " if mode == 'video_inpainting':\n", + " frames_len = len(frames)\n", + " flow_masks, masks_dilated = read_mask(mask, frames_len, size,\n", + " flow_mask_dilates=mask_dilation,\n", + " mask_dilates=mask_dilation)\n", + " w, h = size\n", + " elif mode == 'video_outpainting':\n", + " assert args.scale_h is not None and args.scale_w is not None, 'Please provide a outpainting scale (s_h, s_w).'\n", + " frames, flow_masks, masks_dilated, size = extrapolation(frames, (args.scale_h, args.scale_w))\n", + " w, h = size\n", + " else:\n", + " raise NotImplementedError\n", + "\n", + " masked_frame_for_save = []\n", + " for i in range(len(frames)):\n", + " mask_ = np.expand_dims(np.array(masks_dilated[i]),2).repeat(3, axis=2)/255.\n", + " img = np.array(frames[i])\n", + " green = np.zeros([h, w, 3])\n", + " green[:,:,1] = 255\n", + " alpha = 0.6\n", + " # alpha = 1.0\n", + " fuse_img = (1-alpha)*img + alpha*green\n", + " fuse_img = mask_ * fuse_img + (1-mask_)*img\n", + " masked_frame_for_save.append(fuse_img.astype(np.uint8))\n", + "\n", + " frames_inp = [np.array(f).astype(np.uint8) for f in frames]\n", + " frames = to_tensors()(frames).unsqueeze(0) * 2 - 1\n", + " flow_masks = to_tensors()(flow_masks).unsqueeze(0)\n", + " masks_dilated = to_tensors()(masks_dilated).unsqueeze(0)\n", + " frames, flow_masks, masks_dilated = frames.to(device), flow_masks.to(device), masks_dilated.to(device)\n", + "\n", + " ##############################################\n", + " # set up RAFT and flow competition model\n", + " ##############################################\n", + " ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'raft-things.pth'),\n", + " model_dir='weights', progress=True, file_name=None)\n", + " fix_raft = RAFT_bi(ckpt_path, device)\n", + "\n", + " ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'recurrent_flow_completion.pth'),\n", + " model_dir='weights', progress=True, file_name=None)\n", + " fix_flow_complete = RecurrentFlowCompleteNet(ckpt_path)\n", + " for p in fix_flow_complete.parameters():\n", + " p.requires_grad = False\n", + " fix_flow_complete.to(device)\n", + " fix_flow_complete.eval()\n", + "\n", + " ##############################################\n", + " # set up ProPainter model\n", + " ##############################################\n", + " ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'ProPainter.pth'),\n", + " model_dir='weights', progress=True, file_name=None)\n", + " model = InpaintGenerator(model_path=ckpt_path).to(device)\n", + " model.eval()\n", + "\n", + " ##############################################\n", + " # ProPainter inference\n", + " ##############################################\n", + " video_length = frames.size(1)\n", + " print(f'\\nProcessing: {video_name} [{video_length} frames]...')\n", + " with torch.no_grad():\n", + " # ---- compute flow ----\n", + " if frames.size(-1) <= 640:\n", + " short_clip_len = 12\n", + " elif frames.size(-1) <= 720:\n", + " short_clip_len = 8\n", + " elif frames.size(-1) <= 1280:\n", + " short_clip_len = 4\n", + " else:\n", + " short_clip_len = 2\n", + "\n", + " # use fp32 for RAFT\n", + " if frames.size(1) > short_clip_len:\n", + " gt_flows_f_list, gt_flows_b_list = [], []\n", + " for f in range(0, video_length, short_clip_len):\n", + " end_f = min(video_length, f + short_clip_len)\n", + " if f == 0:\n", + " flows_f, flows_b = fix_raft(frames[:,f:end_f], iters=raft_iter)\n", + " else:\n", + " flows_f, flows_b = fix_raft(frames[:,f-1:end_f], iters=raft_iter)\n", + "\n", + " gt_flows_f_list.append(flows_f)\n", + " gt_flows_b_list.append(flows_b)\n", + " torch.cuda.empty_cache()\n", + "\n", + " gt_flows_f = torch.cat(gt_flows_f_list, dim=1)\n", + " gt_flows_b = torch.cat(gt_flows_b_list, dim=1)\n", + " gt_flows_bi = (gt_flows_f, gt_flows_b)\n", + " else:\n", + " gt_flows_bi = fix_raft(frames, iters=raft_iter)\n", + " torch.cuda.empty_cache()\n", + "\n", + "\n", + " if use_half:\n", + " frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half()\n", + " gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half())\n", + " fix_flow_complete = fix_flow_complete.half()\n", + " model = model.half()\n", + "\n", + "\n", + " # ---- complete flow ----\n", + " flow_length = gt_flows_bi[0].size(1)\n", + " if flow_length > subvideo_length:\n", + " pred_flows_f, pred_flows_b = [], []\n", + " pad_len = 5\n", + " for f in range(0, flow_length, subvideo_length):\n", + " s_f = max(0, f - pad_len)\n", + " e_f = min(flow_length, f + subvideo_length + pad_len)\n", + " pad_len_s = max(0, f) - s_f\n", + " pad_len_e = e_f - min(flow_length, f + subvideo_length)\n", + " pred_flows_bi_sub, _ = fix_flow_complete.forward_bidirect_flow(\n", + " (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),\n", + " flow_masks[:, s_f:e_f+1])\n", + " pred_flows_bi_sub = fix_flow_complete.combine_flow(\n", + " (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),\n", + " pred_flows_bi_sub,\n", + " flow_masks[:, s_f:e_f+1])\n", + "\n", + " pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])\n", + " pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])\n", + " torch.cuda.empty_cache()\n", + "\n", + " pred_flows_f = torch.cat(pred_flows_f, dim=1)\n", + " pred_flows_b = torch.cat(pred_flows_b, dim=1)\n", + " pred_flows_bi = (pred_flows_f, pred_flows_b)\n", + " else:\n", + " pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)\n", + " pred_flows_bi = fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)\n", + " torch.cuda.empty_cache()\n", + "\n", + "\n", + " # ---- image propagation ----\n", + " masked_frames = frames * (1 - masks_dilated)\n", + " subvideo_length_img_prop = min(100, subvideo_length) # ensure a minimum of 100 frames for image propagation\n", + " if video_length > subvideo_length_img_prop:\n", + " updated_frames, updated_masks = [], []\n", + " pad_len = 10\n", + " for f in range(0, video_length, subvideo_length_img_prop):\n", + " s_f = max(0, f - pad_len)\n", + " e_f = min(video_length, f + subvideo_length_img_prop + pad_len)\n", + " pad_len_s = max(0, f) - s_f\n", + " pad_len_e = e_f - min(video_length, f + subvideo_length_img_prop)\n", + "\n", + " b, t, _, _, _ = masks_dilated[:, s_f:e_f].size()\n", + " pred_flows_bi_sub = (pred_flows_bi[0][:, s_f:e_f-1], pred_flows_bi[1][:, s_f:e_f-1])\n", + " prop_imgs_sub, updated_local_masks_sub = model.img_propagation(masked_frames[:, s_f:e_f],\n", + " pred_flows_bi_sub,\n", + " masks_dilated[:, s_f:e_f],\n", + " 'nearest')\n", + " updated_frames_sub = frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) + \\\n", + " prop_imgs_sub.view(b, t, 3, h, w) * masks_dilated[:, s_f:e_f]\n", + " updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w)\n", + "\n", + " updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])\n", + " updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])\n", + " torch.cuda.empty_cache()\n", + "\n", + " updated_frames = torch.cat(updated_frames, dim=1)\n", + " updated_masks = torch.cat(updated_masks, dim=1)\n", + " else:\n", + " b, t, _, _, _ = masks_dilated.size()\n", + " prop_imgs, updated_local_masks = model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest')\n", + " updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated\n", + " updated_masks = updated_local_masks.view(b, t, 1, h, w)\n", + " torch.cuda.empty_cache()\n", + "\n", + "\n", + " ori_frames = frames_inp\n", + " comp_frames = [None] * video_length\n", + "\n", + " neighbor_stride = neighbor_length // 2\n", + " if video_length > subvideo_length:\n", + " ref_num = subvideo_length // ref_stride\n", + " else:\n", + " ref_num = -1\n", + "\n", + " # ---- feature propagation + transformer ----\n", + " for f in tqdm(range(0, video_length, neighbor_stride)):\n", + " neighbor_ids = [\n", + " i for i in range(max(0, f - neighbor_stride),\n", + " min(video_length, f + neighbor_stride + 1))\n", + " ]\n", + " ref_ids = get_ref_index(f, neighbor_ids, video_length, ref_stride, ref_num)\n", + " selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :]\n", + " selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :]\n", + " selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :]\n", + " selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :])\n", + "\n", + " with torch.no_grad():\n", + " # 1.0 indicates mask\n", + " l_t = len(neighbor_ids)\n", + "\n", + " # pred_img = selected_imgs # results of image propagation\n", + " pred_img = model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t)\n", + "\n", + " pred_img = pred_img.view(-1, 3, h, w)\n", + "\n", + " pred_img = (pred_img + 1) / 2\n", + " pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255\n", + " binary_masks = masks_dilated[0, neighbor_ids, :, :, :].cpu().permute(\n", + " 0, 2, 3, 1).numpy().astype(np.uint8)\n", + " for i in range(len(neighbor_ids)):\n", + " idx = neighbor_ids[i]\n", + " img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \\\n", + " + ori_frames[idx] * (1 - binary_masks[i])\n", + " if comp_frames[idx] is None:\n", + " comp_frames[idx] = img\n", + " else:\n", + " comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5\n", + "\n", + " comp_frames[idx] = comp_frames[idx].astype(np.uint8)\n", + "\n", + " torch.cuda.empty_cache()\n", + "\n", + " # save each frame\n", + " # if args.save_frames:\n", + " if True:\n", + " for idx in range(video_length):\n", + " f = comp_frames[idx]\n", + " f = cv2.resize(f, out_size, interpolation = cv2.INTER_CUBIC)\n", + " f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)\n", + " img_save_root = os.path.join(save_root, 'frames', str(idx).zfill(4)+'.png')\n", + " imwrite(f, img_save_root)\n", + "\n", + "\n", + " masked_frame_for_save = [cv2.resize(f, out_size) for f in masked_frame_for_save]\n", + " comp_frames = [cv2.resize(f, out_size) for f in comp_frames]\n", + " imageio.mimwrite(os.path.join(save_root, 'masked_in.mp4'), masked_frame_for_save, fps=fps, quality=7)\n", + " imageio.mimwrite(os.path.join(save_root, 'inpaint_out.mp4'), comp_frames, fps=fps, quality=7)\n", + "\n", + " print(f'\\nAll results are saved in {save_root}')\n", + "\n", + " torch.cuda.empty_cache()\n", + "\n", + " return save_root + '/masked_in.mp4', save_root + '/inpaint_out.mp4'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3sIZZJCjhwPJ" + }, + "source": [ + "## Inference Pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "execution": { + "iopub.execute_input": "2023-09-27T14:31:22.744611Z", + "iopub.status.busy": "2023-09-27T14:31:22.744166Z", + "iopub.status.idle": "2023-09-27T14:31:23.696534Z", + "shell.execute_reply": "2023-09-27T14:31:23.694180Z", + "shell.execute_reply.started": "2023-09-27T14:31:22.744579Z" + }, + "id": "V_r1s897f173", + "outputId": "ecd693c0-9d2f-42d8-eec1-94ea3a4bfdea", + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pretrained flow completion model has loaded...\n", + "Pretrained ProPainter has loaded...\n", + "Network [InpaintGenerator] was created. Total number of parameters: 39.4 million. To see the architecture, do print(network).\n", + "\n", + "Processing: bmx-trees [80 frames]...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 16/16 [00:17<00:00, 1.12s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "All results are saved in ../results/bmx-trees\n" + ] + } + ], + "source": [ + "video_path = 'inputs/object_removal/bmx-trees'\n", + "mask_path = 'inputs/object_removal/bmx-trees_mask'\n", + "output_path = '../results'\n", + "mode = 'video_inpainting'\n", + "\n", + "mask_result, inpaint_result = run_propainter(video = video_path,\n", + " mask = mask_path,\n", + " output = output_path,\n", + " mode = mode,\n", + " use_half = False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oLlPLTlomQG_" + }, + "source": [ + "## Demo Video" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 617 + }, + "id": "O6o6I4Ypii4L", + "outputId": "98e113e9-388b-421c-dbf9-1403ca2b5ae5" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + "
\n", + " \n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "video = display_video(mask_result, inpaint_result)\n", + "video" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DehsWqGClXkr" + }, + "source": [ + "Thank you for using this Colab notebook!\n", + "\n", + "If you found it helpful, please consider giving it a star on GitHub.\n", + "Your support is greatly appreciated!\n", + "\n", + "\n", + "[![GitHub stars](https://img.shields.io/github/stars/sczhou/ProPainter.svg?style=social&label=Star)](https://github.com/sczhou/ProPainter)\n", + "\n", + "Happy coding! 🚀" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/README.md b/README.md index b1a4cdce..8cf7de33 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,9 @@ - + + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/19uQfU5_Mj9bjrIDsZimVkH7l2P105AYb?usp=sharing)