forked from finger-monkey/CMPS
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
67 lines (56 loc) · 2.15 KB
/
utils.py
File metadata and controls
67 lines (56 loc) · 2.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import os.path as osp
import argparse
import sys
import torch
from torch.utils.data import DataLoader
from reid import models
from torch.nn import functional as F
from reid import datasets
from MI_SGD import MI_SGD,keepGradUpdate
from test import test
from reid.utils.data import transforms as T
from torchvision.transforms import Resize
from reid.utils.data.preprocessor import Preprocessor
from reid.evaluators import Evaluator
from torch.optim.optimizer import Optimizer, required
import random
import numpy as np
import math
from reid.evaluators import extract_features
from reid.utils.meters import AverageMeter
import torchvision
import faiss
from torchvision import transforms
import logging
def get_data(sourceName, split_id, data_dir, height, width,
batch_size, workers, combine):
root = osp.join(data_dir, sourceName)
sourceSet = datasets.create(sourceName, root, num_val=0.1, split_id=split_id)
num_classes = sourceSet.num_trainval_ids if combine else sourceSet.num_train_ids
tgtSet = sourceSet
class_tgt = tgtSet.num_trainval_ids if combine else tgtSet.num_train_ids
train_transformer = T.Compose([
Resize((height, width)),
transforms.RandomGrayscale(p=0.2),
T.ToTensor(),
])
train_transformer2 = T.Compose([
Resize((height, width)),
T.ToTensor(),
])
train_step1 = DataLoader(
Preprocessor(sourceSet.trainval, root=sourceSet.images_dir, transform=train_transformer),
batch_size=batch_size, num_workers=workers, shuffle=False, pin_memory=True)
train_step3 = DataLoader(
Preprocessor(sourceSet.trainval, root=sourceSet.images_dir, transform=train_transformer2),
batch_size=batch_size, num_workers=workers, shuffle=False, pin_memory=True)
return sourceSet, sourceSet, num_classes, class_tgt, train_step1, train_step3
def calDist(qFeat, gFeat):
m, n = qFeat.size(0), gFeat.size(0)
x = qFeat.view(m, -1)
y = gFeat.view(n, -1)
dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
dist_m.addmm_(1, -2, x, y.t())
return dist_m