From 04273c041448cd42aae37591713d10dae74998a4 Mon Sep 17 00:00:00 2001 From: ASIjacket <26643085+ASIjacket@users.noreply.github.com> Date: Sat, 31 Dec 2022 13:24:25 +0800 Subject: [PATCH] Update colab-demo.ipynb MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 实现视频处理的调用,使用 RunModel 变量控制图片模式和视频模式 优化各个参数输入,统一为变量 增加路径创建判断 --- Real-CUGAN/colab-demo.ipynb | 65 ++++++++++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 9 deletions(-) diff --git a/Real-CUGAN/colab-demo.ipynb b/Real-CUGAN/colab-demo.ipynb index 92b82c4..73f6859 100644 --- a/Real-CUGAN/colab-demo.ipynb +++ b/Real-CUGAN/colab-demo.ipynb @@ -24,10 +24,19 @@ "PendingPath=ROOTPATH+\"Real-CUGAN/pending/\" # input dir\n", "FinishPath=ROOTPATH+\"Real-CUGAN/finish/\" # output dir\n", "ModelName=\"up2x-latest-no-denoise.pth\" # default model\n", - "Tile=4 #{0,1,2,3,4,auto}; the larger the number, the smaller the memory consumption\n", + "Tile=5 #{0,1,2,3,4,auto}; the larger the number, the smaller the memory consumption\n", + "Scale=2\n", + "NT=2 #线程数\n", + "CacheModel=0\n", + "RunModel=0 #运行模式 0=image 1=video\n", + "DeviceBase=\"cuda:0\"\n", + "Half=True\n", + "\n", "\n", "# initialize environment\n", "!pip install torch opencv-python\n", + "!pip install imageio-ffmpeg\n", + "!pip install imageio==2.4.1\n", "!git clone https://github.com/bilibili/ailab.git\n", "from google.colab import drive\n", "drive.mount('/content/gdrive')" @@ -55,9 +64,13 @@ "outputs": [], "source": [ "import os\n", - "os.mkdir(ModelPath)\n", - "os.mkdir(PendingPath)\n", - "os.mkdir(FinishPath)\n", + "if not os.path.exists(ModelPath):\n", + " os.mkdir(ModelPath)\n", + "if not os.path.exists(PendingPath):\n", + " os.mkdir(PendingPath)\n", + "if not os.path.exists(FinishPath):\n", + " os.mkdir(FinishPath)\n", + "\n", "!cp -r /content/gdrive/MyDrive/updated_weights/* /content/ailab/Real-CUGAN/model/\n", "fileNames = os.listdir(PendingPath)\n", "print(\"Pending images:\")\n", @@ -102,17 +115,51 @@ "import os,sys,cv2\n", "import numpy as np\n", "from upcunet_v3 import RealWaifuUpScaler\n", + "from inference_video import VideoRealWaifuUpScaler\n", + "from inference_video import UpScalerMT\n", + "from multiprocessing import Queue\n", "from time import time as ttime \n", "fileNames = os.listdir(PendingPath)\n", "print(f\"using model {ModelPath+ModelName}\")\n", - "upscaler = RealWaifuUpScaler(2, ModelPath+ModelName, half=True, device=\"cuda:0\")\n", + "\n", + "upscaler = RealWaifuUpScaler(Scale, ModelPath+ModelName, half=Half, device=DeviceBase)\n", + "\n", + "# Overwrite VideoRealWaifuUpScaler\n", + "class VideoRealWaifuUpScaler_kai(VideoRealWaifuUpScaler):\n", + " def __init__(self):\n", + " self.nt = NT\n", + " self.n_gpu=1 # 每块GPU开nt个进程\n", + " self.scale=Scale\n", + " self.encode_params = ['-crf', '21', '-preset', 'medium']\n", + " self.p_sleep=(0.005,0.012)\n", + " self.decode_sleep=0.002\n", + " self.half=Half\n", + " self.tile=Tile\n", + " self.cache_mode=CacheModel\n", + " self.alpha=1\n", + " self.device=DeviceBase\n", + "\n", + " video_model = RealWaifuUpScaler(Scale, ModelPath+ModelName, Half, DeviceBase)\n", + " self.inp_q = Queue(NT * 2)\n", + " self.res_q = Queue(NT * 2)\n", + " for _ in range(NT):\n", + " upscaler = UpScalerMT(self.inp_q, self.res_q, self.device, video_model, self.p_sleep, self.nt, self.tile, self.cache_mode, self.alpha)\n", + " upscaler.start()\n", + "\n", + "video_upscaler=VideoRealWaifuUpScaler_kai()\n", + "\n", "t0 = ttime()\n", "for i in fileNames:\n", " torch.cuda.empty_cache()\n", " try:\n", - " img = cv2.imread(PendingPath+i)[:, :, [2, 1, 0]]\n", - " result = upscaler(img,tile_mode=5,cache_mode=2,alpha=1)\n", - " cv2.imwrite(FinishPath+i,result[:, :, ::-1])\n", + " if RunModel == 0:\n", + " img = cv2.imread(PendingPath+i)[:, :, [2, 1, 0]]\n", + " result = upscaler(img,tile_mode=5,cache_mode=2,alpha=1)\n", + " cv2.imwrite(FinishPath+i,result[:, :, ::-1])\n", + " elif RunModel == 1:\n", + " video_upscaler(PendingPath+i, FinishPath+i)\n", + " else:\n", + " print(\"Warning: Error RunModel Input\")\n", " except RuntimeError as e:\n", " print (i+\" FAILED\")\n", " print (e)\n", @@ -120,7 +167,7 @@ " print(i+\" DONE\")\n", " os.remove(PendingPath+i)\n", "t1 = ttime()\n", - "print(\"time_spent\", t1 - t0)" + "print(\"time_spent\", t1 - t0)\n" ] }, {