From d0d8908e889e4c97e103f82334eeb2108d1195da Mon Sep 17 00:00:00 2001 From: Bin Zhao Date: Sun, 15 Aug 2021 09:05:54 +0200 Subject: [PATCH] spatio_temporal_aggregation --- ltr/actors/segmentation.py | 83 +++ ltr/data/processing.py | 139 +++++ ltr/data/sampler.py | 159 ++++++ ltr/models/lwl/label_encoder.py | 91 ++++ ltr/models/lwl/linear_filter.py | 17 +- ltr/models/lwl/sta_net.py | 209 ++++++++ ltr/train_settings/sta/__init__.py | 0 ltr/train_settings/sta/sta.py | 145 ++++++ pytracking/evaluation/data.py | 7 + pytracking/evaluation/multi_object_wrapper.py | 209 ++++++++ pytracking/evaluation/tracker.py | 140 ++++- pytracking/experiments/myexperiments.py | 16 + pytracking/features/preprocessing_sta.py | 189 +++++++ pytracking/notebooks/analyze_results.ipynb | 243 +++++++-- pytracking/parameter/sta/__init__.py | 0 pytracking/parameter/sta/sta_davis.py | 65 +++ pytracking/parameter/sta/sta_ytvos.py | 65 +++ pytracking/tracker/sta/__init__.py | 5 + pytracking/tracker/sta/sta.py | 480 ++++++++++++++++++ 19 files changed, 2210 insertions(+), 52 deletions(-) create mode 100644 ltr/models/lwl/sta_net.py create mode 100644 ltr/train_settings/sta/__init__.py create mode 100644 ltr/train_settings/sta/sta.py create mode 100644 pytracking/features/preprocessing_sta.py create mode 100644 pytracking/parameter/sta/__init__.py create mode 100644 pytracking/parameter/sta/sta_davis.py create mode 100644 pytracking/parameter/sta/sta_ytvos.py create mode 100644 pytracking/tracker/sta/__init__.py create mode 100644 pytracking/tracker/sta/sta.py diff --git a/ltr/actors/segmentation.py b/ltr/actors/segmentation.py index 6f5161dd..bf3ba5d6 100644 --- a/ltr/actors/segmentation.py +++ b/ltr/actors/segmentation.py @@ -138,3 +138,86 @@ def __call__(self, data): stats['Stats/acc_box_train'] = acc_box/cnt_box return loss, stats + + +class STAActor(BaseActor): + """Actor for training the DiMP network.""" + def __init__(self, net, objective, loss_weight=None, + use_focal_loss=False, use_lovasz_loss=False, + detach_pred=True, + num_refinement_iter=3, + disable_backbone_bn=False, + disable_all_bn=False): + super().__init__(net, objective) + if loss_weight is None: + loss_weight = {'segm': 1.0} + self.loss_weight = loss_weight + + self.use_focal_loss = use_focal_loss + self.use_lovasz_loss = use_lovasz_loss + self.detach_pred = detach_pred + self.num_refinement_iter = num_refinement_iter + self.disable_backbone_bn = disable_backbone_bn + self.disable_all_bn = disable_all_bn + + def train(self, mode=True): + """ Set whether the network is in train mode. + args: + mode (True) - Bool specifying whether in training mode. + """ + self.net.train(mode) + + if self.disable_all_bn: + self.net.eval() + elif self.disable_backbone_bn: + for m in self.net.feature_extractor.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def __call__(self, data): + """ + args: + data - The input data, should contain the fields 'train_images', 'test_images', 'train_anno', + 'test_proposals', 'proposal_iou', 'test_label', 'train_masks' and 'test_masks' + returns: + loss - the training loss + stats - dict containing detailed losses + """ + segm_pred_bbox, segm_pred = self.net(train_imgs=data['train_images'], + train_bbox=data['train_anno']) + acc = 0 + cnt = 0 + acc_mid = 0 + cnt_mid = 0 + + segm_pred_bbox = segm_pred_bbox.view(-1, 1, *segm_pred_bbox.shape[-2:]) + segm_pred = segm_pred.view(-1, 1, *segm_pred.shape[-2:]) + gt_segm = data['train_masks'] + gt_segm = gt_segm.view(-1, 1, *gt_segm.shape[-2:]) + + loss_segm_bbox = self.loss_weight['segm'] * self.objective['segm'](segm_pred_bbox, gt_segm) + loss_segm = self.loss_weight['segm'] * self.objective['segm'](segm_pred, gt_segm) + + acc_l = [davis_jaccard_measure(torch.sigmoid(rm.detach()).cpu().numpy() > 0.5, lb.cpu().numpy()) for + rm, lb in zip(segm_pred.view(-1, *segm_pred.shape[-2:]), gt_segm.view(-1, *segm_pred.shape[-2:]))] + acc += sum(acc_l) + cnt += len(acc_l) + + acc_l_mid = [davis_jaccard_measure(torch.sigmoid(rm.detach()).cpu().numpy() > 0.5, lb.cpu().numpy()) for + rm, lb in zip(segm_pred_bbox.view(-1, *segm_pred_bbox.shape[-2:]), gt_segm.view(-1, *segm_pred_bbox.shape[-2:]))] + acc_mid += sum(acc_l_mid) + cnt_mid += len(acc_l_mid) + + loss = loss_segm_bbox + loss_segm + + if torch.isinf(loss) or torch.isnan(loss): + raise Exception('ERROR: Loss was nan or inf!!!') + + # Log stats + stats = {'Loss/total': loss.item()} + stats['Loss/segm mid'] = loss_segm_bbox.item() + stats['Loss/segm'] = loss_segm.item() + + stats['Stats/acc_mid'] = acc_mid / cnt_mid + stats['Stats/acc'] = acc / cnt + return loss, stats diff --git a/ltr/data/processing.py b/ltr/data/processing.py index 82757b7f..ec4aa555 100644 --- a/ltr/data/processing.py +++ b/ltr/data/processing.py @@ -776,6 +776,145 @@ def __call__(self, data: TensorDict): return data +class STAProcessing(BaseProcessing): + """ The processing class used for training DiMP. The images are processed in the following way. + First, the target bounding box is jittered by adding some noise. Next, a square region (called search region ) + centered at the jittered target center, and of area search_area_factor^2 times the area of the jittered box is + cropped from the image. The reason for jittering the target box is to avoid learning the bias that the target is + always at the center of the search region. The search region is then resized to a fixed size given by the + argument output_sz. + + """ + + def __init__(self, search_area_factor, output_sz, center_jitter_factor, scale_jitter_factor, crop_type='replicate', + max_scale_change=None, mode='pair', + new_roll=False, *args, **kwargs): + """ + args: + search_area_factor - The size of the search region relative to the target size. + output_sz - An integer, denoting the size to which the search region is resized. The search region is always + square. + center_jitter_factor - A dict containing the amount of jittering to be applied to the target center before + extracting the search region. See _get_jittered_box for how the jittering is done. + scale_jitter_factor - A dict containing the amount of jittering to be applied to the target size before + extracting the search region. See _get_jittered_box for how the jittering is done. + crop_type - If 'replicate', the boundary pixels are replicated in case the search region crop goes out of image. + If 'nopad', the search region crop is shifted/shrunk to fit completely inside the image. + mode - Either 'pair' or 'sequence'. If mode='sequence', then output has an extra dimension for frames + """ + super().__init__(*args, **kwargs) + self.search_area_factor = search_area_factor + self.output_sz = output_sz + self.center_jitter_factor = center_jitter_factor + self.scale_jitter_factor = scale_jitter_factor + self.crop_type = crop_type + self.mode = mode + self.max_scale_change = max_scale_change + + self.new_roll = new_roll + + def _get_jittered_box(self, box, mode): + """ Jitter the input box + args: + box - input bounding box + mode - string 'train' or 'test' indicating train or test data + + returns: + torch.Tensor - jittered box + """ + + if self.scale_jitter_factor.get('mode', 'gauss') == 'gauss': + jittered_size = box[2:4] * torch.exp(torch.randn(2) * self.scale_jitter_factor[mode]) + elif self.scale_jitter_factor.get('mode', 'gauss') == 'uniform': + jittered_size = box[2:4] * torch.exp(torch.FloatTensor(2).uniform_(-self.scale_jitter_factor[mode], + self.scale_jitter_factor[mode])) + else: + raise Exception + + max_offset = (jittered_size.prod().sqrt() * torch.tensor(self.center_jitter_factor[mode])).float() + jittered_center = box[0:2] + 0.5 * box[2:4] + max_offset * (torch.rand(2) - 0.5) + + return torch.cat((jittered_center - 0.5 * jittered_size, jittered_size), dim=0) + + def _generate_search_bb(self, boxes_crop, crops, boxes_orig, boxes_jittered): + search_bb = [] + anno_search_bb = [] + for b_crop, im, b_orig, b_jit in zip(boxes_crop, crops, boxes_orig, boxes_jittered): + output_sz = self.output_sz + if isinstance(output_sz, (float, int)): + output_sz = (output_sz, output_sz) + + output_sz = torch.Tensor(output_sz) + + resize_factor = b_crop[-1] / b_orig[-1] + + b_jit_crop_sz = b_jit[2:] * resize_factor + + search_bb_sz = ( + output_sz * (b_jit_crop_sz.prod() / output_sz.prod()).sqrt() * self.search_area_factor).ceil() + search_bb.append(torch.cat((torch.zeros(2), search_bb_sz))) + + b_sh = b_crop.clone() + + anno_search_bb.append(b_sh) + return search_bb, anno_search_bb + + def __call__(self, data: TensorDict): + """ + args: + data - The input data, should contain the following fields: + 'train_images' - + 'train_masks' - + 'train_anno' - + + returns: + TensorDict - output data block with following fields: + 'train_images' - + 'train_masks' - + 'train_anno' - + """ + + if self.transform['joint'] is not None: + data['train_images'], data['train_anno'], data['train_masks'] = self.transform['joint'](image=data['train_images'], bbox=data['train_anno'], mask=data['train_masks']) + + for s in ['train']: + assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \ + "In pair mode, num train/test frames must be 1" + + # Add a uniform noise to the center pos + jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']] + orig_anno = data[s + '_anno'] + + crops, boxes, mask_crops = prutils.target_image_crop(data[s + '_images'], jittered_anno, + data[s + '_anno'], self.search_area_factor, + self.output_sz, mode=self.crop_type, + max_scale_change=self.max_scale_change, + masks=data[s + '_masks']) + + data[s + '_images'], data[s + '_anno'], data[s + '_masks'] = self.transform[s](image=crops, bbox=boxes, mask=mask_crops, joint=False) + # Generate search_bb + sa_bb, anno_in_sa = self._generate_search_bb(boxes, crops, orig_anno, jittered_anno) + + data[s + '_sa_bb'] = sa_bb + data[s + '_anno_in_sa'] = anno_in_sa + + for s in ['train']: + is_distractor = data.get('is_distractor_{}_frame'.format(s), None) + if is_distractor is not None: + for is_dist, box in zip(is_distractor, data[s+'_anno']): + if is_dist: + box[0] = 99999999.9 + box[1] = 99999999.9 + + # Prepare output + if self.mode == 'sequence': + data = data.apply(stack_tensors) + else: + data = data.apply(lambda x: x[0] if isinstance(x, list) else x) + + return data + + class KYSProcessing(BaseProcessing): """ The processing class used for training KYS. The images are processed in the following way. First, the target bounding box is jittered by adding some noise. Next, a square region (called search region ) diff --git a/ltr/data/sampler.py b/ltr/data/sampler.py index 3c8f42f2..fa6f8867 100644 --- a/ltr/data/sampler.py +++ b/ltr/data/sampler.py @@ -377,6 +377,165 @@ def __getitem__(self, index): return self.processing(data) +class STASampler(torch.utils.data.Dataset): + """ Class responsible for sampling frames from training sequences to form batches. Each training sample is a + tuple consisting of i) a set of train frames, used to learn the DiMP classification model and obtain the + modulation vector for IoU-Net, and ii) a set of test frames on which target classification loss for the predicted + DiMP model, and the IoU prediction loss for the IoU-Net is calculated. + + The sampling is done in the following ways. First a dataset is selected at random. Next, a sequence is selected + from that dataset. A base frame is then sampled randomly from the sequence. Next, a set of 'train frames' and + 'test frames' are sampled from the sequence from the range [base_frame_id - max_gap, base_frame_id] and + (base_frame_id, base_frame_id + max_gap] respectively. Only the frames in which the target is visible are sampled. + If enough visible frames are not found, the 'max_gap' is increased gradually till enough frames are found. + + The sampled frames are then passed through the input 'processing' function for the necessary processing- + """ + + def __init__(self, datasets, p_datasets, samples_per_epoch, max_gap, + num_train_frames=1, processing=no_processing, p_reverse=None): + """ + args: + datasets - List of datasets to be used for training + p_datasets - List containing the probabilities by which each dataset will be sampled + samples_per_epoch - Number of training samples per epoch + max_gap - Maximum gap, in frame numbers, between the train frames and the test frames. + num_test_frames - Number of test frames to sample. + num_train_frames - Number of train frames to sample. + processing - An instance of Processing class which performs the necessary processing of the data. + """ + self.datasets = datasets + + # If p not provided, sample uniformly from all videos + if p_datasets is None: + p_datasets = [len(d) for d in self.datasets] + + # Normalize + p_total = sum(p_datasets) + self.p_datasets = [x/p_total for x in p_datasets] + + self.samples_per_epoch = samples_per_epoch + self.max_gap = max_gap + self.num_train_frames = num_train_frames + self.processing = processing + + self.p_reverse = p_reverse + + def __len__(self): + return self.samples_per_epoch + + def _sample_visible_ids(self, visible, num_ids=1, min_id=None, max_id=None): + """ Samples num_ids frames between min_id and max_id for which target is visible + + args: + visible - 1d Tensor indicating whether target is visible for each frame + num_ids - number of frames to be samples + min_id - Minimum allowed frame number + max_id - Maximum allowed frame number + + returns: + list - List of sampled frame numbers. None if not sufficient visible frames could be found. + """ + if min_id is None or min_id < 0: + min_id = 0 + if max_id is None or max_id > len(visible): + max_id = len(visible) + + valid_ids = [i for i in range(min_id, max_id) if visible[i]] + + # No visible ids + if len(valid_ids) == 0: + return None + + return random.choices(valid_ids, k=num_ids) + + def __getitem__(self, index): + """ + args: + index (int): Index (dataset index) + + returns: + TensorDict - dict containing all the data blocks + """ + + # Select a dataset + # TODO ensure that the dataset can either be used independently, or wrapped with batch sampler + # dataset = self.datasets[index] + dataset = random.choices(self.datasets, self.p_datasets)[0] + + is_video_dataset = dataset.is_video_sequence() + + reverse_sequence = False + if self.p_reverse is not None: + reverse_sequence = random.random() < self.p_reverse + + # Sample a sequence with enough visible frames + enough_visible_frames = False + while not enough_visible_frames: + # Sample a sequence + seq_id = random.randint(0, dataset.get_num_sequences() - 1) + + # Sample frames + seq_info_dict = dataset.get_sequence_info(seq_id) + visible = seq_info_dict['visible'] + + enough_visible_frames = visible.type(torch.int64).sum().item() > 2 * (self.num_train_frames) + + enough_visible_frames = enough_visible_frames or not is_video_dataset + + if is_video_dataset: + train_frame_ids = None + sample_frame_ids = None + gap_increase = 0 + + # Sample train frames + while sample_frame_ids is None: + if gap_increase > 1000: + raise Exception('Frame not found') + + if not reverse_sequence: + base_frame_id = self._sample_visible_ids(visible, num_ids=1, min_id=0, + max_id=len(visible)-self.num_train_frames+1) + + train_frame_ids = base_frame_id + sample_frame_ids = self._sample_visible_ids(visible, min_id=train_frame_ids[0]+1, + max_id=train_frame_ids[0] + self.max_gap + gap_increase, + num_ids=self.num_train_frames-1) + + # Increase gap until a frame is found + gap_increase += 5 + else: + base_frame_id = self._sample_visible_ids(visible, num_ids=1, min_id=self.num_train_frames - 1, + max_id=len(visible)) + + train_frame_ids = base_frame_id + sample_frame_ids = self._sample_visible_ids(visible, min_id=train_frame_ids[0]+1 - self.max_gap - gap_increase, + max_id=train_frame_ids[0], + num_ids=self.num_train_frames-1) + + # Increase gap until a frame is found + gap_increase += 5 + train_frame_ids = train_frame_ids + sample_frame_ids + else: + # In case of image dataset, just repeat the image to generate synthetic video + train_frame_ids = [1]*self.num_train_frames + + train_frame_ids = sorted(train_frame_ids, reverse=reverse_sequence) + + train_frames, train_anno, meta_obj = dataset.get_frames(seq_id, train_frame_ids, seq_info_dict) + + train_frames = train_frames[:len(train_frame_ids)] + + train_masks = train_anno['mask'] if 'mask' in train_anno else None + + data = TensorDict({'train_images': train_frames, + 'train_masks': train_masks, + 'train_anno': train_anno['bbox'], + 'dataset': dataset.get_name()}) + + return self.processing(data) + + class KYSSampler(torch.utils.data.Dataset): def __init__(self, datasets, p_datasets, samples_per_epoch, sequence_sample_info, processing=no_processing, sample_occluded_sequences=False): diff --git a/ltr/models/lwl/label_encoder.py b/ltr/models/lwl/label_encoder.py index 702264a5..05029bb9 100644 --- a/ltr/models/lwl/label_encoder.py +++ b/ltr/models/lwl/label_encoder.py @@ -124,3 +124,94 @@ def forward(self, bb, feat, sz): label_enc = label_enc.view(label_shape[0], label_shape[1], *label_enc.shape[-3:]) return label_enc + + +class ResidualDS16FeatSWBox(nn.Module): + def __init__(self, layer_dims, feat_dim, use_final_relu=True, use_gauss=True, use_bn=False, use_sample_w=True): + super().__init__() + + self.use_sample_w = use_sample_w + self.use_gauss = use_gauss + self.conv_block = conv_block(1, layer_dims[0], kernel_size=3, stride=2, padding=1, batch_norm=use_bn) + self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + ds1 = nn.Conv2d(layer_dims[0], layer_dims[1], kernel_size=3, padding=1, stride=2) + self.res1 = BasicBlock(layer_dims[0], layer_dims[1], stride=2, downsample=ds1, use_bn=use_bn) + + ds2 = nn.Conv2d(layer_dims[1], layer_dims[2], kernel_size=3, padding=1, stride=2) + self.res2 = BasicBlock(layer_dims[1], layer_dims[2], stride=2, downsample=ds2, use_bn=use_bn) + + ds3 = nn.Conv2d(layer_dims[2] + feat_dim, layer_dims[3], kernel_size=3, padding=1, stride=1) + self.res3 = BasicBlock(layer_dims[2] + feat_dim, layer_dims[3], stride=1, downsample=ds3, use_bn=use_bn) + + self.label_pred = conv_block(layer_dims[3], layer_dims[4], kernel_size=3, stride=1, padding=1, + relu=use_final_relu) + if self.use_sample_w: + self.samp_w_pred = nn.Conv2d(layer_dims[3], layer_dims[4], kernel_size=3, padding=1, stride=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + if self.use_sample_w: + self.samp_w_pred.weight.data.fill_(0) + self.samp_w_pred.bias.data.fill_(1) + + def bbox_to_mask(self, bbox, sz): + mask = torch.zeros((bbox.shape[0],1,*sz), dtype=torch.float32, device=bbox.device) + for i, bb in enumerate(bbox): + x1, y1, w, h = list(map(int, bb)) + x1 = int(x1+0.5) + y1 = int(y1+0.5) + h = int(h+0.5) + w = int(w+0.5) + mask[i, :, max(0,y1):(y1+h), max(0,x1):(x1+w)] = 1.0 + return mask + + def bbox_to_gauss(self, bbox, sz): + mask = torch.zeros((bbox.shape[0],1,*sz), dtype=torch.float32, device=bbox.device) + x_max, y_max = sz[-1], sz[-2] + for i, bb in enumerate(bbox): + x1, y1, w, h = list(map(int, bb)) + cx, cy = x1+w/2, y1+h/2 + xcoords = torch.arange(0, x_max).unsqueeze(dim=0).to(bbox.device).float() + ycoords = torch.arange(0, y_max).unsqueeze(dim=0).T.to(bbox.device).float() + d_xcoords = xcoords - cx + d_ycoords = ycoords - cy + dtotsqr = d_xcoords**2/(0.25*w)**2 + d_ycoords**2/(0.25*h)**2 + mask[i,0] = torch.exp(-0.5*dtotsqr) + return mask + + def forward(self, bb, feat, sz): + assert bb.dim() == 3 + num_frames = bb.shape[0] + batch_sz = bb.shape[1] + bb = bb.reshape(-1, 4) + if self.use_gauss: + label_mask = self.bbox_to_gauss(bb, sz[-2:]) + else: + label_mask = self.bbox_to_mask(bb, sz[-2:]) + + label_mask = label_mask.view(-1, 1, *label_mask.shape[-2:]) + + mask_enc = self.pool(self.conv_block(label_mask)) + mask_enc = self.res2(self.res1(mask_enc)) + + feat = feat.view(-1, *feat.shape[-3:]) + feat_mask_enc = torch.cat((mask_enc, feat), dim=1) + out = self.res3(feat_mask_enc) + + label_enc = self.label_pred(out) + label_enc = label_enc.view(num_frames, batch_sz, *label_enc.shape[-3:]) + + sample_w = None + if self.use_sample_w: + sample_w = self.samp_w_pred(out) + sample_w = sample_w.view(num_frames, batch_sz, *sample_w.shape[-3:]) + + # Out dim is (num_seq, num_frames, layer_dims[-1], h, w) + return label_enc, sample_w diff --git a/ltr/models/lwl/linear_filter.py b/ltr/models/lwl/linear_filter.py index 0fed538b..854b92c2 100644 --- a/ltr/models/lwl/linear_filter.py +++ b/ltr/models/lwl/linear_filter.py @@ -23,15 +23,16 @@ def __init__(self, filter_size, filter_initializer, filter_optimizer=None, featu self.filter_dilation_factors = filter_dilation_factors # Init weights - for m in self.feature_extractor.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - if m.bias is not None: + if not feature_extractor is None: + for m in self.feature_extractor.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() def forward(self, train_feat, test_feat, train_label, *args, **kwargs): """ the mask should be 5d""" diff --git a/ltr/models/lwl/sta_net.py b/ltr/models/lwl/sta_net.py new file mode 100644 index 00000000..2968768e --- /dev/null +++ b/ltr/models/lwl/sta_net.py @@ -0,0 +1,209 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as f +from collections import OrderedDict +from ltr.models.layers.blocks import conv_block +import ltr.models.lwl.linear_filter as target_clf +import ltr.models.target_classifier.features as clf_features +import ltr.models.lwl.initializer as seg_initializer +import ltr.models.lwl.label_encoder as seg_label_encoder +import ltr.models.lwl.loss_residual_modules as loss_residual_modules +import ltr.models.lwl.decoder as lwtl_decoder +import ltr.models.backbone as backbones +import ltr.models.backbone.resnet_mrcnn as mrcnn_backbones +import ltr.models.meta.steepestdescent as steepestdescent +from ltr import model_constructor +from pytracking import TensorList + + +class STANet(nn.Module): + def __init__(self, feature_extractor, target_model, target_model_segm, decoder, target_model_input_layer, decoder_input_layers, + label_encoder=None, bbox_encoder=None, segm_encoder = None): + super().__init__() + + self.feature_extractor = feature_extractor # Backbone feature extractor F + self.target_model = target_model # Target model and the few-shot learner + self.target_model_segm = target_model_segm # Target model and the few-shot learner + self.decoder = decoder # Segmentation Decoder + + self.label_encoder = label_encoder # Few-shot label generator and weight predictor + self.bbox_encoder = bbox_encoder # Few-shot label generator and weight predictor + self.segm_encoder = segm_encoder # Few-shot label generator and weight predictor + + self.target_model_input_layer = (target_model_input_layer,) if isinstance(target_model_input_layer, + str) else target_model_input_layer + self.decoder_input_layers = decoder_input_layers + self.output_layers = sorted(list(set(self.target_model_input_layer + self.decoder_input_layers))) + + def forward(self, train_imgs, train_bbox): + num_sequences = train_imgs.shape[1] + num_train_frames = train_imgs.shape[0] + + # Extract backbone features + train_feat = self.extract_backbone_features( + train_imgs.reshape(-1, train_imgs.shape[-3], train_imgs.shape[-2], train_imgs.shape[-1])) + + # Extract classification features + train_feat_clf = self.extract_target_model_features(train_feat) # seq*frames, channels, height, width + + train_bbox_enc, _ = self.label_encoder(train_bbox, train_feat_clf, list(train_imgs.shape[-2:])) + train_mask_enc, train_mask_sw = self.bbox_encoder(train_bbox, train_feat_clf, list(train_imgs.shape[-2:])) + train_feat_clf = train_feat_clf.view(num_train_frames, num_sequences, *train_feat_clf.shape[-3:]) + + _, filter_iter, _ = self.target_model.get_filter(train_feat_clf, train_mask_enc, train_mask_sw) + target_scores = [self.target_model.apply_target_model(f, train_feat_clf) for f in filter_iter] + target_scores_last_iter = target_scores[-1] + coarse_mask = torch.cat((train_bbox_enc, target_scores_last_iter), dim=2) + pred_all, _ = self.decoder(coarse_mask, train_feat, train_imgs.shape[-2:]) + # pred_all = pred_all.detach() + # train_bbox_enc = train_bbox_enc.detach() + + pred_all = pred_all.view(num_train_frames, num_sequences, *pred_all.shape[-2:]) + train_segm_enc, train_segm_sw = self.segm_encoder(torch.sigmoid(pred_all), train_feat_clf) + _, filter_iter_segm, _ = self.target_model_segm.get_filter(train_feat_clf, train_segm_enc, train_segm_sw) + target_scores_segm = [self.target_model_segm.apply_target_model(f, train_feat_clf) for f in filter_iter_segm] + target_scores_last_iter_segm = target_scores_segm[-1] + coarse_mask = torch.cat((train_bbox_enc, target_scores_last_iter_segm), dim=2) + pred_all_segm, _ = self.decoder(coarse_mask, train_feat, train_imgs.shape[-2:]) + pred_all_segm = pred_all_segm.view(num_train_frames, num_sequences, *pred_all_segm.shape[-2:]) + + return pred_all, pred_all_segm + + def segment_target_add_bbox_encoder(self, bbox_mask, target_filter, test_feat_clf, test_feat, segm): + # Classification features + assert target_filter.dim() == 5 # seq, filters, ch, h, w + + if segm == False: + target_scores = self.target_model.apply_target_model(target_filter, test_feat_clf) + else: + target_scores = self.target_model_segm.apply_target_model(target_filter, test_feat_clf) + + target_scores = torch.cat((bbox_mask, target_scores), dim=2) + mask_pred, decoder_feat = self.decoder(target_scores, test_feat, + (test_feat_clf.shape[-2]*16, test_feat_clf.shape[-1]*16)) + # Output is 1, 1, h, w + return mask_pred + + # def segment_target(self, target_filter, test_feat_clf, test_feat, segm): + # # Classification features + # assert target_filter.dim() == 5 # seq, filters, ch, h, w + # test_feat_clf = test_feat_clf.view(-1, 1, *test_feat_clf.shape[-3:]) + + # if segm == False: + # target_scores = self.target_model.classify(target_filter, test_feat_clf) + # else: + # target_scores = self.target_model_segm.classify(target_filter, test_feat_clf) + # mask_pred, decoder_feat = self.decoder(target_scores, test_feat, + # (test_feat_clf.shape[-2]*16, test_feat_clf.shape[-1]*16), + # (self.bbreg_decoder_layer, )) + + # bb_pred = None + # if self.bb_regressor is not None: + # bb_pred = self.bb_regressor(decoder_feat[self.bbreg_decoder_layer]) + # bb_pred[:, :2] *= test_feat_clf.shape[-2] * 16 + # bb_pred[:, 2:] *= test_feat_clf.shape[-1] * 16 + # bb_pred = torch.stack((bb_pred[:, 2], bb_pred[:, 0], + # bb_pred[:, 3] - bb_pred[:, 2], + # bb_pred[:, 1] - bb_pred[:, 0]), dim=1) + + # decoder_feat['mask_enc'] = target_scores.view(-1, *target_scores.shape[-3:]) + # aux_mask_pred = {} + # if 'mask_enc_iter' in self.aux_layers.keys(): + # aux_mask_pred['mask_enc_iter'] = \ + # self.aux_layers['mask_enc_iter'](target_scores.view(-1, *target_scores.shape[-3:]), (test_feat_clf.shape[-2]*16, + # test_feat_clf.shape[-1]*16)) + # # Output is 1, 1, h, w + # return mask_pred, bb_pred, aux_mask_pred + + def get_backbone_target_model_features(self, backbone_feat): + feat = OrderedDict({l: backbone_feat[l] for l in self.target_model_input_layer}) + if len(self.target_model_input_layer) == 1: + return feat[self.target_model_input_layer[0]] + return feat + + def extract_target_model_features(self, backbone_feat): + return self.target_model.extract_target_model_features(self.get_backbone_target_model_features(backbone_feat)) + + def extract_backbone_features(self, im, layers=None): + if layers is None: + layers = self.output_layers + return self.feature_extractor(im, layers) + +@model_constructor +def steepest_descent_resnet50(filter_size=1, num_filters=1, optim_iter=3, optim_init_reg=0.01, + backbone_pretrained=False, clf_feat_blocks=1, + clf_feat_norm=True, final_conv=False, + out_feature_dim=512, + target_model_input_layer='layer3', + decoder_input_layers = ("layer4", "layer3", "layer2", "layer1",), + detach_length=float('Inf'), + label_encoder_dims=(1, 1), + frozen_backbone_layers=(), + decoder_mdim=64, filter_groups=1, + use_bn_in_label_enc=True, + dilation_factors=None, + backbone_type='imagenet',): + # backbone feature extractor F + if backbone_type == 'imagenet': + backbone_net = backbones.resnet50(pretrained=backbone_pretrained, frozen_layers=frozen_backbone_layers) + elif backbone_type == 'mrcnn': + backbone_net = mrcnn_backbones.resnet50(pretrained=False, frozen_layers=frozen_backbone_layers) + else: + raise Exception + + norm_scale = math.sqrt(1.0 / (out_feature_dim * filter_size * filter_size)) + + layer_channels = backbone_net.out_feature_channels() + + # Extracts features input to the target model + target_model_feature_extractor = clf_features.residual_basic_block( + feature_dim=layer_channels[target_model_input_layer], + num_blocks=clf_feat_blocks, l2norm=clf_feat_norm, + final_conv=final_conv, norm_scale=norm_scale, + out_dim=out_feature_dim) + + initializer = seg_initializer.FilterInitializerZero(filter_size=filter_size, num_filters=num_filters, + feature_dim=out_feature_dim, filter_groups=filter_groups) + initializer_segm = seg_initializer.FilterInitializerZero(filter_size=filter_size, num_filters=num_filters, + feature_dim=out_feature_dim, filter_groups=filter_groups) + + # Few-shot label generator and weight predictor + label_encoder = seg_label_encoder.ResidualDS16FeatSWBox(layer_dims=label_encoder_dims + (num_filters, ), + feat_dim=out_feature_dim, use_final_relu=True, + use_gauss=False) + bbox_encoder = seg_label_encoder.ResidualDS16FeatSWBox(layer_dims=label_encoder_dims + (num_filters, ), + feat_dim=out_feature_dim, use_final_relu=True, + use_gauss=False) + segm_encoder = seg_label_encoder.ResidualDS16SW(layer_dims=label_encoder_dims[:-1] + (num_filters, ), + use_bn=use_bn_in_label_enc) + + # Computes few-shot learning loss + residual_module = loss_residual_modules.LWTLResidual(init_filter_reg=optim_init_reg) + residual_module_segm = loss_residual_modules.LWTLResidual(init_filter_reg=optim_init_reg) + + # Iteratively updates the target model parameters by minimizing the few-shot learning loss + optimizer = steepestdescent.GNSteepestDescent(residual_module=residual_module, num_iter=optim_iter, + detach_length=detach_length, + residual_batch_dim=1, compute_losses=True) + optimizer_segm = steepestdescent.GNSteepestDescent(residual_module=residual_module_segm, num_iter=optim_iter, + detach_length=detach_length, + residual_batch_dim=1, compute_losses=True) + + # Target model and Few-shot learner + target_model = target_clf.LinearFilter(filter_size=filter_size, filter_initializer=initializer, + filter_optimizer=optimizer, feature_extractor=target_model_feature_extractor, + filter_dilation_factors=dilation_factors) + target_model_segm = target_clf.LinearFilter(filter_size=filter_size, filter_initializer=initializer_segm, + filter_optimizer=optimizer_segm, feature_extractor=None, + filter_dilation_factors=dilation_factors) + + # Decoder + decoder_input_layers_channels = {L: layer_channels[L] for L in decoder_input_layers} + + decoder = lwtl_decoder.LWTLDecoder(num_filters*2, decoder_mdim, decoder_input_layers_channels, use_bn=True) + + net = STANet(feature_extractor=backbone_net, target_model=target_model, target_model_segm=target_model_segm, decoder=decoder, + label_encoder=label_encoder, bbox_encoder=bbox_encoder, segm_encoder = segm_encoder, + target_model_input_layer=target_model_input_layer, decoder_input_layers=decoder_input_layers) + return net \ No newline at end of file diff --git a/ltr/train_settings/sta/__init__.py b/ltr/train_settings/sta/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ltr/train_settings/sta/sta.py b/ltr/train_settings/sta/sta.py new file mode 100644 index 00000000..81c70c16 --- /dev/null +++ b/ltr/train_settings/sta/sta.py @@ -0,0 +1,145 @@ +import torch +import os +import torch.optim as optim +from ltr.dataset import YouTubeVOS, Davis +from ltr.data import processing, sampler, LTRLoader +import ltr.models.lwl.sta_net as sta_networks +import ltr.actors.segmentation as segm_actors +from ltr.trainers import LTRTrainer +import ltr.data.transforms as tfm +from ltr import MultiGPU +from ltr.models.loss.segmentation import LovaszSegLoss + + +def run(settings): + settings.description = 'Default train settings with backbone weights fixed. We initialize the backbone ResNet with ' \ + 'pre-trained Mask-RCNN weights. These weights can be obtained from ' \ + 'https://drive.google.com/file/d/12pVHmhqtxaJ151dZrXN1dcgUa7TuAjdA/view?usp=sharing. ' \ + 'Download and save these weights in env_settings.pretrained_networks directory' + settings.batch_size = 4 + settings.num_workers = 8 + settings.multi_gpu = True + settings.print_interval = 1 + settings.normalize_mean = [102.9801, 115.9465, 122.7717] + settings.normalize_std = [1.0, 1.0, 1.0] + + settings.feature_sz = (52, 30) + + # Settings used for generating the image crop input to the network. See documentation of LWTLProcessing class in + # ltr/data/processing.py for details. + settings.output_sz = (settings.feature_sz[0] * 16, settings.feature_sz[1] * 16) # Size of input image crop + settings.search_area_factor = 5.0 + settings.crop_type = 'inside_major' + settings.max_scale_change = None + + settings.center_jitter_factor = {'train': 3, 'test': (5.5, 4.5)} + settings.scale_jitter_factor = {'train': 0.25, 'test': 0.5} + + settings.min_target_area = 500 + + # Datasets + ytvos_train = YouTubeVOS(version="2019", multiobj=False, split='jjtrain') + davis_train = Davis(version='2017', multiobj=False, split='train') + + ytvos_val = YouTubeVOS(version="2019", multiobj=False, split='jjvalid') + + # Data transform + transform_joint = tfm.Transform(tfm.ToBGR(), + tfm.ToGrayscale(probability=0.05), + tfm.RandomHorizontalFlip(probability=0.5)) + + transform_train = tfm.Transform(tfm.ToTensorAndJitter(0.2, normalize=False), + tfm.Normalize(mean=settings.normalize_mean, std=settings.normalize_std)) + + transform_val = tfm.Transform(tfm.ToTensorAndJitter(0.0, normalize=False), + tfm.Normalize(mean=settings.normalize_mean, std=settings.normalize_std)) + + data_processing_train = processing.STAProcessing(search_area_factor=settings.search_area_factor, + output_sz=settings.output_sz, + center_jitter_factor=settings.center_jitter_factor, + scale_jitter_factor=settings.scale_jitter_factor, + mode='sequence', + crop_type=settings.crop_type, + max_scale_change=settings.max_scale_change, + transform=transform_train, + joint_transform=transform_joint, + new_roll=True) + + data_processing_val = processing.STAProcessing(search_area_factor=settings.search_area_factor, + output_sz=settings.output_sz, + center_jitter_factor=settings.center_jitter_factor, + scale_jitter_factor=settings.scale_jitter_factor, + mode='sequence', + crop_type=settings.crop_type, + max_scale_change=settings.max_scale_change, + transform=transform_val, + joint_transform=transform_joint, + new_roll=True) + + # Train sampler and loader + dataset_train = sampler.STASampler([ytvos_train, davis_train], [6, 1], + samples_per_epoch=settings.batch_size * 1000, max_gap=100, + num_train_frames=3, + processing=data_processing_train) + dataset_val = sampler.STASampler([ytvos_val], [1], + samples_per_epoch=settings.batch_size * 500, max_gap=100, + num_train_frames=3, + processing=data_processing_val) + + loader_train = LTRLoader('train', dataset_train, training=True, num_workers=settings.num_workers, + stack_dim=1, batch_size=settings.batch_size) + + loader_val = LTRLoader('val', dataset_val, training=False, num_workers=settings.num_workers, + epoch_interval=5, stack_dim=1, batch_size=settings.batch_size) + + # Network + net = sta_networks.steepest_descent_resnet50(filter_size=3, num_filters=16, optim_iter=5, + backbone_pretrained=True, + out_feature_dim=512, + frozen_backbone_layers=['conv1', 'bn1', 'layer1', 'layer2', 'layer3', + 'layer4'], + label_encoder_dims=(16, 32, 64, 128), + use_bn_in_label_enc=False, + clf_feat_blocks=0, + final_conv=True, + backbone_type='mrcnn') + + # Load pre-trained maskrcnn weights + weights_path = os.path.join(settings.env.pretrained_networks, 'e2e_mask_rcnn_R_50_FPN_1x_converted.pkl') + pretrained_weights = torch.load(weights_path) + + net.feature_extractor.load_state_dict(pretrained_weights) + + # Wrap the network for multi GPU training + if settings.multi_gpu: + net = MultiGPU(net, dim=1) + + # Loss function + objective = { + 'segm': LovaszSegLoss(per_image=False), + } + + loss_weight = { + 'segm': 100.0 + } + + actor = segm_actors.STAActor(net=net, objective=objective, loss_weight=loss_weight, + num_refinement_iter=2, disable_all_bn=True) + + # Optimizer + optimizer = optim.Adam([{'params': actor.net.target_model.filter_initializer.parameters(), 'lr': 5e-5}, + {'params': actor.net.target_model.filter_optimizer.parameters(), 'lr': 1e-4}, + {'params': actor.net.target_model.feature_extractor.parameters(), 'lr': 2e-5}, + {'params': actor.net.decoder.parameters(), 'lr': 1e-4}, + {'params': actor.net.label_encoder.parameters(), 'lr': 2e-4}, + {'params': actor.net.bbox_encoder.parameters(), 'lr': 2e-4}, + {'params': actor.net.segm_encoder.parameters(), 'lr': 2e-4}, + {'params': actor.net.target_model_segm.filter_initializer.parameters(), 'lr': 5e-5}, + {'params': actor.net.target_model_segm.filter_optimizer.parameters(), 'lr': 1e-4}], + lr=2e-4) + + lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60], gamma=0.2) + + trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler) + + trainer.train(80, load_latest=True, fail_safe=False) diff --git a/pytracking/evaluation/data.py b/pytracking/evaluation/data.py index 6446efb6..5a680bb4 100644 --- a/pytracking/evaluation/data.py +++ b/pytracking/evaluation/data.py @@ -103,6 +103,13 @@ def init_bbox(self, frame_num=0): def init_mask(self, frame_num=0): return self.object_init_data(frame_num=frame_num).get('init_mask') + def get_bbox(self, frame_num, object_id=None): + # print(self.ground_truth_rect, flush=True) + if object_id is not None: + return self.ground_truth_rect[object_id][frame_num] + else: + return self.ground_truth_rect[frame_num] + def get_info(self, keys, frame_num=None): info = dict() for k in keys: diff --git a/pytracking/evaluation/multi_object_wrapper.py b/pytracking/evaluation/multi_object_wrapper.py index 63d62609..2ef7d015 100644 --- a/pytracking/evaluation/multi_object_wrapper.py +++ b/pytracking/evaluation/multi_object_wrapper.py @@ -189,3 +189,212 @@ def visdom_draw_tracking(self, image, box, segmentation): self.visdom.register((image, *box), 'Tracking', 1, 'Tracking') else: self.visdom.register((image, *box, segmentation), 'Tracking', 1, 'Tracking') + + +class MultiObjectWrapperSTA: + def __init__(self, base_tracker_class, params, visdom=None, fast_load=False, frame_reader=None): + self.base_tracker_class = base_tracker_class + self.params = params + self.visdom = visdom + self.frame_reader = frame_reader + + self.initialized_ids = [] + self.trackers = OrderedDict() + + self.fast_load = fast_load + if self.fast_load: + self.tracker_copy = self.base_tracker_class(self.params) + if hasattr(self.tracker_copy, 'initialize_features'): + self.tracker_copy.initialize_features() + + def create_tracker(self): + tracker = None + if self.fast_load: + try: + tracker = copy.deepcopy(self.tracker_copy) + except: + pass + if tracker is None: + tracker = self.base_tracker_class(self.params) + tracker.visdom = self.visdom + + tracker.frame_reader = self.frame_reader + return tracker + + def _split_info(self, info): + info_split = OrderedDict() + init_other = OrderedDict() # Init other contains init info for all other objects + for obj_id in info['init_object_ids']: + info_split[obj_id] = dict() + init_other[obj_id] = dict() + info_split[obj_id]['object_ids'] = [obj_id] + info_split[obj_id]['sequence_object_ids'] = info['sequence_object_ids'] + if 'init_bbox' in info: + info_split[obj_id]['init_bbox'] = info['init_bbox'][obj_id] + init_other[obj_id]['init_bbox'] = info['init_bbox'][obj_id] + if 'init_mask' in info: + info_split[obj_id]['init_mask'] = (info['init_mask'] == int(obj_id)).astype(np.uint8) + init_other[obj_id]['init_mask'] = info_split[obj_id]['init_mask'] + for obj_info in info_split.values(): + obj_info['init_other'] = init_other + return info_split + + def _set_defaults(self, tracker_out: dict, defaults=None): + defaults = {} if defaults is None else defaults + + for key, val in defaults.items(): + if tracker_out.get(key) is None: + tracker_out[key] = val + + return tracker_out + + def default_merge(self, out_all): + out_merged = OrderedDict() + + out_first = list(out_all.values())[0] + out_types = out_first.keys() + + # Merge segmentation mask + if 'segmentation' in out_types and out_first['segmentation'] is not None: + # Stack all masks + # If a tracker outputs soft segmentation mask, use that. Else use the binary segmentation + segmentation_maps = [out.get('segmentation_soft', out['segmentation']) for out in out_all.values()] + segmentation_maps = np.stack(segmentation_maps) + + obj_ids = np.array([0, *map(int, out_all.keys())], dtype=np.uint8) + segm_threshold = getattr(self.params, 'segmentation_threshold', 0.5) + merged_segmentation = obj_ids[np.where(segmentation_maps.max(axis=0) > segm_threshold, + segmentation_maps.argmax(axis=0) + 1, 0)] + + out_merged['segmentation'] = merged_segmentation + + # Merge other fields + for key in out_types: + if key == 'segmentation': + pass + else: + out_merged[key] = {obj_id: out[key] for obj_id, out in out_all.items()} + + return out_merged + + def merge_outputs(self, out_all): + if hasattr(self.base_tracker_class, 'merge_results'): + out_merged = self.trackers[self.initialized_ids[0]].merge_results(out_all) + else: + out_merged = self.default_merge(out_all) + + return out_merged + + def initialize(self, image, info: dict) -> dict: + self.initialized_ids = [] + self.trackers = OrderedDict() + + if len(info['init_object_ids']) == 0: + return None + + object_ids = info['object_ids'] + + init_info_split = self._split_info(info) + self.trackers = OrderedDict({obj_id: self.create_tracker() for obj_id in object_ids}) + + out_all = OrderedDict() + # Run individual trackers for each object + for obj_id in info['init_object_ids']: + start_time = time.time() + out = self.trackers[obj_id].initialize(image, init_info_split[obj_id]) + if out is None: + out = {} + + init_default = {'target_bbox': init_info_split[obj_id].get('init_bbox'), + 'time': time.time() - start_time, + 'segmentation': init_info_split[obj_id].get('init_mask')} + + out = self._set_defaults(out, init_default) + out_all[obj_id] = out + + self.initialized_ids = info['init_object_ids'].copy() + self.initialized_ids_for_store = info['init_object_ids'].copy() + + # Merge results + out_merged = self.merge_outputs(out_all) + + return out_merged + + def track(self, image, bbox, info: dict = None) -> dict: + if info is None: + info = {} + + prev_output = info.get('previous_output', OrderedDict()) + + if info.get('init_object_ids', False): + init_info_split = self._split_info(info) + for obj_init_info in init_info_split.values(): + obj_init_info['previous_output'] = prev_output + + info['init_other'] = list(init_info_split.values())[0]['init_other'] + + out_all = OrderedDict() + for obj_id in self.initialized_ids: + start_time = time.time() + + out = self.trackers[obj_id].track(image, bbox, info) + + default = {'time': time.time() - start_time} + out = self._set_defaults(out, default) + out_all[obj_id] = out + + # Initialize new + if info.get('init_object_ids', False): + if info['init_object_ids'] != self.initialized_ids: + for obj_id in info['init_object_ids']: + # if not obj_id in self.trackers: + # self.trackers[obj_id] = self.create_tracker() + + start_time = time.time() + self.trackers[obj_id].track(image, bbox, info) + # out = self.trackers[obj_id].initialize(image, init_info_split[obj_id]) + # if out is None: + out = {} + + init_default = {'target_bbox': init_info_split[obj_id].get('init_bbox'), + 'time': time.time() - start_time, + 'segmentation': init_info_split[obj_id].get('init_mask')} + + out = self._set_defaults(out, init_default) + out_all[obj_id] = out + + # if info['init_object_ids'] != self.initialized_ids: + self.initialized_ids.extend(info['init_object_ids']) + + # Merge results + out_merged = self.merge_outputs(out_all) + + return out_merged + + def store_seq(self, image, bbox, info): + if info is None: + info = {} + + if info.get('init_object_ids', False): + init_info_split = self._split_info(info) + + for obj_id in self.initialized_ids_for_store: + out = self.trackers[obj_id].store_seq(image, bbox, info) + + # Initialize new + if info.get('init_object_ids', False): + for obj_id in info['init_object_ids']: + if not obj_id in self.trackers: + self.trackers[obj_id] = self.create_tracker() + + start_time = time.time() + out = self.trackers[obj_id].initialize(image, init_info_split[obj_id]) + if out is None: + out = {} + + init_default = {'target_bbox': init_info_split[obj_id].get('init_bbox'), + 'time': time.time() - start_time, + 'segmentation': init_info_split[obj_id].get('init_mask')} + + out = self._set_defaults(out, init_default) + self.initialized_ids_for_store.extend(info['init_object_ids']) diff --git a/pytracking/evaluation/tracker.py b/pytracking/evaluation/tracker.py index 4bc4ad04..5b6afebf 100644 --- a/pytracking/evaluation/tracker.py +++ b/pytracking/evaluation/tracker.py @@ -11,7 +11,7 @@ from pytracking.utils.plotting import draw_figure, overlay_mask from pytracking.utils.convert_vot_anno_to_rect import convert_vot_anno_to_rect from ltr.data.bounding_box_utils import masks_to_bboxes -from pytracking.evaluation.multi_object_wrapper import MultiObjectWrapper +from pytracking.evaluation.multi_object_wrapper import MultiObjectWrapper, MultiObjectWrapperSTA from pathlib import Path import torch @@ -21,6 +21,34 @@ 7: (123, 123, 123), 8: (255, 128, 0), 9: (128, 0, 255)} +class FrameReader: + def __init__(self, seq_info): + self.seq_info = seq_info + self.frame_names = seq_info.frames + self.frames = [None for _ in self.frame_names] + self.num_frames_in_memory = 0 + + def get_frame(self, frame_number): + if self.frames[frame_number] is None: + self.frames[frame_number] = self._read_image(self.frame_names[frame_number]) + self.num_frames_in_memory += 1 + + return self.frames[frame_number] + + def get_init_info(self, frame_number): + return self.seq_info.frame_info(frame_number) + + def get_bbox(self, frame_number, object_id=None): + return self.seq_info.get_bbox(frame_number, object_id) + + def _read_image(self, image_file: str): + im = cv.imread(image_file) + return cv.cvtColor(im, cv.COLOR_BGR2RGB) + + def num_frames(self): + return len(self.frame_names) + + def trackerlist(name: str, parameter_name: str, run_ids = None, display_name: str = None): """Generate list of trackers. args: @@ -138,14 +166,24 @@ def run_sequence(self, seq, visualization=None, debug=None, visdom_info=None, mu if multiobj_mode is None: multiobj_mode = getattr(params, 'multiobj_mode', getattr(self.tracker_class, 'multiobj_mode', 'default')) + # TODO only enable ground-truth passing if explicitly allowed in param settings + frame_reader = FrameReader(seq) + if multiobj_mode == 'default' or is_single_object: tracker = self.create_tracker(params) + tracker.frame_reader = frame_reader elif multiobj_mode == 'parallel': - tracker = MultiObjectWrapper(self.tracker_class, params, self.visdom) + if self.name == "sta": + tracker = MultiObjectWrapperSTA(self.tracker_class, params, self.visdom, frame_reader=frame_reader) + else: + tracker = MultiObjectWrapper(self.tracker_class, params, self.visdom) else: raise ValueError('Unknown multi object mode {}'.format(multiobj_mode)) - output = self._track_sequence(tracker, seq, init_info) + if self.name == "sta": + output = self._track_sequence_sta(tracker, frame_reader, init_info) + else: + output = self._track_sequence(tracker, seq, init_info) return output def _track_sequence(self, tracker, seq, init_info): @@ -226,6 +264,102 @@ def _store_outputs(tracker_out: dict, defaults=None): output.pop(key) return output + + def _track_sequence_sta(self, tracker, frame_reader, init_info): + # Define outputs + # Each field in output is a list containing tracker prediction for each frame. + + # In case of single object tracking mode: + # target_bbox[i] is the predicted bounding box for frame i + # time[i] is the processing time for frame i + # segmentation[i] is the segmentation mask for frame i (numpy array) + + # In case of multi object tracking mode: + # target_bbox[i] is an OrderedDict, where target_bbox[i][obj_id] is the predicted box for target obj_id in + # frame i + # time[i] is either the processing time for frame i, or an OrderedDict containing processing times for each + # object in frame i + # segmentation[i] is the multi-label segmentation mask for frame i (numpy array) + + output = {'target_bbox': [], + 'time': [], + 'segmentation': []} + + def _store_outputs(tracker_out: dict, defaults=None): + defaults = {} if defaults is None else defaults + for key in output.keys(): + val = tracker_out.get(key, defaults.get(key, None)) + if key in tracker_out or val is not None: + output[key].append(val) + + # Initialize + image = frame_reader.get_frame(0) + + if tracker.params.visualization and self.visdom is None: + self.visualize(image, init_info.get('init_bbox')) + + start_time = time.time() + out = tracker.initialize(image, init_info) + if out is None: + out = {} + + prev_output = OrderedDict(out) + + # init_default = {'target_bbox': init_info.get('init_bbox'), + # 'time': time.time() - start_time, + # 'segmentation': init_info.get('init_mask')} + + # _store_outputs(out, init_default) + for frame_num in range(1, frame_reader.num_frames()): + image = frame_reader.get_frame(frame_num) + if "object_ids" in init_info.keys(): + bbox = OrderedDict() + for object_id in init_info["object_ids"]: + bbox[object_id] = frame_reader.get_bbox(frame_num, object_id) + else: + bbox = frame_reader.get_bbox(frame_num) + info = frame_reader.get_init_info(frame_num) + tracker.store_seq(image, bbox, info) + + for frame_num in range(0, frame_reader.num_frames()): + while True: + if not self.pause_mode: + break + elif self.step: + self.step = False + break + else: + time.sleep(0.1) + + image = frame_reader.get_frame(frame_num) + + if "object_ids" in init_info.keys(): + bbox = OrderedDict() + for object_id in init_info["object_ids"]: + bbox[object_id] = frame_reader.get_bbox(frame_num, object_id) + else: + bbox = frame_reader.get_bbox(frame_num) + + start_time = time.time() + + info = frame_reader.get_init_info(frame_num) + info['previous_output'] = prev_output + + out = tracker.track(image, bbox, info) + prev_output = OrderedDict(out) + _store_outputs(out, {'time': time.time() - start_time}) + + segmentation = out['segmentation'] if 'segmentation' in out else None + if self.visdom is not None: + tracker.visdom_draw_tracking(image, out['target_bbox'], segmentation) + elif tracker.params.visualization: + self.visualize(image, out['target_bbox'], segmentation) + + # for key in ['target_bbox', 'segmentation']: + # if key in output and len(output[key]) <= 1: + # output.pop(key) + + return output def run_video(self, videofilepath, optional_box=None, debug=None, visdom_info=None, save_results=False): """Run the tracker with the vieofile. diff --git a/pytracking/experiments/myexperiments.py b/pytracking/experiments/myexperiments.py index afe2a029..3217bf5f 100644 --- a/pytracking/experiments/myexperiments.py +++ b/pytracking/experiments/myexperiments.py @@ -17,3 +17,19 @@ def uav_test(): dataset = get_dataset('uav') return trackers, dataset + +def sta_ytvos(): + trackers = [] + trackers.extend(trackerlist('sta', 'sta_ytvos', range(0, 1))) + + dataset = get_dataset('yt2019_jjval') + + return trackers, dataset + +def sta_davis(): + trackers = [] + trackers.extend(trackerlist('sta', 'sta_davis', range(0, 1))) + + dataset = get_dataset('dv2017_val') + + return trackers, dataset diff --git a/pytracking/features/preprocessing_sta.py b/pytracking/features/preprocessing_sta.py new file mode 100644 index 00000000..abc3f0e5 --- /dev/null +++ b/pytracking/features/preprocessing_sta.py @@ -0,0 +1,189 @@ +import torch +import torch.nn.functional as F +import numpy as np + + +def numpy_to_torch(a: np.ndarray): + return torch.from_numpy(a).float().permute(2, 0, 1).unsqueeze(0) + +def torch_to_numpy(a: torch.Tensor): + return a.squeeze(0).permute(1,2,0).numpy() + +def sample_patch_transformed(im, bbox, pos, scale, image_sz, transforms, is_mask=False): + """Extract transformed image samples. + args: + im: Image. + bbox: Bounding box. + pos: Center position for extraction. + scale: Image scale to extract features from. + image_sz: Size to resize the image samples to before extraction. + transforms: A set of image transforms to apply. + """ + + # Get image patche + im_patch, _, bbox_patch = sample_patch(im, bbox, pos, scale*image_sz, image_sz, is_mask=is_mask) + + # Apply transforms + im_patches = torch.cat([T(im_patch, is_mask=is_mask) for T in transforms]) + + return im_patches, bbox_patch + + +def sample_patch_multiscale(im, bbox, pos, scales, image_sz, mode: str='replicate', max_scale_change=None): + """Extract image patches at multiple scales. + args: + im: Image. + pos: Center position for extraction. + scales: Image scales to extract image patches from. + image_sz: Size to resize the image samples to + mode: how to treat image borders: 'replicate' (default), 'inside' or 'inside_major' + max_scale_change: maximum allowed scale change when using 'inside' and 'inside_major' mode + """ + if isinstance(scales, (int, float)): + scales = [scales] + + # Get image patches + patch_iter, coord_iter, bbox_iter = zip(*(sample_patch(im, bbox, pos, s*image_sz, image_sz, mode=mode, + max_scale_change=max_scale_change) for s in scales)) + im_patches = torch.cat(list(patch_iter)) + patch_coords = torch.cat(list(coord_iter)) + patch_bboxes = torch.cat(list(bbox_iter)) + + return im_patches, patch_coords, patch_bboxes + + +def sample_patch(im: torch.Tensor, bbox: torch.Tensor, pos: torch.Tensor, sample_sz: torch.Tensor, output_sz: torch.Tensor = None, + mode: str = 'replicate', max_scale_change=None, is_mask=False): + """Sample an image patch. + + args: + im: Image + bbox: Bounding box. + pos: center position of crop + sample_sz: size to crop + output_sz: size to resize to + mode: how to treat image borders: 'replicate' (default), 'inside' or 'inside_major' + max_scale_change: maximum allowed scale change when using 'inside' and 'inside_major' mode + """ + + # if mode not in ['replicate', 'inside']: + # raise ValueError('Unknown border mode \'{}\'.'.format(mode)) + + # copy and convert + posl = pos.long().clone() + + pad_mode = mode + + bbox_clone = bbox.clone() + + # Get new sample size if forced inside the image + if mode == 'inside' or mode == 'inside_major': + pad_mode = 'replicate' + im_sz = torch.Tensor([im.shape[2], im.shape[3]]) + shrink_factor = (sample_sz.float() / im_sz) + if mode == 'inside': + shrink_factor = shrink_factor.max() + elif mode == 'inside_major': + shrink_factor = shrink_factor.min() + shrink_factor.clamp_(min=1, max=max_scale_change) + sample_sz = (sample_sz.float() / shrink_factor).long() + + # Compute pre-downsampling factor + if output_sz is not None: + resize_factor = torch.min(sample_sz.float() / output_sz.float()).item() + df = int(max(int(resize_factor - 0.1), 1)) + else: + df = int(1) + + sz = sample_sz.float() / df # new size + + # Do downsampling + if df > 1: + os = posl % df # offset + posl = (posl - os) / df # new position + im2 = im[..., os[0].item()::df, os[1].item()::df] # downsample + bbox_clone[..., :2] -= os + bbox_clone /= df + else: + im2 = im + # plot_image(torch_to_numpy(im2), bbox_clone, "out_mid.jpg") + # exit() + # compute size to crop + szl = torch.max(sz.round(), torch.Tensor([2])).long() + + # Extract top and bottom coordinates + tl = posl - (szl - 1)/2 + br = posl + szl/2 + 1 + + # Shift the crop to inside + if mode == 'inside' or mode == 'inside_major': + im2_sz = torch.LongTensor([im2.shape[2], im2.shape[3]]) + shift = (-tl).clamp(0) - (br - im2_sz).clamp(0) + tl += shift + br += shift + + outside = ((-tl).clamp(0) + (br - im2_sz).clamp(0)) // 2 + shift = (-tl - outside) * (outside > 0).long() + tl += shift + br += shift + + # Get image patch + # im_patch = im2[...,tl[0].item():br[0].item(),tl[1].item():br[1].item()] + + # Get image patch + if not is_mask: + im_patch = F.pad(im2, (-tl[1].item(), br[1].item() - im2.shape[3], -tl[0].item(), br[0].item() - im2.shape[2]), pad_mode) + else: + im_patch = F.pad(im2, (-tl[1].item(), br[1].item() - im2.shape[3], -tl[0].item(), br[0].item() - im2.shape[2])) + bbox_clone[:, :, 0] -= tl[1].item() + bbox_clone[:, :, 1] -= tl[0].item() + + # Get image coordinates + patch_coord = df * torch.cat((tl, br)).view(1,4) + + if output_sz is None or (im_patch.shape[-2] == output_sz[0] and im_patch.shape[-1] == output_sz[1]): + return im_patch.clone(), patch_coord, bbox_clone + + bbox_clone[:, :, :2] /= (torch.tensor(im_patch.shape[-2:]) / output_sz) + bbox_clone[:, :, 2:] /= (torch.tensor(im_patch.shape[-2:]) / output_sz) + # Resample + if not is_mask: + im_patch = F.interpolate(im_patch, output_sz.long().tolist(), mode='bilinear') + else: + im_patch = F.interpolate(im_patch, output_sz.long().tolist(), mode='nearest') + + + return im_patch, patch_coord, bbox_clone + + +def crop_and_resize(im, crop_bb, output_sz, mask=None): + x1 = crop_bb[0] + x2 = crop_bb[0] + crop_bb[2] + + y1 = crop_bb[1] + y2 = crop_bb[1] + crop_bb[3] + + # Crop target + im_crop = F.pad(im, (-x1, x2 - im.shape[-1], -y1, y2 - im.shape[-2]), 'replicate') + im_crop = F.interpolate(im_crop, output_sz.long().tolist(), mode='bilinear') + + if mask is None: + return im_crop + + mask_crop = F.pad(mask, (-x1, x2 - im.shape[-1], -y1, y2 - im.shape[-2])) + mask_crop = F.interpolate(mask_crop, output_sz.long().tolist(), mode='nearest') + + return im_crop, mask_crop + +def plot_image(image, bbox=None, savePath="out.png"): + from PIL import Image, ImageDraw + ##image shape: h, w, 3 + im = Image.fromarray(image.astype(np.uint8)) + drawer = ImageDraw.Draw(im) + ##bbox shape: 1, 1, 4 + shape = bbox.round().squeeze(0).squeeze(0).numpy().astype(np.int) + shape[2:] += shape[:2] + drawer.rectangle([(shape[0],shape[1]),(shape[2], shape[3])], outline="red") + im.save(savePath) + print("image saved") + diff --git a/pytracking/notebooks/analyze_results.ipynb b/pytracking/notebooks/analyze_results.ipynb index 46a8eeaf..a8534cdb 100644 --- a/pytracking/notebooks/analyze_results.ipynb +++ b/pytracking/notebooks/analyze_results.ipynb @@ -2,16 +2,14 @@ "cells": [ { "cell_type": "markdown", - "metadata": {}, "source": [ "# Generating Results on Datasets" - ] + ], + "metadata": {} }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 1, "source": [ "%load_ext autoreload\n", "%autoreload 2\n", @@ -23,21 +21,180 @@ "\n", "sys.path.append('../..')\n", "from pytracking.analysis.plot_results import plot_results, print_results, print_per_sequence_results\n", - "from pytracking.evaluation import Tracker, get_dataset, trackerlist" - ] + "from pytracking.evaluation import Tracker, get_dataset, trackerlist\n", + "from pytracking.analysis.evaluate_vos import evaluate_vos" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 2, + "source": [ + "# evaluate STA results\n", + "trackers = []\n", + "trackers.extend(trackerlist('sta', 'sta_davis_cp', range(0, 1)))\n", + "evaluate_vos(trackers, dataset='dv2017_val')" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "DAVIS/2017/val loaded.\n", + "1/30: bike-packing: 2 objects\n", + "joint 1: acc 0.611 ┊░▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▄▄▅▅▄▄▄▄▄▄▄▄▅▅▄▄▄▄▅▅▅▅▄▄▅▅▅▅▆▆▆▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅░┊\n", + "joint 2: acc 0.738 ┊░███████▇▇███▇██▇▆▆▆▆▅▅▅▅▅▅▅▅▆▅▅▆▅▆▃▃▃▃▅▃▅▅▃▅▅▅▅▆▅▆▆▆▆▆▆▆▆▆▇▆▇▇▇▆▇▇▆░┊\n", + "final : acc 0.675 (0.675) ┊░▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▅▆▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▄▄▄▄▅▄▅▅▄▅▅▅▅▅▆▆▆▆▆▅▆▆▆▆▆▆▆▆▆▆▆▆▆░┊\n", + "2/30: blackswan: 1 object\n", + "final : acc 0.944 (0.764) ┊░▇█████████████▇█████████████████████████████████░┊\n", + "3/30: bmx-trees: 2 objects\n", + "joint 1: acc 0.428 ┊░▄▄▄▄▄▄▄▃▃▄▄▄▄▄▄▄▄▃▂▄▄▄▃▃▄▃▄▃▃▃▄▄▄▄▄▄▄▃▄▄▄▄▄▄▄▄▄▃▄▄▃▄▄▃▂▃▂▃▂▂▁▃▂▂▂▃▃▃▃▄▃▄▄▄▄▄▄▃░┊\n", + "joint 2: acc 0.727 ┊░▇▇▇▇▆▆▆▆▆▆▆▇▇▇▇▅▆▇▇▆▅▅▇▇▆▆▆▇▇▆▇▇▇▇▆▇▇▇▇▇▇▇▇▇▆▇▆▆▆▆▅▄▃▃▄▄▄▅▄▄▂▂▃▄▅▅▄▃▆▇▇▇▇▇▇▇▇▆░┊\n", + "final : acc 0.578 (0.690) ┊░▆▅▆▅▅▅▅▅▅▅▅▅▅▅▅▄▅▅▄▅▄▄▅▅▅▄▅▅▅▅▅▅▅▅▅▅▅▅▅▆▅▆▅▅▅▅▅▅▅▅▄▄▃▃▃▄▃▄▃▃▂▂▂▃▃▄▄▃▄▅▅▅▅▅▆▅▅▄░┊\n", + "4/30: breakdance: 1 object\n", + "final : acc 0.884 (0.722) ┊░▇▇▇▇▇▇▇▇▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▅▆▇▇▇▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▇▇▇▇▇▇▇▆░┊\n", + "5/30: camel: 1 object\n", + "final : acc 0.929 (0.752) ┊░████▇█▇▇▇█▇▇▇▇▇▇▇█▇███▇▇▇▇▇▇▇▇▇▇▇▇▇██████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████▇▇▇▇▇▇▇███▇▇▇▇░┊\n", + "6/30: car-roundabout: 1 object\n", + "final : acc 0.977 (0.780) ┊░██████████████████████████████████████████████████████████████████████▇██░┊\n", + "7/30: car-shadow: 1 object\n", + "final : acc 0.966 (0.801) ┊░██████████████████████████████████████░┊\n", + "8/30: cows: 1 object\n", + "final : acc 0.933 (0.814) ┊░████████▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇██▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████▇▇▇█▇█▇█▇▇▇▇▇▇▇█▇▇▇▇▇▇░┊\n", + "9/30: dance-twirl: 1 object\n", + "final : acc 0.850 (0.817) ┊░▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆▇▆▆▆▆▆▆▆▆▆▇▆▆▆▆▆▆▆▇▇▇▆▇▇▇▇▇▇▇▇▇▇██▇▇▇▇▇▇███▇▇▇▇▇▇▇▇░┊\n", + "10/30: dog: 1 object\n", + "final : acc 0.930 (0.826) ┊░▇▇▇▇▇▇▇▇▇█▇▇▇▇▇▇▇▇██████▇██▇▇▇▇▇▇▇█▇▇█▇██████▇▇███▇▇▇▇▇▇▇▇░┊\n", + "11/30: dogs-jump: 3 objects\n", + "joint 1: acc 0.885 ┊░▇▇▇▇▇▇▇▄▁▁▄▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██▇▇██▇████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇░┊\n", + "joint 2: acc 0.926 ┊░██▇█▇▇▇▇▇▇█▇▇████▇███████▇█▇▇██▇▇▇▇███▇▇█▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇▇▇▇▇▇▇▇░┊\n", + "joint 3: acc 0.904 ┊░█▇▇▇▇▇▆▆▅▅▆▆▆▇▇▇███████▇▇██▇▇▇▇▇▇▇██▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇████████▇█░┊\n", + "final : acc 0.905 (0.842) ┊░█▇▇▇▇▇▇▆▄▄▆▇▇▇▇▇█▇██▇███▇▇▇▇▇▇█▇▇▇██▇█▇██▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇▇▇▇░┊\n", + "12/30: drift-chicane: 1 object\n", + "final : acc 0.908 (0.846) ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██▇███▇▇░┊\n", + "13/30: drift-straight: 1 object\n", + "final : acc 0.941 (0.852) ┊░█▇█████████████████████████▇█▇█▇█▇▇▇▇█▇███▇▇▇▇▇▇░┊\n", + "14/30: goat: 1 object\n", + "final : acc 0.885 (0.854) ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇░┊\n", + "15/30: gold-fish: 5 objects\n", + "joint 1: acc 0.847 ┊░▇▇▇▇▇▇▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▇▇▆▆▇▅▇▇▇▇▅▅▆▆▆▇░┊\n", + "joint 2: acc 0.835 ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▇▆▆▆▆▆▆▆▇▇▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇░┊\n", + "joint 3: acc 0.871 ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆░┊\n", + "joint 4: acc 0.893 ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇░┊\n", + "joint 5: acc 0.900 ┊░▇▇▇██▇██████▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇░┊\n", + "final : acc 0.869 (0.857) ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▇▇▇▇░┊\n", + "16/30: horsejump-high: 2 objects\n", + "joint 1: acc 0.790 ┊░▆▆▆▆▆▇▆▆▆▆▆▆▇▇▆▆▆▆▆▆▆▆▆▇▆▆▆▆▆▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆░┊\n", + "joint 2: acc 0.835 ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▆▆▇▇▆▆▇▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▇▇▇▇░┊\n", + "final : acc 0.813 (0.853) ┊░▇▆▇▇▇▇▇▆▇▇▆▆▇▇▆▆▆▆▆▆▆▇▆▆▆▇▇▆▆▇▆▇▇▇▇▇▇▇▆▇▆▆▆▆▆▆▇▇░┊\n", + "17/30: india: 3 objects\n", + "joint 1: acc 0.919 ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇▇▇██▇▇██▇▇▇▇▇▇▇▇▇▇█▆▆▄▇▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆██████████████████░┊\n", + "joint 2: acc 0.822 ┊░▇▇▇▇▇▇▇▇▇▇█▇▇████▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▄▇▇▇▇▇▇▇▇▆▇▆▅▆▅██████████████▅▆▅▆▅▆▄▄▄▁▁▁▁▁░┊\n", + "joint 3: acc 0.834 ┊░█▇██▇▇██▇▇▇▇▇▇█▇█▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▆▄█▆▇▇▇▇▆▇▆▇▂█████████▇▇▇▅▁▂▄▄▅▆▆▆▆▆▆▆▆▆▆▆▆▇▆▆▆░┊\n", + "final : acc 0.858 (0.854) ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▆▆▇▇▆▇▇▇▇▆▇▇▇▇▇▇███▇▇▇▇▅▆▆▇▇▇▇▆▇▆▆▆▇▆▆▆▅▅▅▅▅░┊\n", + "18/30: judo: 2 objects\n", + "joint 1: acc 0.866 ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇░┊\n", + "joint 2: acc 0.837 ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆▇▇▇▇▇▆▆▇▆▇▇▇▇▇░┊\n", + "final : acc 0.851 (0.854) ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▇▇▇▇▇▇▆▇▇▇▇▇▇▇░┊\n", + "19/30: kite-surf: 3 objects\n", + "joint 1: acc 0.221 ┊░▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▂▁▁▁▁▂▂▂░┊\n", + "joint 2: acc 0.584 ┊░▁▄▃▄▅▄▅▅▅▅▄▃▄▂▃█████▆▆▆▆▆▅▅▄▄▁██████ ▁▁▂ ▂█▂▄▁▅▄░┊\n", + "joint 3: acc 0.702 ┊░▅▆▅▆▆▆▆▅▆▆▆▆▆▆▅▆▆▆▅▅▅▆▆▆▆▆▆▆▅▆▆▆▆▆▆▅▅▅▅▅▆▆▆▆▅▆▆▆░┊\n", + "final : acc 0.502 (0.822) ┊░▃▄▄▄▄▄▄▄▄▄▄▄▄▃▄▅▅▅▅▅▄▄▄▄▅▄▄▄▄▃▅▅▅▅▅▅▂▂▂▃▃▃▅▃▄▃▄▄░┊\n", + "20/30: lab-coat: 5 objects\n", + "joint 1: acc 0.598 ┊░▃▄▂▄▄▄▃▃▃▄▂ ▂▅▆▅▆▆▆▅▆▆▅▅▄▅▆▅▅▅▃▆▇▆▆▇▆▇▆▇▆▆▇▇░┊\n", + "joint 2: acc 0.553 ┊░▁▁▁ ▃▄▃▃▄▄▅▆▆▆▆▅▅▄▅▅▅▅▃▄▃▃▄▄▄▄▆▄▅▅▅▆▆▆▆▆▅▅▆▅▅░┊\n", + "joint 3: acc 0.944 ┊░████████████████████▇██▇▇▇██▇▇▇██████████████░┊\n", + "joint 4: acc 0.854 ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▇▅▅▆▆▆▆▅▅▇█▇▆▆███▇▇███▇▇███░┊\n", + "joint 5: acc 0.910 ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██▇██▇▇▇▇█▇████████░┊\n", + "final : acc 0.772 (0.815) ┊░▅▅▅▅▆▆▆▆▆▆▆▆▆▆▇▆▆▆▆▆▆▆▆▆▆▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇░┊\n", + "21/30: libby: 1 object\n", + "final : acc 0.886 (0.817) ┊░▇▇▇▇▇▇████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▇▇▇▇▇▇▇▇▆▇▆░┊\n", + "22/30: loading: 3 objects\n", + "joint 1: acc 0.963 ┊░████████████████████████████████████████████████░┊\n", + "joint 2: acc 0.824 ┊░▇▆▇▇▇▇▇▇▆▇▆▆▆▆▇▆▇▇▆▆▇▇▇▇▇▇▇▇▅▇▆▇▆▆▇▇▇▇▅▇▇▇▅▆▇▇▇▇░┊\n", + "joint 3: acc 0.901 ┊░██████▇▇▇█▇▇█▇███▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▆▇▇▇▇▇░┊\n", + "final : acc 0.896 (0.823) ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▇▆▇▇▇▇▇░┊\n", + "23/30: mbike-trick: 2 objects\n", + "joint 1: acc 0.816 ┊░▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▆▆▇▇▇▇▇▇▇▇▇▆▇▇▇▇▇▆▆▇▇▇▇▇▆▇▆▆▆▇▇▇▇▇▇▇▇▇▆▇▇▇▇▇▆▇▆▆▆▆▇▇▇▇▇▆▇░┊\n", + "joint 2: acc 0.801 ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▅▅▅▅▅▅▅▅▆▆▆▆▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇░┊\n", + "final : acc 0.809 (0.822) ┊░▇▆▆▆▇▇▇▇▇▇▆▇▇▆▇▇▇▇▇▇▆▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇░┊\n", + "24/30: motocross-jump: 2 objects\n", + "joint 1: acc 0.818 ┊░▅▆▅▇▇▆▆▇▇▇▇▇▇▇▇▇▇▆▄▇▇█▇▅▆█▆██▆▆▆▆▇▇▆▆▆░┊\n", + "joint 2: acc 0.834 ┊░▆▇▆▇▇▇▇▇▇▇▇▇▆▇▆▆▇▇▇▇▇▇▇▆▆▇▇▇▇▆▇▇▆▇▇▆▆▆░┊\n", + "final : acc 0.826 (0.822) ┊░▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▅▇▇▇▇▆▆▇▆▇▇▆▇▆▆▇▇▆▆▆░┊\n", + "25/30: paragliding-launch: 3 objects\n", + "joint 1: acc 0.755 ┊░▆▆▄▅▆▅▅▅▅▆▆▆▆▆▇▇▆▇▇▇▇▆▆▆▇▇▆▆▇▇▇▇▇▇▇▆▇▆▆▇▇▆▇▇▇▆▆▆▆▅▆▆▅▅▆▅▅▅▅▅▅▆▆▆▅▆▆▆▆▆▆▆▆▆▆▆▆▆░┊\n", + "joint 2: acc 0.632 ┊░▅▅▅▅▅▅▅▅▅▅▅▅▅▅▆▅▅▅▅▅▆▅▅▅▆▅▅▅▅▆▆▆▆▆▆▆▆▅▅▅▆▅▆▆▆▆▆▆▆▅▆▅▅▅▄▄▄▄▅▅▅▅▅▅▅▄▅▄▅▄▄▄▄▅▄▄▄▃░┊\n", + "joint 3: acc 0.084 ┊░▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁ ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▁▁▁ ▁ ░┊\n", + "final : acc 0.490 (0.802) ┊░▄▄▃▄▄▄▄▃▄▄▄▄▄▄▅▄▄▅▄▄▄▄▄▄▄▄▄▄▄▅▅▅▅▅▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▃▄▄▃▄▃▃▃▃▃▃▄▄▃▄▃▄▄▄▄▃▄▄▄▄▄▃▃▃░┊\n", + "26/30: parkour: 1 object\n", + "final : acc 0.931 (0.805) ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇▇███▇█▇▇▇▇▇█▇▇▇▇▇▇▇█▇▇▇█▇██▇███████████▇▇█▇▇▇███▇█▇████▇▇███████▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇▇▇▇▇▇░┊\n", + "27/30: pigs: 3 objects\n", + "joint 1: acc 0.913 ┊░▇█████▇▇▇▇▇▇▇▇▇█▇▇█▇██████▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇░┊\n", + "joint 2: acc 0.704 ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▆▆▇▇▇▇▇▆▆▆▅▆▅▅▅▂▄▄▄▆▆▆▅▅▅▄▄▃▃▂ ▁▃▃▄▅▆▆▇▇▇▇▇░┊\n", + "joint 3: acc 0.933 ┊░███▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████▇▇▇▇▇▇█▇█████████▇▇▇▇▇▇▇▇▇▇▇▇▇▇██▇███████████▇███▇▇░┊\n", + "final : acc 0.850 (0.807) ┊░▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▅▆▆▆▇▇▇▆▆▇▆▆▆▆▆▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇▇░┊\n", + "28/30: scooter-black: 2 objects\n", + "joint 1: acc 0.815 ┊░▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▇░┊\n", + "joint 2: acc 0.875 ┊░▇▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇░┊\n", + "final : acc 0.845 (0.808) ┊░▆▆▆▆▇▆▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇░┊\n", + "29/30: shooting: 3 objects\n", + "joint 1: acc 0.790 ┊░▅▆▅▆▅▆▅▆▆▆▅▁▄▇▇▆▇▇▇▇▇▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇░┊\n", + "joint 2: acc 0.886 ┊░▆▇▇▇▆▆▇▇▇▇▇▆▇█▇▇█████▆████▆▇▇▇▇▇▇▇█▇█▇░┊\n", + "joint 3: acc 0.870 ┊░▆▆▇▇▇█▇██▇▇▇▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇▇░┊\n", + "final : acc 0.849 (0.811) ┊░▆▆▆▇▆▇▇▇▇▇▆▄▆▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇░┊\n", + "30/30: soapbox: 3 objects\n", + "joint 1: acc 0.857 ┊░▆▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▇▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆░┊\n", + "joint 2: acc 0.821 ┊░▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆░┊\n", + "joint 3: acc 0.824 ┊░▅▆▅▅▅▆▅▅▆▆▅▅▅▆▅▆▆▆▆▆▆▆▅▅▅▆▆▅▆▆▇▇▆▆▇▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████████▇▇▇▇▇▇▇░┊\n", + "final : acc 0.834 (0.812) ┊░▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▇▆░┊\n", + "J: 0.812, recall: 0.926, decay: 0.002\n", + "\n", + " | J-Mean | J-Recall | J-Decay |\n", + "sta sta_davis_cp_000 | 81.17 | 92.56 | 0.15 |\n", + "\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "([ J-Mean J-Recall J-Decay\n", + " 0 0.811713 0.925566 0.001533],\n", + " [ Sequence J-Mean J-Recall J-Decay\n", + " 0 bike-packing_1 0.610973 1.000000 -0.049671\n", + " 1 bike-packing_2 0.738134 0.910448 0.146141\n", + " 2 blackswan_1 0.944352 1.000000 -0.001557\n", + " 3 bmx-trees_1 0.427937 0.128205 0.093377\n", + " 4 bmx-trees_2 0.727114 0.884615 0.159450\n", + " .. ... ... ... ...\n", + " 56 shooting_2 0.886161 1.000000 -0.086682\n", + " 57 shooting_3 0.869925 1.000000 -0.035461\n", + " 58 soapbox_1 0.857451 1.000000 -0.118126\n", + " 59 soapbox_2 0.820919 1.000000 -0.062378\n", + " 60 soapbox_3 0.823770 1.000000 -0.246549\n", + " \n", + " [61 rows x 4 columns]])" + ] + }, + "metadata": {}, + "execution_count": 2 + } + ], + "metadata": {} }, { "cell_type": "markdown", - "metadata": {}, "source": [ "## Plots for OTB, NFS and UAV" - ] + ], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "trackers = []\n", "trackers.extend(trackerlist('atom', 'default', range(0,5), 'ATOM'))\n", @@ -61,20 +218,20 @@ "dataset = get_dataset('otb', 'nfs', 'uav')\n", "plot_results(trackers, dataset, 'OTB+NFS+UAV', merge_results=True, plot_types=('success', 'prec'), \n", " skip_missing_seq=False, force_evaluation=True, plot_bin_gap=0.05, exclude_invalid_frames=False)" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "markdown", - "metadata": {}, "source": [ "## Plots for LaSOT" - ] + ], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "trackers = []\n", "trackers.extend(trackerlist('atom', 'default', range(0,5), 'ATOM'))\n", @@ -86,20 +243,20 @@ "dataset = get_dataset('lasot')\n", "plot_results(trackers, dataset, 'LaSOT', merge_results=True, plot_types=('success'), \n", " skip_missing_seq=False, force_evaluation=True, plot_bin_gap=0.05)" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "markdown", - "metadata": {}, "source": [ "## Tables for OTB, NFS, UAV and LaSOT" - ] + ], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "trackers = []\n", "trackers.extend(trackerlist('atom', 'default', range(0,5), 'ATOM'))\n", @@ -122,70 +279,71 @@ "\n", "dataset = get_dataset('lasot')\n", "print_results(trackers, dataset, 'LaSOT', merge_results=True, plot_types=('success', 'prec', 'norm_prec'))" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "markdown", - "metadata": {}, "source": [ "## Filtered per-sequence results" - ] + ], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "# Print per sequence results for sequences where all trackers fail, i.e. all trackers have average overlap in percentage of less than 10.0\n", "filter_criteria = {'mode': 'ao_max', 'threshold': 10.0}\n", "dataset = get_dataset('otb', 'nfs', 'uav')\n", "print_per_sequence_results(trackers, dataset, 'OTB+NFS+UAV', merge_results=True, filter_criteria=filter_criteria, force_evaluation=False)" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "# Print per sequence results for sequences where at least one tracker fails, i.e. a tracker has average overlap in percentage of less than 10.0\n", "filter_criteria = {'mode': 'ao_min', 'threshold': 10.0}\n", "dataset = get_dataset('otb', 'nfs', 'uav')\n", "print_per_sequence_results(trackers, dataset, 'OTB+NFS+UAV', merge_results=True, filter_criteria=filter_criteria, force_evaluation=False)" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "# Print per sequence results for sequences where the trackers have differing behavior.\n", "# i.e. average overlap in percentage for different trackers on a sequence differ by at least 40.0\n", "filter_criteria = {'mode': 'delta_ao', 'threshold': 40.0}\n", "dataset = get_dataset('otb', 'nfs', 'uav')\n", "print_per_sequence_results(trackers, dataset, 'OTB+NFS+UAV', merge_results=True, filter_criteria=filter_criteria, force_evaluation=False)" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "# Print per sequence results for all sequences\n", "filter_criteria = None\n", "dataset = get_dataset('otb', 'nfs', 'uav')\n", "print_per_sequence_results(trackers, dataset, 'OTB+NFS+UAV', merge_results=True, filter_criteria=filter_criteria, force_evaluation=False)" - ] + ], + "outputs": [], + "metadata": {} } ], "metadata": { "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" + "name": "python3", + "display_name": "Python 3.7.6 64-bit ('pytracking': conda)" }, "language_info": { "codemirror_mode": { @@ -197,9 +355,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.2" + "version": "3.7.6" + }, + "interpreter": { + "hash": "75b3c4f0cb80b41d7624071a965f226bc66105853f746c5b75437c4d24e84699" } }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/pytracking/parameter/sta/__init__.py b/pytracking/parameter/sta/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pytracking/parameter/sta/sta_davis.py b/pytracking/parameter/sta/sta_davis.py new file mode 100644 index 00000000..f7c6b51a --- /dev/null +++ b/pytracking/parameter/sta/sta_davis.py @@ -0,0 +1,65 @@ +from pytracking.utils import TrackerParams +from pytracking.features.net_wrappers import NetWithBackbone + +def parameters(): + params = TrackerParams() + + params.debug = 0 + params.visualization = False + + params.seg_to_bb_mode = 'var' + params.max_scale_change = (0.95, 1.1) + params.min_mask_area = 100 + + params.return_raw_scores = True + params.use_gpu = True + + params.image_sample_size = (30 * 16, 52 * 16) + params.search_area_scale = 4.0 + params.border_mode = 'inside_major' + params.patch_max_scale_change = None + + # Learning parameters + params.use_merged_mask_for_memory = True + params.sample_memory_size = 32 + params.learning_rate = 0.1 + params.init_samples_minimum_weight = 0 + params.train_skipping = 1 + params.train_sample_interval = 1 + params.test_sample_interval = 5 + params.test_num_frames = 9 + + # Net optimization params + params.update_classifier = True + params.net_opt_iter = 5 + params.net_opt_update_iter = 15 + params.net_opt_hn_iter = 1 + + # Detection parameters + params.window_output = False + + # Init augmentation parameters + params.use_augmentation = False + params.augmentation = {} + + params.augmentation_expansion_factor = 2 + params.random_shift_factor = 1/3 + + # IoUnet parameters + params.use_iou_net = False # Use the augmented samples to compute the modulation vector + + params.net = NetWithBackbone(net_path='STANet_ep0080.pth.tar', + use_gpu=params.use_gpu, + image_format='bgr255', + mean=[102.9801, 115.9465, 122.7717], + std=[1.0, 1.0, 1.0] + ) + + params.vot_anno_conversion_type = 'preserve_area' + + return params + + + + + diff --git a/pytracking/parameter/sta/sta_ytvos.py b/pytracking/parameter/sta/sta_ytvos.py new file mode 100644 index 00000000..12dc4116 --- /dev/null +++ b/pytracking/parameter/sta/sta_ytvos.py @@ -0,0 +1,65 @@ +from pytracking.utils import TrackerParams +from pytracking.features.net_wrappers import NetWithBackbone + +def parameters(): + params = TrackerParams() + + params.debug = 0 + params.visualization = False + + params.seg_to_bb_mode = 'var' + params.max_scale_change = (0.95, 1.1) + params.min_mask_area = 100 + + params.return_raw_scores = True + params.use_gpu = True + + params.image_sample_size = (30 * 16, 52 * 16) + params.search_area_scale = 4.0 + params.border_mode = 'inside_major' + params.patch_max_scale_change = None + + # Learning parameters + params.use_merged_mask_for_memory = True + params.sample_memory_size = 32 + params.learning_rate = 0.1 + params.init_samples_minimum_weight = 0 + params.train_skipping = 1 + params.train_sample_interval = 1 + params.test_sample_interval = 3 + params.test_num_frames = 9 + + # Net optimization params + params.update_classifier = True + params.net_opt_iter = 5 + params.net_opt_update_iter = 15 + params.net_opt_hn_iter = 1 + + # Detection parameters + params.window_output = False + + # Init augmentation parameters + params.use_augmentation = False + params.augmentation = {} + + params.augmentation_expansion_factor = 2 + params.random_shift_factor = 1/3 + + # IoUnet parameters + params.use_iou_net = False # Use the augmented samples to compute the modulation vector + + params.net = NetWithBackbone(net_path='STANet_ep0080.pth.tar', + use_gpu=params.use_gpu, + image_format='bgr255', + mean=[102.9801, 115.9465, 122.7717], + std=[1.0, 1.0, 1.0] + ) + + params.vot_anno_conversion_type = 'preserve_area' + + return params + + + + + diff --git a/pytracking/tracker/sta/__init__.py b/pytracking/tracker/sta/__init__.py new file mode 100644 index 00000000..ce36598c --- /dev/null +++ b/pytracking/tracker/sta/__init__.py @@ -0,0 +1,5 @@ +from .sta import STA + + +def get_tracker_class(): + return STA \ No newline at end of file diff --git a/pytracking/tracker/sta/sta.py b/pytracking/tracker/sta/sta.py new file mode 100644 index 00000000..4c16f297 --- /dev/null +++ b/pytracking/tracker/sta/sta.py @@ -0,0 +1,480 @@ +from pytracking.tracker.base import BaseTracker +import torch +import torch.nn.functional as F +import numpy as np +import math +import time +from pytracking import TensorList +from pytracking.features.preprocessing_sta import numpy_to_torch, torch_to_numpy +from pytracking.features.preprocessing_sta import sample_patch_multiscale, sample_patch_transformed, sample_patch +from pytracking.features import augmentation +from collections import OrderedDict + +def plot_image(image, savePath="out.png"): + from PIL import Image, ImageDraw + ##image shape: h, w, 1 + im = Image.fromarray(image.astype(np.uint8)) + drawer = ImageDraw.Draw(im) + im.save(savePath) + print("image saved") + +class STA(BaseTracker): + multiobj_mode = 'parallel' + + def predicts_segmentation_mask(self): + return True + + def initialize_features(self): + if not getattr(self, 'features_initialized', False): + self.params.net.initialize() + self.features_initialized = True + + def initialize(self, image, info: dict) -> dict: + # Learn the initial target model. Initialize memory etc. + self.frame_num = 0 + if not self.params.has('device'): + self.params.device = 'cuda' if self.params.use_gpu else 'cpu' + + # Initialize network + self.initialize_features() + + # The segmentation network + self.net = self.params.net + + # Time initialization + tic = time.time() + + # Get target position and size + state = info['init_bbox'] + init_bbox = info.get('init_bbox', None) + + if init_bbox is not None: + # shape 1, 1, 4 (frames, seq, 4) + init_bbox = torch.tensor(init_bbox).unsqueeze(0).unsqueeze(0).float() + + # Set target center and target size + self.pos = torch.Tensor([state[1] + (state[3] - 1)/2, state[0] + (state[2] - 1)/2]) + self.target_sz = torch.Tensor([state[3], state[2]]) + + # Get object ids + self.object_id = info.get('object_ids', [None])[0] + self.id_str = '' if self.object_id is None else ' {}'.format(self.object_id) + + # Set sizes + sz = self.params.image_sample_size + self.img_sample_sz = torch.Tensor([sz, sz] if isinstance(sz, int) else sz) + self.img_support_sz = self.img_sample_sz + + # Set search area. + self.search_area_scale = [self.params.search_area_scale] if isinstance(self.params.search_area_scale, (int, float)) else self.params.search_area_scale + search_area = [torch.prod(self.target_sz * s).item() for s in self.search_area_scale] + self.target_scale = [math.sqrt(s) / self.img_sample_sz.prod().sqrt() for s in search_area] + + # Convert image + im = numpy_to_torch(image) + + # Extract and transform sample + self.feature_sz = self.img_sample_sz / 16 + ksz = self.net.target_model.filter_size + self.kernel_size = torch.Tensor([ksz, ksz] if isinstance(ksz, (int, float)) else ksz) + self.output_sz = self.feature_sz + (self.kernel_size + 1)%2 + + im_patches_ms = [] + init_bboxes_ms = [] + for p, t in zip(self.get_centered_sample_pos(), self.target_scale): + _, _, im_patches, init_bboxes = self.extract_backbone_features(im, init_bbox, p, t, self.img_sample_sz) + im_patches_ms.append(im_patches.unsqueeze(0)) + init_bboxes_ms.append(init_bboxes.unsqueeze(0)) + + self.init_memory(im_patches_ms, init_bboxes_ms) + out = {'time': time.time() - tic} + + # If object is visible in the i-th aved frame + self.visible = [1] + + return out + + def store_seq(self, image, bbox, info): + if self.object_id is None: + bbox = bbox + else: + bbox = bbox[self.object_id] + if bbox[0] == -1 and bbox[1] == -1 and bbox[2] == -1 and bbox[3] == -1: + self.visible.append(0) + else: + self.visible.append(1) + + self.pos = torch.Tensor([bbox[1] + (bbox[3] - 1)/2, bbox[0] + (bbox[2] - 1)/2]) + self.target_sz = torch.Tensor([bbox[3], bbox[2]]) + bbox = torch.tensor(bbox).unsqueeze(0).unsqueeze(0).float() + + search_area = [torch.prod(self.target_sz * s).item() for s in self.search_area_scale] + self.target_scale = [math.sqrt(s) / self.img_sample_sz.prod().sqrt() for s in search_area] + + # Convert image + im = numpy_to_torch(image) + + # Extract backbone features + im_patches_ms = [] + init_bboxes_ms = [] + for p, t in zip(self.get_centered_sample_pos(), self.target_scale): + _, _, im_patches, patch_bboxes = self.extract_backbone_features(im, bbox, p, t, self.img_sample_sz) + im_patches_ms.append(im_patches.unsqueeze(0)) + init_bboxes_ms.append(patch_bboxes.unsqueeze(0)) + + # Update the tracker memory + self.update_memory(im_patches_ms, init_bboxes_ms) + + def track(self, image, bbox, info: dict = None) -> dict: + self.debug_info = {} + + self.frame_num += 1 + self.debug_info['frame_num'] = self.frame_num + + if self.object_id is None: + bbox = bbox + else: + bbox = bbox[self.object_id] + + if bbox[0] == -1 and bbox[1] == -1 and bbox[2] == -1 and bbox[3] == -1: + segmentation_mask_im = np.full(image.shape[:2], 0) + segmentation_output = np.full(image.shape[:2], -100.0) + if self.object_id is None: + segmentation_output = 1 / (1 + np.exp(-segmentation_output)) + out = {'segmentation': segmentation_mask_im, 'target_bbox': bbox, + 'segmentation_raw': segmentation_output} + return out + + self.pos = torch.Tensor([bbox[1] + (bbox[3] - 1)/2, bbox[0] + (bbox[2] - 1)/2]) + self.target_sz = torch.Tensor([bbox[3], bbox[2]]) + bbox = torch.tensor(bbox).unsqueeze(0).unsqueeze(0).float() + + search_area = [torch.prod(self.target_sz * s).item() for s in self.search_area_scale] + self.target_scale = [math.sqrt(s) / self.img_sample_sz.prod().sqrt() for s in search_area] + + # ********************************************************************** # + # ---------- Predict segmentation mask for the current frame ----------- # + # ********************************************************************** # + + # Convert image + im = numpy_to_torch(image) + + segmentation_scores_im_ms = [] + for i, (p, t) in enumerate(zip(self.get_centered_sample_pos(), self.target_scale)): + _, sample_coords, im_patches, patch_bboxes = self.extract_backbone_features(im, bbox, p, t, self.img_sample_sz) + + # predict segmentation masks + segmentation_scores = self.update_target_model(im_patches, patch_bboxes, i) + + # Location of sample + sample_pos, sample_scale = self.get_sample_location(sample_coords) + + # Get the segmentation scores for the full image. + # Regions outside the search region are assigned low scores (-100) + segmentation_scores_im_ms.append(self.convert_scores_crop_to_image(segmentation_scores, im, sample_scale, sample_pos)) + segmentation_scores_im_ms = torch.stack(segmentation_scores_im_ms, dim=0) + segmentation_scores_im = torch.mean(segmentation_scores_im_ms, dim=0) + + bbox = bbox.round().squeeze(0).squeeze(0).numpy().astype(np.int) + segmentation_scores_im[..., :bbox[0]] = -100 + segmentation_scores_im[..., bbox[0]+bbox[2]:] = -100 + segmentation_scores_im[..., :bbox[1], :] = -100 + segmentation_scores_im[..., bbox[1]+bbox[3]:, :] = -100 + + segmentation_mask_im = (segmentation_scores_im > 0.0).float() # Binary segmentation mask + segmentation_prob_im = torch.sigmoid(segmentation_scores_im) # Probability of being target at each pixel + + # ************************************************************************ # + # ---------- Output estimated segmentation mask and target box ----------- # + # ************************************************************************ # + + # Get target box from the predicted segmentation + pred_pos, pred_target_sz = self.get_target_state(segmentation_prob_im.squeeze()) + new_state = torch.cat((pred_pos[[1, 0]] - (pred_target_sz[[1, 0]] - 1) / 2, pred_target_sz[[1, 0]])) + output_state = new_state.tolist() + + if self.object_id is None: + # In single object mode, no merge called. Hence return the probabilities + segmentation_output = segmentation_prob_im + else: + # In multi-object mode, return raw scores + segmentation_output = segmentation_scores_im + + segmentation_mask_im = segmentation_mask_im.view(*segmentation_mask_im.shape[-2:]).cpu().numpy() + segmentation_output = segmentation_output.cpu().numpy() + + if self.visdom is not None: + self.visdom.register(segmentation_scores_im, 'heatmap', 2, 'Seg Scores' + self.id_str) + self.visdom.register(self.debug_info, 'info_dict', 1, 'Status') + + out = {'segmentation': segmentation_mask_im, 'target_bbox': output_state, + 'segmentation_raw': segmentation_output} + return out + + def merge_results(self, out_all): + """ Merges the predictions of individual targets""" + out_merged = OrderedDict() + + obj_ids = list(out_all.keys()) + + # Merge segmentation scores using the soft-aggregation approach from RGMP + segmentation_scores = [] + for id in obj_ids: + if 'segmentation_raw' in out_all[id].keys(): + segmentation_scores.append(out_all[id]['segmentation_raw']) + else: + # If 'segmentation_raw' is not present, then this is the initial frame for the target. Convert the + # GT Segmentation mask to raw scores (assign 100 to target region, -100 to background) + segmentation_scores.append((out_all[id]['segmentation'] - 0.5) * 200.0) + + segmentation_scores = np.stack(segmentation_scores) + segmentation_scores = torch.from_numpy(segmentation_scores).float() + segmentation_prob = torch.sigmoid(segmentation_scores) + + # Obtain seg. probability and scores for background label + eps = 1e-7 + bg_p = torch.prod(1 - segmentation_prob, dim=0).clamp(eps, 1.0 - eps) # bg prob + bg_score = (bg_p / (1.0 - bg_p)).log() + + segmentation_scores_all = torch.cat((bg_score.unsqueeze(0), segmentation_scores), dim=0) + + out = [] + for s in segmentation_scores_all: + s_out = 1.0 / (segmentation_scores_all - s.unsqueeze(0)).exp().sum(dim=0) + out.append(s_out) + + segmentation_maps_t_agg = torch.stack(out, dim=0) + segmentation_maps_np_agg = segmentation_maps_t_agg.numpy() + + # Obtain segmentation mask + obj_ids_all = np.array([0, *map(int, obj_ids)], dtype=np.uint8) + merged_segmentation = obj_ids_all[segmentation_maps_np_agg.argmax(axis=0)] + + out_merged['segmentation'] = merged_segmentation + out_merged['segmentation_raw'] = OrderedDict({key: segmentation_maps_np_agg[i + 1] + for i, key in enumerate(obj_ids)}) + + # target_bbox + out_first = list(out_all.values())[0] + out_types = out_first.keys() + + for key in out_types: + if 'segmentation' in key: + pass + elif 'target_bbox' in key: + # Update the target box using the merged segmentation mask + merged_boxes = {} + for obj_id, out in out_all.items(): + segmentation_prob = torch.from_numpy(out_merged['segmentation_raw'][obj_id]) + pred_pos, pred_target_sz = self.get_target_state(segmentation_prob) + new_state = torch.cat((pred_pos[[1, 0]] - (pred_target_sz[[1, 0]] - 1) / 2, pred_target_sz[[1, 0]])) + merged_boxes[obj_id] = new_state.tolist() + out_merged['target_bbox'] = merged_boxes + else: + # For fields other than segmentation predictions or target box, only convert the data structure + out_merged[key] = {obj_id: out[key] for obj_id, out in out_all.items()} + + return out_merged + + def get_target_state(self, segmentation_prob_im): + """ Estimate target bounding box using the predicted segmentation probabilities """ + + # If predicted mask area is too small, target might be occluded. In this case, just return prev. box + if segmentation_prob_im.sum() < self.params.get('min_mask_area', -10): + return self.pos, self.target_sz + + if self.params.get('seg_to_bb_mode') == 'var': + # Target center is the center of mass of the predicted per-pixel seg. probability scores + prob_sum = segmentation_prob_im.sum() + e_y = torch.sum(segmentation_prob_im.sum(dim=-1) * + torch.arange(segmentation_prob_im.shape[-2], dtype=torch.float32)) / prob_sum + e_x = torch.sum(segmentation_prob_im.sum(dim=-2) * + torch.arange(segmentation_prob_im.shape[-1], dtype=torch.float32)) / prob_sum + + # Target size is obtained using the variance of the seg. probability scores + e_h = torch.sum(segmentation_prob_im.sum(dim=-1) * + (torch.arange(segmentation_prob_im.shape[-2], dtype=torch.float32) - e_y)**2) / prob_sum + e_w = torch.sum(segmentation_prob_im.sum(dim=-2) * + (torch.arange(segmentation_prob_im.shape[-1], dtype=torch.float32) - e_x)**2) / prob_sum + + sz_factor = self.params.get('seg_to_bb_sz_factor', 4) + return torch.Tensor([e_y, e_x]), torch.Tensor([e_h.sqrt() * sz_factor, e_w.sqrt() * sz_factor]) + else: + raise Exception('Unknown seg_to_bb_mode mode {}'.format(self.params.get('seg_to_bb_mode'))) + + def get_sample_location(self, sample_coord): + """Get the location of the extracted sample.""" + sample_coord = sample_coord.float() + sample_pos = 0.5*(sample_coord[:2] + sample_coord[2:] - 1) + sample_scales = ((sample_coord[2:] - sample_coord[:2]) / self.img_sample_sz).prod().sqrt() + return sample_pos, sample_scales + + def get_centered_sample_pos(self): + """Get the center position for the new sample. Make sure the target is correctly centered.""" + return [self.pos + ((self.feature_sz + self.kernel_size) % 2) * t * \ + self.img_support_sz / (2*self.feature_sz) for t in self.target_scale] + + def convert_scores_crop_to_image(self, segmentation_scores, im, sample_scale, sample_pos): + """ Obtain segmentation scores for the full image using the scores for the search region crop. This is done by + assigning a low score (-100) for image regions outside the search region """ + + # Resize the segmention scores to match the image scale + segmentation_scores_re = F.interpolate(segmentation_scores, scale_factor=sample_scale.item(), mode='bilinear') + segmentation_scores_re = segmentation_scores_re.view(*segmentation_scores_re.shape[-2:]) + + # Regions outside search area get very low score + segmentation_scores_im = torch.ones(im.shape[-2:], dtype=segmentation_scores_re.dtype) * (-100.0) + + # Find the co-ordinates of the search region in the image scale + r1 = int(sample_pos[0].item() - 0.5*segmentation_scores_re.shape[-2]) + c1 = int(sample_pos[1].item() - 0.5*segmentation_scores_re.shape[-1]) + + r2 = r1 + segmentation_scores_re.shape[-2] + c2 = c1 + segmentation_scores_re.shape[-1] + + r1_pad = max(0, -r1) + c1_pad = max(0, -c1) + + r2_pad = max(r2 - im.shape[-2], 0) + c2_pad = max(c2 - im.shape[-1], 0) + + # Copy the scores for the search region + shape = segmentation_scores_re.shape + segmentation_scores_im[r1 + r1_pad:r2 - r2_pad, c1 + c1_pad:c2 - c2_pad] = \ + segmentation_scores_re[r1_pad:shape[0] - r2_pad, c1_pad:shape[1] - c2_pad] + + return segmentation_scores_im + + def segment_target(self, bbox_mask, sample_tm_feat, sample_x, segm=False): + with torch.no_grad(): + segmentation_scores = self.net.segment_target_add_bbox_encoder(bbox_mask, self.target_filter, sample_tm_feat, sample_x, segm) + + return segmentation_scores + + def extract_backbone_features(self, im: torch.Tensor, bbox, pos: torch.Tensor, scale, sz: torch.Tensor): + im_patches, patch_coords, patch_bboxes = sample_patch_multiscale(im, bbox, pos, scale.unsqueeze(0), sz, + mode=self.params.get('border_mode', 'replicate'), + max_scale_change=self.params.get('patch_max_scale_change', None)) + with torch.no_grad(): + backbone_feat = self.net.extract_backbone(im_patches) + return backbone_feat, patch_coords[0], im_patches[0], patch_bboxes[0].to(self.params.device) + + def get_target_model_features(self, backbone_feat): + """ Extract features input to the target model""" + with torch.no_grad(): + return self.net.extract_target_model_features(backbone_feat) + + def init_memory(self, train_x, bboxes): + """ Initialize the sample memory used to update the target model """ + # Initialize memory + self.training_samples = train_x + self.target_bboxes = bboxes + self.num_stored_samples = [x.shape[0] for x in self.training_samples] + + def update_memory(self, sample_x, bboxes): + """ Add a new sample to the memory""" + for i in range(len(self.training_samples)): + self.training_samples[i] = torch.cat([self.training_samples[i], sample_x[i]], dim=0) + self.target_bboxes[i] = torch.cat([self.target_bboxes[i], bboxes[i]], dim=0) + self.num_stored_samples = [x.shape[0] for x in self.training_samples] + + def update_target_model(self, train_x, bbox, i, learning_rate=None): + # Set flags and learning rate + if learning_rate is None: + learning_rate = self.params.learning_rate + + # Decide the number of iterations to run + num_iter = 0 + if (self.frame_num - 1) % self.params.train_skipping == 0: + num_iter = self.params.get('net_opt_update_iter', None) + + if num_iter > 0: + samples = train_x.unsqueeze(0) + bboxes = bbox.unsqueeze(0) + + sample_interval = self.params.get('test_sample_interval', 5) + test_num_frames = self.params.get('test_num_frames', 1) + num_append_left = (test_num_frames-1)//2 + num_append_right = (test_num_frames-1)//2 + + while self.frame_num-1-num_append_left*sample_interval < 0: + num_append_left -= 1 + num_append_right += 1 + while self.frame_num-1+num_append_right*sample_interval >= self.num_stored_samples[0]: + num_append_left += 1 + num_append_right -= 1 + + if self.params.get('casual', False) == True: + num_append_left += num_append_right + num_append_right = 0 + + # add samples in the past + ind = self.frame_num-1-sample_interval + while ind >= 0 and num_append_left > 0: + if self.visible[ind]: + samples = torch.cat((samples, self.training_samples[i][ind:ind+1]), dim=0) + bboxes = torch.cat((bboxes, self.target_bboxes[i][ind:ind+1]), dim=0) + ind -= sample_interval + num_append_left -= 1 + + # add samples in the future + ind = self.frame_num-1+sample_interval + while ind < self.num_stored_samples[0] and num_append_right > 0: + if self.visible[ind]: + samples = torch.cat((samples, self.training_samples[i][ind:ind+1]), dim=0) + bboxes = torch.cat((bboxes, self.target_bboxes[i][ind:ind+1]), dim=0) + ind += sample_interval + num_append_right -= 1 + + # print("sample shape", samples.shape, flush=True) + with torch.no_grad(): + backbone_feat = self.net.extract_backbone(samples) + test_x = self.get_target_model_features(backbone_feat) + train_bbox_enc, _ = self.net.label_encoder(bboxes, test_x, list(self.params.image_sample_size)) + + few_shot_label, few_shot_sw = self.net.bbox_encoder(bboxes, test_x.unsqueeze(1), list(self.params.image_sample_size)) + self.target_filter, _, _ = self.net.target_model.get_filter(test_x.unsqueeze(1), few_shot_label, few_shot_sw, + num_iter=num_iter) + segmentation_scores = self.segment_target(train_bbox_enc, test_x, backbone_feat) + train_segm_enc, train_segm_sw = self.net.segm_encoder(torch.sigmoid(segmentation_scores), test_x.unsqueeze(1)) + + # print("train_segm_enc shape", train_segm_enc.shape, flush=True) + # for i in range(train_segm_enc.shape[2]): + # vis_embedding = train_segm_enc[0:1,0,i:i+1,...] + # vis_embedding = F.interpolate(vis_embedding, size=self.params.image_sample_size, mode='bilinear') + # vis_embedding = vis_embedding[0,0,...].cpu().numpy() + # print("vis_embedding shape", vis_embedding.shape, flush=True) + # print("max value", np.max(vis_embedding), flush=True) + # print("min value", np.min(vis_embedding), flush=True) + # vis_embedding = vis_embedding / np.max(vis_embedding) * 255.0 + # plot_image(vis_embedding, savePath="output"+str(i)+".jpg") + # exit() + + self.target_filter, _, _ = self.net.target_model_segm.get_filter(test_x.unsqueeze(1), train_segm_enc, train_segm_sw, + num_iter=num_iter) + segmentation_scores = self.segment_target(train_bbox_enc, test_x, backbone_feat, True) + segmentation_score = segmentation_scores[0:1] + + augs = self.params.augmentation if self.params.get('use_augmentation', True) else {} + if 'fliplr' in augs: + sample_width = samples.shape[3] + samples = samples.flip((3)) + bboxes[:, :, 0] = sample_width - bboxes[:, :, 0] - bboxes[:, :, 2] + + backbone_feat = self.net.extract_backbone(samples) + test_x = self.get_target_model_features(backbone_feat) + train_bbox_enc, _ = self.net.label_encoder(bboxes, test_x, list(self.params.image_sample_size)) + + few_shot_label, few_shot_sw = self.net.bbox_encoder(bboxes, test_x.unsqueeze(1), list(self.params.image_sample_size)) + self.target_filter, _, _ = self.net.target_model.get_filter(test_x.unsqueeze(1), few_shot_label, few_shot_sw, + num_iter=num_iter) + segmentation_scores = self.segment_target(train_bbox_enc, test_x, backbone_feat) + + train_segm_enc, train_segm_sw = self.net.segm_encoder(torch.sigmoid(segmentation_scores), test_x.unsqueeze(1)) + self.target_filter, _, _ = self.net.target_model_segm.get_filter(test_x.unsqueeze(1), train_segm_enc, train_segm_sw, + num_iter=num_iter) + segmentation_scores = self.segment_target(train_bbox_enc, test_x, backbone_feat, True) + segmentation_score_flip = segmentation_scores[0:1] + segmentation_score = (segmentation_score + segmentation_score_flip.flip((3))) / 2.0 + + return segmentation_score