diff --git a/configs/train_propainter.json b/configs/train_propainter.json index c0c29ba7..cd46bf24 100644 --- a/configs/train_propainter.json +++ b/configs/train_propainter.json @@ -1,48 +1,47 @@ { - "seed": 2023, - "save_dir": "experiments_model/", - "train_data_loader": { - "name": "youtube-vos", - "video_root": "your_video_root", - "flow_root": "your_flow_root", - "w": 432, - "h": 240, - "num_local_frames": 10, - "num_ref_frames": 6, - "load_flow": 0 - }, - "losses": { - "hole_weight": 1, - "valid_weight": 1, - "flow_weight": 1, - "adversarial_weight": 0.01, - "GAN_LOSS": "hinge", - "perceptual_weight": 0 - }, - "model": { - "net": "propainter", - "no_dis": 0, - "load_d": 1, - "interp_mode": "nearest" - }, - "trainer": { - "version": "trainer", - "type": "Adam", - "beta1": 0, - "beta2": 0.99, - "lr": 1e-4, - "batch_size": 8, - "num_workers": 8, - "num_prefetch_queue": 8, - "log_freq": 100, - "save_freq": 1e4, - "iterations": 700e3, - "scheduler": { - "type": "MultiStepLR", - "milestones": [ - 400e3 - ], - "gamma": 0.1 - } + "seed": 2023, + "save_dir": "experiments_model/", + "train_data_loader": { + "name": "davis", + "video_root": "your_video_root", + "flow_root": "your_flow_root", + "w": 432, + "h": 240, + "num_local_frames": 10, + "num_ref_frames": 6, + "load_flow": 0 + }, + "losses": { + "GAN_LOSS": "hinge", + "hole_weight": 1.0, + "valid_weight": 1.0, + "perceptual_weight": 0.1, + "adversarial_weight": 0.01, + "ffl_weight": 0.05, + "ffl_alpha": 1.0 + }, + "model": { + "net": "propainter", + "no_dis": 0, + "load_d": 1, + "interp_mode": "nearest" + }, + "trainer": { + "version": "trainer", + "type": "Adam", + "beta1": 0, + "beta2": 0.99, + "lr": 1e-4, + "batch_size": 8, + "num_workers": 8, + "num_prefetch_queue": 8, + "log_freq": 100, + "save_freq": 1e4, + "iterations": 700e3, + "scheduler": { + "type": "MultiStepLR", + "milestones": [400e3], + "gamma": 0.1 } -} \ No newline at end of file + } +} diff --git a/core/loss.py b/core/loss.py index b1d94d0c..f0276af9 100644 --- a/core/loss.py +++ b/core/loss.py @@ -1,8 +1,10 @@ +import lpips import torch import torch.nn as nn -import lpips + from model.vgg_arch import VGGFeatureExtractor + class PerceptualLoss(nn.Module): """Perceptual loss with commonly used style loss. @@ -26,14 +28,16 @@ class PerceptualLoss(nn.Module): criterion (str): Criterion used for perceptual loss. Default: 'l1'. """ - def __init__(self, - layer_weights, - vgg_type='vgg19', - use_input_norm=True, - range_norm=False, - perceptual_weight=1.0, - style_weight=0., - criterion='l1'): + def __init__( + self, + layer_weights, + vgg_type="vgg19", + use_input_norm=True, + range_norm=False, + perceptual_weight=1.0, + style_weight=0.0, + criterion="l1", + ): super(PerceptualLoss, self).__init__() self.perceptual_weight = perceptual_weight self.style_weight = style_weight @@ -42,19 +46,20 @@ def __init__(self, layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, use_input_norm=use_input_norm, - range_norm=range_norm) + range_norm=range_norm, + ) self.criterion_type = criterion - if self.criterion_type == 'l1': + if self.criterion_type == "l1": self.criterion = torch.nn.L1Loss() - elif self.criterion_type == 'l2': + elif self.criterion_type == "l2": self.criterion = torch.nn.L2loss() - elif self.criterion_type == 'mse': - self.criterion = torch.nn.MSELoss(reduction='mean') - elif self.criterion_type == 'fro': + elif self.criterion_type == "mse": + self.criterion = torch.nn.MSELoss(reduction="mean") + elif self.criterion_type == "fro": self.criterion = None else: - raise NotImplementedError(f'{criterion} criterion has not been supported.') + raise NotImplementedError(f"{criterion} criterion has not been supported.") def forward(self, x, gt): """Forward function. @@ -74,10 +79,16 @@ def forward(self, x, gt): if self.perceptual_weight > 0: percep_loss = 0 for k in x_features.keys(): - if self.criterion_type == 'fro': - percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k] + if self.criterion_type == "fro": + percep_loss += ( + torch.norm(x_features[k] - gt_features[k], p="fro") + * self.layer_weights[k] + ) else: - percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss += ( + self.criterion(x_features[k], gt_features[k]) + * self.layer_weights[k] + ) percep_loss *= self.perceptual_weight else: percep_loss = None @@ -86,12 +97,23 @@ def forward(self, x, gt): if self.style_weight > 0: style_loss = 0 for k in x_features.keys(): - if self.criterion_type == 'fro': - style_loss += torch.norm( - self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k] + if self.criterion_type == "fro": + style_loss += ( + torch.norm( + self._gram_mat(x_features[k]) + - self._gram_mat(gt_features[k]), + p="fro", + ) + * self.layer_weights[k] + ) else: - style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat( - gt_features[k])) * self.layer_weights[k] + style_loss += ( + self.criterion( + self._gram_mat(x_features[k]), + self._gram_mat(gt_features[k]), + ) + * self.layer_weights[k] + ) style_loss *= self.style_weight else: style_loss = None @@ -113,11 +135,14 @@ def _gram_mat(self, x): gram = features.bmm(features_t) / (c * h * w) return gram + class LPIPSLoss(nn.Module): - def __init__(self, - loss_weight=1.0, - use_input_norm=True, - range_norm=False,): + def __init__( + self, + loss_weight=1.0, + use_input_norm=True, + range_norm=False, + ): super(LPIPSLoss, self).__init__() self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval() self.loss_weight = loss_weight @@ -126,16 +151,20 @@ def __init__(self, if self.use_input_norm: # the mean is for image with range [0, 1] - self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer( + "mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + ) # the std is for image with range [0, 1] - self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + self.register_buffer( + "std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + ) def forward(self, pred, target): if self.range_norm: - pred = (pred + 1) / 2 + pred = (pred + 1) / 2 target = (target + 1) / 2 if self.use_input_norm: - pred = (pred - self.mean) / self.std + pred = (pred - self.mean) / self.std target = (target - self.mean) / self.std lpips_loss = self.perceptual(target.contiguous(), pred.contiguous()) return self.loss_weight * lpips_loss.mean(), None @@ -146,27 +175,25 @@ class AdversarialLoss(nn.Module): Adversarial loss https://arxiv.org/abs/1711.10337 """ - def __init__(self, - type='nsgan', - target_real_label=1.0, - target_fake_label=0.0): + + def __init__(self, type="nsgan", target_real_label=1.0, target_fake_label=0.0): r""" type = nsgan | lsgan | hinge """ super(AdversarialLoss, self).__init__() self.type = type - self.register_buffer('real_label', torch.tensor(target_real_label)) - self.register_buffer('fake_label', torch.tensor(target_fake_label)) + self.register_buffer("real_label", torch.tensor(target_real_label)) + self.register_buffer("fake_label", torch.tensor(target_fake_label)) - if type == 'nsgan': + if type == "nsgan": self.criterion = nn.BCELoss() - elif type == 'lsgan': + elif type == "lsgan": self.criterion = nn.MSELoss() - elif type == 'hinge': + elif type == "hinge": self.criterion = nn.ReLU() def __call__(self, outputs, is_real, is_disc=None): - if self.type == 'hinge': + if self.type == "hinge": if is_disc: if is_real: outputs = -outputs @@ -174,7 +201,109 @@ def __call__(self, outputs, is_real, is_disc=None): else: return (-outputs).mean() else: - labels = (self.real_label - if is_real else self.fake_label).expand_as(outputs) + labels = (self.real_label if is_real else self.fake_label).expand_as( + outputs + ) loss = self.criterion(outputs, labels) return loss + + +# ============================================================================= +# Focal Frequency Loss (Jiang et al., ICCV 2021) +# Applied ONLY to the high-frequency residual of the HOLE region. +# This avoids inter-frequency gradient conflict as described in DRCN (IJCAI 2025). +# ============================================================================= +import torch.fft +import torch.nn.functional as F + + +def _gaussian_blur_for_ffl(x, kernel_size=5, sigma=1.0): + """Depthwise Gaussian blur — extracts low-frequency component of x.""" + B, C, H, W = x.shape + coords = torch.arange(kernel_size, device=x.device).float() - kernel_size // 2 + g = torch.exp(-(coords**2) / (2 * sigma**2)) + g = g / g.sum() + kernel = g[:, None] * g[None, :] # (k, k) + kernel = kernel.expand(C, 1, kernel_size, kernel_size) # (C, 1, k, k) + pad = kernel_size // 2 + return F.conv2d(x, kernel, padding=pad, groups=C) + + +def laplacian_decompose_ffl(x, kernel_size=5, sigma=1.0): + """ + Split frame into low-frequency and high-frequency components. + Args: + x (Tensor): (B, C, H, W) float tensor, any value range. + Returns: + low (Tensor): blurred / low-frequency version of x. + high (Tensor): residual = x - low (edges, textures). + """ + low = _gaussian_blur_for_ffl(x, kernel_size, sigma) + high = x - low + return low, high + + +class FocalFrequencyLoss(nn.Module): + """ + Focal Frequency Loss — operates on the HIGH-FREQUENCY component only. + + The focal weighting dynamically up-weights frequency components that + the model reconstructs poorly, pushing training to focus on hard cases. + + Args: + loss_weight (float): Scalar multiplier on final loss. Default: 1.0. + (Controlled externally via config ffl_weight.) + alpha (float): Focusing exponent. Higher = more aggressive + focus on poorly-reconstructed frequencies. + Default: 1.0 (from original paper). + kernel_size (int): Gaussian kernel size for decomposition. Default: 5. + sigma (float): Gaussian sigma for decomposition. Default: 1.0. + """ + + def __init__(self, loss_weight=1.0, alpha=1.0, kernel_size=5, sigma=1.0): + super(FocalFrequencyLoss, self).__init__() + self.loss_weight = loss_weight + self.alpha = alpha + self.kernel_size = kernel_size + self.sigma = sigma + + def _focal_freq_core(self, pred_freq, real_freq): + """ + Focal-weighted frequency-domain L2 between pred and real. + Args: + pred_freq, real_freq: (..., H, W, 2) — stacked real/imag parts. + """ + diff = (pred_freq - real_freq) ** 2 + freq_dist = diff[..., 0] + diff[..., 1] # real² + imag² + # Dynamic weight: frequencies with high error get higher weight + weight = torch.exp(self.alpha * freq_dist.detach() ** 0.5) + return torch.mean(weight * freq_dist) + + def forward(self, pred, target): + """ + Args: + pred (Tensor): predicted hole region (B, C, H, W), range [-1,1] + target (Tensor): ground-truth hole region (B, C, H, W), range [-1,1] + + NOTE: Both pred and target should already be masked (multiplied by mask) + before being passed in. Do NOT pass the full frame here. + + Returns: + Scalar loss tensor. + """ + # Step 1: Decompose → keep only HIGH-frequency residual + _, pred_high = laplacian_decompose_ffl(pred, self.kernel_size, self.sigma) + _, target_high = laplacian_decompose_ffl(target, self.kernel_size, self.sigma) + + # Step 2: 2-D FFT on spatial dimensions + pred_f = torch.fft.fft2(pred_high, norm="ortho") + target_f = torch.fft.fft2(target_high, norm="ortho") + + # Step 3: Shift DC component to center; stack real/imag → (..., H, W, 2) + pred_f = torch.fft.fftshift(pred_f, dim=(-2, -1)) + target_f = torch.fft.fftshift(target_f, dim=(-2, -1)) + pred_f = torch.stack([pred_f.real, pred_f.imag], dim=-1) + target_f = torch.stack([target_f.real, target_f.imag], dim=-1) + + # Step 4: Focal-weighted loss × overall weight + return self._focal_freq_core(pred_f, target_f) * self.loss_weight diff --git a/core/trainer.py b/core/trainer.py index 5e6b6a66..d9f99141 100644 --- a/core/trainer.py +++ b/core/trainer.py @@ -1,25 +1,25 @@ -import os import glob -import logging import importlib -from tqdm import tqdm +import logging +import os import torch import torch.nn as nn import torch.nn.functional as F -from core.prefetch_dataloader import PrefetchDataLoader, CPUPrefetcher -from torch.utils.data.distributed import DistributedSampler -from torch.nn.parallel import DistributedDataParallel as DDP import torchvision +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm -from core.lr_scheduler import MultiStepRestartLR, CosineAnnealingRestartLR -from core.loss import AdversarialLoss, PerceptualLoss, LPIPSLoss from core.dataset import TrainDataset -from model.modules.flow_comp_raft import RAFT_bi, FlowLoss, EdgeLoss +# from core.loss import AdversarialLoss, PerceptualLoss, LPIPSLoss +from core.loss import AdversarialLoss, FocalFrequencyLoss, LPIPSLoss, PerceptualLoss +from core.lr_scheduler import CosineAnnealingRestartLR, MultiStepRestartLR +from core.prefetch_dataloader import CPUPrefetcher, PrefetchDataLoader +from model.modules.flow_comp_raft import EdgeLoss, FlowLoss, RAFT_bi from model.recurrent_flow_completion import RecurrentFlowCompleteNet - from RAFT.utils.flow_viz_pt import flow_to_image @@ -28,103 +28,119 @@ def __init__(self, config): self.config = config self.epoch = 0 self.iteration = 0 - self.num_local_frames = config['train_data_loader']['num_local_frames'] - self.num_ref_frames = config['train_data_loader']['num_ref_frames'] + self.num_local_frames = config["train_data_loader"]["num_local_frames"] + self.num_ref_frames = config["train_data_loader"]["num_ref_frames"] # setup data set and data loader - self.train_dataset = TrainDataset(config['train_data_loader']) + self.train_dataset = TrainDataset(config["train_data_loader"]) self.train_sampler = None - self.train_args = config['trainer'] - if config['distributed']: + self.train_args = config["trainer"] + if config["distributed"]: self.train_sampler = DistributedSampler( self.train_dataset, - num_replicas=config['world_size'], - rank=config['global_rank']) + num_replicas=config["world_size"], + rank=config["global_rank"], + ) dataloader_args = dict( dataset=self.train_dataset, - batch_size=self.train_args['batch_size'] // config['world_size'], + batch_size=self.train_args["batch_size"] // config["world_size"], shuffle=(self.train_sampler is None), - num_workers=self.train_args['num_workers'], + num_workers=self.train_args["num_workers"], sampler=self.train_sampler, - drop_last=True) + drop_last=True, + ) - self.train_loader = PrefetchDataLoader(self.train_args['num_prefetch_queue'], **dataloader_args) + self.train_loader = PrefetchDataLoader( + self.train_args["num_prefetch_queue"], **dataloader_args + ) self.prefetcher = CPUPrefetcher(self.train_loader) # set loss functions - self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS']) - self.adversarial_loss = self.adversarial_loss.to(self.config['device']) + self.adversarial_loss = AdversarialLoss(type=self.config["losses"]["GAN_LOSS"]) + self.adversarial_loss = self.adversarial_loss.to(self.config["device"]) self.l1_loss = nn.L1Loss() # self.perc_loss = PerceptualLoss( - # layer_weights={'conv3_4': 0.25, 'conv4_4': 0.25, 'conv5_4': 0.5}, + # layer_weights={'conv3_4': 0.25, 'conv4_4': 0.25, 'conv5_4': 0.5}, # use_input_norm=True, # range_norm=True, # criterion='l1' # ).to(self.config['device']) - if self.config['losses']['perceptual_weight'] > 0: - self.perc_loss = LPIPSLoss(use_input_norm=True, range_norm=True).to(self.config['device']) - + if self.config["losses"]["perceptual_weight"] > 0: + self.perc_loss = LPIPSLoss(use_input_norm=True, range_norm=True).to( + self.config["device"] + ) + # ── Focal Frequency Loss — instantiated only if weight > 0 ─────────── + if self.config["losses"].get("ffl_weight", 0) > 0: + self.ffl_loss = FocalFrequencyLoss( + loss_weight=self.config["losses"]["ffl_weight"], + alpha=self.config["losses"].get("ffl_alpha", 1.0), + ).to(self.config["device"]) + # ───────────────────────────────────────────────────────────────────── # self.flow_comp_loss = FlowCompletionLoss().to(self.config['device']) # self.flow_comp_loss = FlowCompletionLoss(self.config['device']) # set raft - self.fix_raft = RAFT_bi(device = self.config['device']) - self.fix_flow_complete = RecurrentFlowCompleteNet('weights/recurrent_flow_completion.pth') + self.fix_raft = RAFT_bi(device=self.config["device"]) + self.fix_flow_complete = RecurrentFlowCompleteNet( + "weights/recurrent_flow_completion.pth" + ) for p in self.fix_flow_complete.parameters(): p.requires_grad = False - self.fix_flow_complete.to(self.config['device']) + self.fix_flow_complete.to(self.config["device"]) self.fix_flow_complete.eval() # self.flow_loss = FlowLoss() # setup models including generator and discriminator - net = importlib.import_module('model.' + config['model']['net']) + net = importlib.import_module("model." + config["model"]["net"]) self.netG = net.InpaintGenerator() # print(self.netG) - self.netG = self.netG.to(self.config['device']) - if not self.config['model'].get('no_dis', False): - if self.config['model'].get('dis_2d', False): + self.netG = self.netG.to(self.config["device"]) + if not self.config["model"].get("no_dis", False): + if self.config["model"].get("dis_2d", False): self.netD = net.Discriminator_2D( - in_channels=3, - use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge') + in_channels=3, use_sigmoid=config["losses"]["GAN_LOSS"] != "hinge" + ) else: - self.netD = net.Discriminator( - in_channels=3, - use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge') - self.netD = self.netD.to(self.config['device']) - - self.interp_mode = self.config['model']['interp_mode'] + self.netD = net.Discriminator( + in_channels=3, use_sigmoid=config["losses"]["GAN_LOSS"] != "hinge" + ) + self.netD = self.netD.to(self.config["device"]) + + self.interp_mode = self.config["model"]["interp_mode"] # setup optimizers and schedulers self.setup_optimizers() self.setup_schedulers() self.load() - if config['distributed']: - self.netG = DDP(self.netG, - device_ids=[self.config['local_rank']], - output_device=self.config['local_rank'], - broadcast_buffers=True, - find_unused_parameters=True) - if not self.config['model']['no_dis']: - self.netD = DDP(self.netD, - device_ids=[self.config['local_rank']], - output_device=self.config['local_rank'], - broadcast_buffers=True, - find_unused_parameters=False) + if config["distributed"]: + self.netG = DDP( + self.netG, + device_ids=[self.config["local_rank"]], + output_device=self.config["local_rank"], + broadcast_buffers=True, + find_unused_parameters=True, + ) + if not self.config["model"]["no_dis"]: + self.netD = DDP( + self.netD, + device_ids=[self.config["local_rank"]], + output_device=self.config["local_rank"], + broadcast_buffers=True, + find_unused_parameters=False, + ) # set summary writer self.dis_writer = None self.gen_writer = None self.summary = {} - if self.config['global_rank'] == 0 or (not config['distributed']): - if not self.config['model']['no_dis']: - self.dis_writer = SummaryWriter( - os.path.join(config['save_dir'], 'dis')) - self.gen_writer = SummaryWriter( - os.path.join(config['save_dir'], 'gen')) + if self.config["global_rank"] == 0 or (not config["distributed"]): + if not self.config["model"]["no_dis"]: + self.dis_writer = SummaryWriter(os.path.join(config["save_dir"], "dis")) + self.gen_writer = SummaryWriter(os.path.join(config["save_dir"], "gen")) def setup_optimizers(self): """Set up optimizers.""" @@ -133,73 +149,79 @@ def setup_optimizers(self): if param.requires_grad: backbone_params.append(param) else: - print(f'Params {name} will not be optimized.') - + print(f"Params {name} will not be optimized.") + optim_params = [ - { - 'params': backbone_params, - 'lr': self.config['trainer']['lr'] - }, + {"params": backbone_params, "lr": self.config["trainer"]["lr"]}, ] - self.optimG = torch.optim.Adam(optim_params, - betas=(self.config['trainer']['beta1'], - self.config['trainer']['beta2'])) + self.optimG = torch.optim.Adam( + optim_params, + betas=(self.config["trainer"]["beta1"], self.config["trainer"]["beta2"]), + ) - if not self.config['model']['no_dis']: + if not self.config["model"]["no_dis"]: self.optimD = torch.optim.Adam( self.netD.parameters(), - lr=self.config['trainer']['lr'], - betas=(self.config['trainer']['beta1'], - self.config['trainer']['beta2'])) + lr=self.config["trainer"]["lr"], + betas=( + self.config["trainer"]["beta1"], + self.config["trainer"]["beta2"], + ), + ) def setup_schedulers(self): """Set up schedulers.""" - scheduler_opt = self.config['trainer']['scheduler'] - scheduler_type = scheduler_opt.pop('type') + scheduler_opt = self.config["trainer"]["scheduler"] + scheduler_type = scheduler_opt.pop("type") - if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: + if scheduler_type in ["MultiStepLR", "MultiStepRestartLR"]: self.scheG = MultiStepRestartLR( self.optimG, - milestones=scheduler_opt['milestones'], - gamma=scheduler_opt['gamma']) - if not self.config['model']['no_dis']: + milestones=scheduler_opt["milestones"], + gamma=scheduler_opt["gamma"], + ) + if not self.config["model"]["no_dis"]: self.scheD = MultiStepRestartLR( self.optimD, - milestones=scheduler_opt['milestones'], - gamma=scheduler_opt['gamma']) - elif scheduler_type == 'CosineAnnealingRestartLR': + milestones=scheduler_opt["milestones"], + gamma=scheduler_opt["gamma"], + ) + elif scheduler_type == "CosineAnnealingRestartLR": self.scheG = CosineAnnealingRestartLR( self.optimG, - periods=scheduler_opt['periods'], - restart_weights=scheduler_opt['restart_weights'], - eta_min=scheduler_opt['eta_min']) - if not self.config['model']['no_dis']: + periods=scheduler_opt["periods"], + restart_weights=scheduler_opt["restart_weights"], + eta_min=scheduler_opt["eta_min"], + ) + if not self.config["model"]["no_dis"]: self.scheD = CosineAnnealingRestartLR( self.optimD, - periods=scheduler_opt['periods'], - restart_weights=scheduler_opt['restart_weights'], - eta_min=scheduler_opt['eta_min']) + periods=scheduler_opt["periods"], + restart_weights=scheduler_opt["restart_weights"], + eta_min=scheduler_opt["eta_min"], + ) else: raise NotImplementedError( - f'Scheduler {scheduler_type} is not implemented yet.') + f"Scheduler {scheduler_type} is not implemented yet." + ) def update_learning_rate(self): """Update learning rate.""" self.scheG.step() - if not self.config['model']['no_dis']: + if not self.config["model"]["no_dis"]: self.scheD.step() def get_lr(self): """Get current learning rate.""" - return self.optimG.param_groups[0]['lr'] + return self.optimG.param_groups[0]["lr"] def add_summary(self, writer, name, val): """Add tensorboard summary.""" if name not in self.summary: self.summary[name] = 0 self.summary[name] += val - n = self.train_args['log_freq'] + n = self.train_args["log_freq"] if writer is not None and self.iteration % n == 0: writer.add_scalar(name, self.summary[name] / n, self.iteration) self.summary[name] = 0 @@ -207,149 +229,160 @@ def add_summary(self, writer, name, val): def load(self): """Load netG (and netD).""" # get the latest checkpoint - model_path = self.config['save_dir'] + model_path = self.config["save_dir"] # TODO: add resume name - if os.path.isfile(os.path.join(model_path, 'latest.ckpt')): - latest_epoch = open(os.path.join(model_path, 'latest.ckpt'), - 'r').read().splitlines()[-1] + if os.path.isfile(os.path.join(model_path, "latest.ckpt")): + latest_epoch = ( + open(os.path.join(model_path, "latest.ckpt"), "r") + .read() + .splitlines()[-1] + ) else: ckpts = [ - os.path.basename(i).split('.pth')[0] - for i in glob.glob(os.path.join(model_path, '*.pth')) + os.path.basename(i).split(".pth")[0] + for i in glob.glob(os.path.join(model_path, "*.pth")) ] ckpts.sort() latest_epoch = ckpts[-1][4:] if len(ckpts) > 0 else None if latest_epoch is not None: - gen_path = os.path.join(model_path, - f'gen_{int(latest_epoch):06d}.pth') - dis_path = os.path.join(model_path, - f'dis_{int(latest_epoch):06d}.pth') - opt_path = os.path.join(model_path, - f'opt_{int(latest_epoch):06d}.pth') - - if self.config['global_rank'] == 0: - print(f'Loading model from {gen_path}...') - dataG = torch.load(gen_path, map_location=self.config['device']) + gen_path = os.path.join(model_path, f"gen_{int(latest_epoch):06d}.pth") + dis_path = os.path.join(model_path, f"dis_{int(latest_epoch):06d}.pth") + opt_path = os.path.join(model_path, f"opt_{int(latest_epoch):06d}.pth") + + if self.config["global_rank"] == 0: + print(f"Loading model from {gen_path}...") + dataG = torch.load(gen_path, map_location=self.config["device"]) self.netG.load_state_dict(dataG) - if not self.config['model']['no_dis'] and self.config['model']['load_d']: - dataD = torch.load(dis_path, map_location=self.config['device']) + if not self.config["model"]["no_dis"] and self.config["model"]["load_d"]: + dataD = torch.load(dis_path, map_location=self.config["device"]) self.netD.load_state_dict(dataD) - data_opt = torch.load(opt_path, map_location=self.config['device']) - self.optimG.load_state_dict(data_opt['optimG']) + data_opt = torch.load(opt_path, map_location=self.config["device"]) + self.optimG.load_state_dict(data_opt["optimG"]) # self.scheG.load_state_dict(data_opt['scheG']) - if not self.config['model']['no_dis'] and self.config['model']['load_d']: - self.optimD.load_state_dict(data_opt['optimD']) + if not self.config["model"]["no_dis"] and self.config["model"]["load_d"]: + self.optimD.load_state_dict(data_opt["optimD"]) # self.scheD.load_state_dict(data_opt['scheD']) - self.epoch = data_opt['epoch'] - self.iteration = data_opt['iteration'] + self.epoch = data_opt["epoch"] + self.iteration = data_opt["iteration"] else: - gen_path = self.config['trainer'].get('gen_path', None) - dis_path = self.config['trainer'].get('dis_path', None) - opt_path = self.config['trainer'].get('opt_path', None) + gen_path = self.config["trainer"].get("gen_path", None) + dis_path = self.config["trainer"].get("dis_path", None) + opt_path = self.config["trainer"].get("opt_path", None) if gen_path is not None: - if self.config['global_rank'] == 0: - print(f'Loading Gen-Net from {gen_path}...') - dataG = torch.load(gen_path, map_location=self.config['device']) + if self.config["global_rank"] == 0: + print(f"Loading Gen-Net from {gen_path}...") + dataG = torch.load(gen_path, map_location=self.config["device"]) self.netG.load_state_dict(dataG) - - if dis_path is not None and not self.config['model']['no_dis'] and self.config['model']['load_d']: - if self.config['global_rank'] == 0: - print(f'Loading Dis-Net from {dis_path}...') - dataD = torch.load(dis_path, map_location=self.config['device']) + + if ( + dis_path is not None + and not self.config["model"]["no_dis"] + and self.config["model"]["load_d"] + ): + if self.config["global_rank"] == 0: + print(f"Loading Dis-Net from {dis_path}...") + dataD = torch.load(dis_path, map_location=self.config["device"]) self.netD.load_state_dict(dataD) if opt_path is not None: - data_opt = torch.load(opt_path, map_location=self.config['device']) - self.optimG.load_state_dict(data_opt['optimG']) - self.scheG.load_state_dict(data_opt['scheG']) - if not self.config['model']['no_dis'] and self.config['model']['load_d']: - self.optimD.load_state_dict(data_opt['optimD']) - self.scheD.load_state_dict(data_opt['scheD']) + data_opt = torch.load(opt_path, map_location=self.config["device"]) + self.optimG.load_state_dict(data_opt["optimG"]) + self.scheG.load_state_dict(data_opt["scheG"]) + if ( + not self.config["model"]["no_dis"] + and self.config["model"]["load_d"] + ): + self.optimD.load_state_dict(data_opt["optimD"]) + self.scheD.load_state_dict(data_opt["scheD"]) else: - if self.config['global_rank'] == 0: - print('Warnning: There is no trained model found.' - 'An initialized model will be used.') + if self.config["global_rank"] == 0: + print( + "Warnning: There is no trained model found." + "An initialized model will be used." + ) def save(self, it): """Save parameters every eval_epoch""" - if self.config['global_rank'] == 0: + if self.config["global_rank"] == 0: # configure path - gen_path = os.path.join(self.config['save_dir'], - f'gen_{it:06d}.pth') - dis_path = os.path.join(self.config['save_dir'], - f'dis_{it:06d}.pth') - opt_path = os.path.join(self.config['save_dir'], - f'opt_{it:06d}.pth') - print(f'\nsaving model to {gen_path} ...') + gen_path = os.path.join(self.config["save_dir"], f"gen_{it:06d}.pth") + dis_path = os.path.join(self.config["save_dir"], f"dis_{it:06d}.pth") + opt_path = os.path.join(self.config["save_dir"], f"opt_{it:06d}.pth") + print(f"\nsaving model to {gen_path} ...") # remove .module for saving - if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP): + if isinstance(self.netG, torch.nn.DataParallel) or isinstance( + self.netG, DDP + ): netG = self.netG.module - if not self.config['model']['no_dis']: + if not self.config["model"]["no_dis"]: netD = self.netD.module else: netG = self.netG - if not self.config['model']['no_dis']: + if not self.config["model"]["no_dis"]: netD = self.netD # save checkpoints torch.save(netG.state_dict(), gen_path) - if not self.config['model']['no_dis']: + if not self.config["model"]["no_dis"]: torch.save(netD.state_dict(), dis_path) torch.save( { - 'epoch': self.epoch, - 'iteration': self.iteration, - 'optimG': self.optimG.state_dict(), - 'optimD': self.optimD.state_dict(), - 'scheG': self.scheG.state_dict(), - 'scheD': self.scheD.state_dict() - }, opt_path) + "epoch": self.epoch, + "iteration": self.iteration, + "optimG": self.optimG.state_dict(), + "optimD": self.optimD.state_dict(), + "scheG": self.scheG.state_dict(), + "scheD": self.scheD.state_dict(), + }, + opt_path, + ) else: torch.save( { - 'epoch': self.epoch, - 'iteration': self.iteration, - 'optimG': self.optimG.state_dict(), - 'scheG': self.scheG.state_dict() - }, opt_path) - - latest_path = os.path.join(self.config['save_dir'], 'latest.ckpt') + "epoch": self.epoch, + "iteration": self.iteration, + "optimG": self.optimG.state_dict(), + "scheG": self.scheG.state_dict(), + }, + opt_path, + ) + + latest_path = os.path.join(self.config["save_dir"], "latest.ckpt") os.system(f"echo {it:06d} > {latest_path}") def train(self): """training entry""" - pbar = range(int(self.train_args['iterations'])) - if self.config['global_rank'] == 0: - pbar = tqdm(pbar, - initial=self.iteration, - dynamic_ncols=True, - smoothing=0.01) + pbar = range(int(self.train_args["iterations"])) + if self.config["global_rank"] == 0: + pbar = tqdm( + pbar, initial=self.iteration, dynamic_ncols=True, smoothing=0.01 + ) - os.makedirs('logs', exist_ok=True) + os.makedirs("logs", exist_ok=True) logging.basicConfig( level=logging.INFO, - format="%(asctime)s %(filename)s[line:%(lineno)d]" - "%(levelname)s %(message)s", + format="%(asctime)s %(filename)s[line:%(lineno)d]%(levelname)s %(message)s", datefmt="%a, %d %b %Y %H:%M:%S", filename=f"logs/{self.config['save_dir'].split('/')[-1]}.log", - filemode='w') + filemode="w", + ) while True: self.epoch += 1 self.prefetcher.reset() - if self.config['distributed']: + if self.config["distributed"]: self.train_sampler.set_epoch(self.epoch) self._train_epoch(pbar) - if self.iteration > self.train_args['iterations']: + if self.iteration > self.train_args["iterations"]: break - print('\nEnd training....') + print("\nEnd training....") def _train_epoch(self, pbar): """Process input and calculate loss every training epoch""" - device = self.config['device'] + device = self.config["device"] train_data = self.prefetcher.next() while train_data is not None: self.iteration += 1 @@ -363,37 +396,54 @@ def _train_epoch(self, pbar): masked_frames = frames * (1 - masks) masked_local_frames = masked_frames[:, :l_t, ...] # get gt optical flow - if flows_f[0] == 'None' or flows_b[0] == 'None': + if flows_f[0] == "None" or flows_b[0] == "None": gt_flows_bi = self.fix_raft(gt_local_frames) else: gt_flows_bi = (flows_f.to(device), flows_b.to(device)) # ---- complete flow ---- - pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, local_masks) - pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, local_masks) + pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow( + gt_flows_bi, local_masks + ) + pred_flows_bi = self.fix_flow_complete.combine_flow( + gt_flows_bi, pred_flows_bi, local_masks + ) # pred_flows_bi = gt_flows_bi # ---- image propagation ---- - prop_imgs, updated_local_masks = self.netG.module.img_propagation(masked_local_frames, pred_flows_bi, local_masks, interpolation=self.interp_mode) + prop_imgs, updated_local_masks = self.netG.module.img_propagation( + masked_local_frames, + pred_flows_bi, + local_masks, + interpolation=self.interp_mode, + ) updated_masks = masks.clone() updated_masks[:, :l_t, ...] = updated_local_masks.view(b, l_t, 1, h, w) updated_frames = masked_frames.clone() - prop_local_frames = gt_local_frames * (1-local_masks) + prop_imgs.view(b, l_t, 3, h, w) * local_masks # merge + prop_local_frames = ( + gt_local_frames * (1 - local_masks) + + prop_imgs.view(b, l_t, 3, h, w) * local_masks + ) # merge updated_frames[:, :l_t, ...] = prop_local_frames # ---- feature propagation + Transformer ---- - pred_imgs = self.netG(updated_frames, pred_flows_bi, masks, updated_masks, l_t) + pred_imgs = self.netG( + updated_frames, pred_flows_bi, masks, updated_masks, l_t + ) pred_imgs = pred_imgs.view(b, -1, c, h, w) # get the local frames pred_local_frames = pred_imgs[:, :l_t, ...] - comp_local_frames = gt_local_frames * (1. - local_masks) + pred_local_frames * local_masks - comp_imgs = frames * (1. - masks) + pred_imgs * masks + comp_local_frames = ( + gt_local_frames * (1.0 - local_masks) + pred_local_frames * local_masks + ) + comp_imgs = frames * (1.0 - masks) + pred_imgs * masks gen_loss = 0 dis_loss = 0 + ffl = None # optimize net_g - if not self.config['model']['no_dis']: + if not self.config["model"]["no_dis"]: for p in self.netD.parameters(): p.requires_grad = False @@ -401,33 +451,55 @@ def _train_epoch(self, pbar): # generator l1 loss hole_loss = self.l1_loss(pred_imgs * masks, frames * masks) - hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight'] + hole_loss = ( + hole_loss / torch.mean(masks) * self.config["losses"]["hole_weight"] + ) gen_loss += hole_loss - self.add_summary(self.gen_writer, 'loss/hole_loss', hole_loss.item()) + self.add_summary(self.gen_writer, "loss/hole_loss", hole_loss.item()) valid_loss = self.l1_loss(pred_imgs * (1 - masks), frames * (1 - masks)) - valid_loss = valid_loss / torch.mean(1-masks) * self.config['losses']['valid_weight'] + valid_loss = ( + valid_loss + / torch.mean(1 - masks) + * self.config["losses"]["valid_weight"] + ) gen_loss += valid_loss - self.add_summary(self.gen_writer, 'loss/valid_loss', valid_loss.item()) + self.add_summary(self.gen_writer, "loss/valid_loss", valid_loss.item()) # perceptual loss - if self.config['losses']['perceptual_weight'] > 0: - perc_loss = self.perc_loss(pred_imgs.view(-1,3,h,w), frames.view(-1,3,h,w))[0] * self.config['losses']['perceptual_weight'] + if self.config["losses"]["perceptual_weight"] > 0: + perc_loss = ( + self.perc_loss( + pred_imgs.view(-1, 3, h, w), frames.view(-1, 3, h, w) + )[0] + * self.config["losses"]["perceptual_weight"] + ) gen_loss += perc_loss - self.add_summary(self.gen_writer, 'loss/perc_loss', perc_loss.item()) - + self.add_summary(self.gen_writer, "loss/perc_loss", perc_loss.item()) + + # ── focal frequency loss — hole region only ─────────────────────── + if self.config["losses"].get("ffl_weight", 0) > 0: + # Apply ONLY to the masked/hole region (same as hole_loss). + # Unmasked regions are already perfect copies of the input; + # including them would pollute the gradient signal. + pred_hole = (pred_imgs * masks).view(-1, 3, h, w) + gt_hole = (frames * masks).view(-1, 3, h, w) + ffl = self.ffl_loss(pred_hole, gt_hole) + gen_loss += ffl + self.add_summary(self.gen_writer, "loss/ffl_loss", ffl.item()) + # ────────────────────────────────────────────────────────────────── # gan loss - if not self.config['model']['no_dis']: + if not self.config["model"]["no_dis"]: # generator adversarial loss gen_clip = self.netD(comp_imgs) gan_loss = self.adversarial_loss(gen_clip, True, False) - gan_loss = gan_loss * self.config['losses']['adversarial_weight'] + gan_loss = gan_loss * self.config["losses"]["adversarial_weight"] gen_loss += gan_loss - self.add_summary(self.gen_writer, 'loss/gan_loss', gan_loss.item()) + self.add_summary(self.gen_writer, "loss/gan_loss", gan_loss.item()) gen_loss.backward() self.optimG.step() - if not self.config['model']['no_dis']: + if not self.config["model"]["no_dis"]: # optimize net_d for p in self.netD.parameters(): p.requires_grad = True @@ -439,8 +511,12 @@ def _train_epoch(self, pbar): dis_real_loss = self.adversarial_loss(real_clip, True, True) dis_fake_loss = self.adversarial_loss(fake_clip, False, True) dis_loss += (dis_real_loss + dis_fake_loss) / 2 - self.add_summary(self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item()) - self.add_summary(self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item()) + self.add_summary( + self.dis_writer, "loss/dis_vid_real", dis_real_loss.item() + ) + self.add_summary( + self.dis_writer, "loss/dis_vid_fake", dis_fake_loss.item() + ) dis_loss.backward() self.optimD.step() @@ -450,60 +526,125 @@ def _train_epoch(self, pbar): if self.iteration % 200 == 0: # img to cpu t = 0 - gt_local_frames_cpu = ((gt_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu() - masked_local_frames = ((masked_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu() - prop_local_frames_cpu = ((prop_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu() - pred_local_frames_cpu = ((pred_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu() - img_results = torch.cat([masked_local_frames[0][t], gt_local_frames_cpu[0][t], - prop_local_frames_cpu[0][t], pred_local_frames_cpu[0][t]], 1) - img_results = torchvision.utils.make_grid(img_results, nrow=1, normalize=True) + gt_local_frames_cpu = ( + (gt_local_frames.view(b, -1, 3, h, w) + 1) / 2.0 + ).cpu() + masked_local_frames = ( + (masked_local_frames.view(b, -1, 3, h, w) + 1) / 2.0 + ).cpu() + prop_local_frames_cpu = ( + (prop_local_frames.view(b, -1, 3, h, w) + 1) / 2.0 + ).cpu() + pred_local_frames_cpu = ( + (pred_local_frames.view(b, -1, 3, h, w) + 1) / 2.0 + ).cpu() + img_results = torch.cat( + [ + masked_local_frames[0][t], + gt_local_frames_cpu[0][t], + prop_local_frames_cpu[0][t], + pred_local_frames_cpu[0][t], + ], + 1, + ) + img_results = torchvision.utils.make_grid( + img_results, nrow=1, normalize=True + ) if self.gen_writer is not None: - self.gen_writer.add_image(f'img/img:inp-gt-res-{t}', img_results, self.iteration) + self.gen_writer.add_image( + f"img/img:inp-gt-res-{t}", img_results, self.iteration + ) t = 5 if masked_local_frames.shape[1] > 5: - img_results = torch.cat([masked_local_frames[0][t], gt_local_frames_cpu[0][t], - prop_local_frames_cpu[0][t], pred_local_frames_cpu[0][t]], 1) - img_results = torchvision.utils.make_grid(img_results, nrow=1, normalize=True) + img_results = torch.cat( + [ + masked_local_frames[0][t], + gt_local_frames_cpu[0][t], + prop_local_frames_cpu[0][t], + pred_local_frames_cpu[0][t], + ], + 1, + ) + img_results = torchvision.utils.make_grid( + img_results, nrow=1, normalize=True + ) if self.gen_writer is not None: - self.gen_writer.add_image(f'img/img:inp-gt-res-{t}', img_results, self.iteration) + self.gen_writer.add_image( + f"img/img:inp-gt-res-{t}", img_results, self.iteration + ) # flow to cpu gt_flows_forward_cpu = flow_to_image(gt_flows_bi[0][0]).cpu() - masked_flows_forward_cpu = (gt_flows_forward_cpu[0] * (1-local_masks[0][0].cpu())).to(gt_flows_forward_cpu) + masked_flows_forward_cpu = ( + gt_flows_forward_cpu[0] * (1 - local_masks[0][0].cpu()) + ).to(gt_flows_forward_cpu) pred_flows_forward_cpu = flow_to_image(pred_flows_bi[0][0]).cpu() - flow_results = torch.cat([gt_flows_forward_cpu[0], masked_flows_forward_cpu, pred_flows_forward_cpu[0]], 1) + flow_results = torch.cat( + [ + gt_flows_forward_cpu[0], + masked_flows_forward_cpu, + pred_flows_forward_cpu[0], + ], + 1, + ) if self.gen_writer is not None: - self.gen_writer.add_image('img/flow:gt-pred', flow_results, self.iteration) + self.gen_writer.add_image( + "img/flow:gt-pred", flow_results, self.iteration + ) # console logs - if self.config['global_rank'] == 0: + if self.config["global_rank"] == 0: pbar.update(1) - if not self.config['model']['no_dis']: - pbar.set_description((f"d: {dis_loss.item():.3f}; " - f"hole: {hole_loss.item():.3f}; " - f"valid: {valid_loss.item():.3f}")) + if not self.config["model"]["no_dis"]: + _ffl_val = ( + ffl.item() + if self.config["losses"].get("ffl_weight", 0) > 0 + else 0.0 + ) + pbar.set_description( + ( + f"d: {dis_loss.item():.3f}; " + f"hole: {hole_loss.item():.3f}; " + f"valid: {valid_loss.item():.3f}; " + f"ffl: {_ffl_val:.4f}" + ) + ) else: - pbar.set_description((f"hole: {hole_loss.item():.3f}; " - f"valid: {valid_loss.item():.3f}")) - - if self.iteration % self.train_args['log_freq'] == 0: - if not self.config['model']['no_dis']: - logging.info(f"[Iter {self.iteration}] " - f"d: {dis_loss.item():.4f}; " - f"hole: {hole_loss.item():.4f}; " - f"valid: {valid_loss.item():.4f}") + pbar.set_description( + ( + f"hole: {hole_loss.item():.3f}; " + f"valid: {valid_loss.item():.3f}" + ) + ) + + if self.iteration % self.train_args["log_freq"] == 0: + if not self.config["model"]["no_dis"]: + _ffl_val = ( + ffl.item() + if self.config["losses"].get("ffl_weight", 0) > 0 + else 0.0 + ) + logging.info( + f"[Iter {self.iteration}] " + f"d: {dis_loss.item():.4f}; " + f"hole: {hole_loss.item():.4f}; " + f"valid: {valid_loss.item():.4f}; " + f"ffl: {_ffl_val:.4f}" + ) else: - logging.info(f"[Iter {self.iteration}] " - f"hole: {hole_loss.item():.4f}; " - f"valid: {valid_loss.item():.4f}") + logging.info( + f"[Iter {self.iteration}] " + f"hole: {hole_loss.item():.4f}; " + f"valid: {valid_loss.item():.4f}" + ) # saving models - if self.iteration % self.train_args['save_freq'] == 0: + if self.iteration % self.train_args["save_freq"] == 0: self.save(int(self.iteration)) - if self.iteration > self.train_args['iterations']: + if self.iteration > self.train_args["iterations"]: break - train_data = self.prefetcher.next() \ No newline at end of file + train_data = self.prefetcher.next() diff --git a/datasets/davis/test.json b/datasets/davis/test.json index 54875df4..c58ae491 100644 --- a/datasets/davis/test.json +++ b/datasets/davis/test.json @@ -1 +1,11 @@ -{"bear": 82, "blackswan": 50, "bmx-bumps": 90, "bmx-trees": 80, "boat": 75, "breakdance": 84, "breakdance-flare": 71, "bus": 80, "camel": 90, "car-roundabout": 75, "car-shadow": 40, "car-turn": 80, "cows": 104, "dance-jump": 60, "dance-twirl": 90, "dog": 60, "dog-agility": 25, "drift-chicane": 52, "drift-straight": 50, "drift-turn": 64, "elephant": 80, "flamingo": 80, "goat": 90, "hike": 80, "hockey": 75, "horsejump-high": 50, "horsejump-low": 60, "kite-surf": 50, "kite-walk": 80, "libby": 49, "lucia": 70, "mallard-fly": 70, "mallard-water": 80, "motocross-bumps": 60, "motocross-jump": 40, "motorbike": 43, "paragliding": 70, "paragliding-launch": 80, "parkour": 100, "rhino": 90, "rollerblade": 35, "scooter-black": 43, "scooter-gray": 75, "soapbox": 99, "soccerball": 48, "stroller": 91, "surf": 55, "swing": 60, "tennis": 70, "train": 80} \ No newline at end of file +{ + "drift-straight": 50, + "stroller": 91, + "classic-car": 63, + "mbike-trick": 79, + "bear": 82, + "schoolgirls": 80, + "dance-jump": 60, + "drift-turn": 64, + "paragliding": 70 +} \ No newline at end of file diff --git a/datasets/davis/train.json b/datasets/davis/train.json index 3f63b2d9..1a328d42 100644 --- a/datasets/davis/train.json +++ b/datasets/davis/train.json @@ -1 +1,83 @@ -{"baseball": 90, "basketball-game": 77, "bears-ball": 78, "bmx-rider": 85, "butterfly": 80, "car-competition": 66, "cat": 52, "chairlift": 99, "circus": 73, "city-ride": 70, "crafting": 45, "curling": 76, "dog-competition": 85, "dolphins-show": 74, "dribbling": 49, "drone-flying": 70, "ducks": 75, "elephant-hyenas": 55, "giraffes": 88, "gym-ball": 69, "helicopter-landing": 77, "horse-race": 80, "horses-kids": 78, "hurdles-race": 55, "ice-hockey": 52, "jet-ski": 83, "juggling-selfie": 78, "kayak-race": 63, "kids-robot": 75, "landing": 35, "luggage": 83, "mantaray": 73, "marbles": 70, "mascot": 78, "mermaid": 78, "monster-trucks": 99, "motorbike-indoors": 79, "motorbike-race": 88, "music-band": 87, "obstacles": 81, "obstacles-race": 48, "peacock": 75, "plane-exhibition": 73, "puppet": 100, "robot-battle": 85, "robotic-arm": 82, "rodeo": 85, "sea-turtle": 90, "skydiving-jumping": 75, "snowboard-race": 75, "snowboard-sand": 55, "surfer": 80, "swimmer": 86, "table-tennis": 70, "tram": 84, "trucks-race": 78, "twist-dance": 83, "volleyball-beach": 73, "water-slide": 88, "weightlifting": 90} \ No newline at end of file +{ + "night-race": 46, + "bmx-bumps": 90, + "dog": 60, + "motorbike": 43, + "lady-running": 65, + "breakdance-flare": 71, + "dog-agility": 25, + "bike-packing": 69, + "hockey": 75, + "miami-surf": 70, + "loading": 50, + "horsejump-low": 60, + "flamingo": 80, + "libby": 49, + "car-shadow": 40, + "skate-park": 80, + "koala": 100, + "walking": 72, + "rallye": 50, + "train": 80, + "boxing-fisheye": 87, + "paragliding-launch": 80, + "scooter-gray": 75, + "snowboard": 66, + "sheep": 68, + "crossing": 52, + "tractor-sand": 76, + "disc-jockey": 76, + "motocross-jump": 40, + "upside-down": 65, + "scooter-black": 43, + "parkour": 100, + "dance-twirl": 90, + "lab-coat": 47, + "color-run": 84, + "camel": 90, + "kid-football": 68, + "longboard": 52, + "rollerblade": 35, + "car-roundabout": 75, + "horsejump-high": 50, + "tennis": 70, + "kite-surf": 50, + "kite-walk": 80, + "drift-chicane": 52, + "rhino": 90, + "goat": 90, + "cows": 104, + "swing": 60, + "shooting": 40, + "lindy-hop": 73, + "car-turn": 80, + "motocross-bumps": 60, + "surf": 55, + "drone": 91, + "planes-water": 38, + "gold-fish": 78, + "dogs-jump": 66, + "boat": 75, + "blackswan": 50, + "dog-gooses": 86, + "varanus-cage": 67, + "india": 81, + "tuk-tuk": 59, + "dogs-scale": 83, + "mallard-fly": 70, + "hike": 80, + "dancing": 62, + "pigs": 79, + "cat-girl": 89, + "judo": 34, + "bmx-trees": 80, + "stunt": 71, + "mallard-water": 80, + "elephant": 80, + "breakdance": 84, + "scooter-board": 91, + "bus": 80, + "soapbox": 99, + "lucia": 70, + "soccerball": 48 +} diff --git a/inputs/object_removal/bmx-trees/ground_truth.mp4 b/inputs/object_removal/bmx-trees/ground_truth.mp4 new file mode 100644 index 00000000..6c9b5162 Binary files /dev/null and b/inputs/object_removal/bmx-trees/ground_truth.mp4 differ diff --git a/inputs/object_removal/tennis/ground_truth-01.mp4 b/inputs/object_removal/tennis/ground_truth-01.mp4 new file mode 100644 index 00000000..f337dbde Binary files /dev/null and b/inputs/object_removal/tennis/ground_truth-01.mp4 differ diff --git a/inputs/object_removal/tennis/ground_truth.mp4 b/inputs/object_removal/tennis/ground_truth.mp4 new file mode 100644 index 00000000..4f98144e Binary files /dev/null and b/inputs/object_removal/tennis/ground_truth.mp4 differ