From 23bf57b41fee68fde3471a86d4533c660f9380bf Mon Sep 17 00:00:00 2001 From: YFaris <45187161+YFaris@users.noreply.github.com> Date: Wed, 31 May 2023 05:49:16 +0800 Subject: [PATCH] Update folder_data_set_loader.py --- .../core/loaders/folder_data_set_loader.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/core/linnaeus/core/loaders/folder_data_set_loader.py b/core/linnaeus/core/loaders/folder_data_set_loader.py index d377b1f..ad0530d 100644 --- a/core/linnaeus/core/loaders/folder_data_set_loader.py +++ b/core/linnaeus/core/loaders/folder_data_set_loader.py @@ -7,6 +7,42 @@ from linnaeus.core.data_augmentation import preprocessing +class FolderDataSetLoaderYoloFormat(Dataset): + def __init__(self): + """ + self-defined dataset for augmented dataset in yolo format + """ + super(FolderDataSetLoaderYoloFormat, self).__init__() + # root dir of images + self.root_images = "./Data/" + # root dir of annotations + self.root_labels = "./Annotations/" + # obtain the names of labels + self.labels = os.listdir(self.root_labels) + + def __getitem__(self, index): + """ + according to index obtain image and its label + :param index: + :return: + """ + image, bbox, category = get_voc_label("/home/wzl/VOC/VOC2007/VOCdevkit/VOC2007/Annotations/" + self.labels[index]) + # obtain template, search + template, search, bbox, mapping = transform(image, bbox) + bbox = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2, bbox[2] - bbox[0], bbox[3] - bbox[1]] + # process template and search + torch_template = torch.from_numpy(np.transpose(cv2.cvtColor(template, cv2.COLOR_BGR2RGB), (2, 0, 1))) + torch_search = torch.from_numpy(np.transpose(cv2.cvtColor(search, cv2.COLOR_BGR2RGB), (2, 0, 1))) + # color jitter + torch_template = self.color_jitter(torch_template.unsqueeze(0)).squeeze(0).type(torch.FloatTensor) / 255 + torch_search = self.color_jitter(torch_search.unsqueeze(0)).squeeze(0).type(torch.FloatTensor) / 255 + return torch_template, torch_search, torch.tensor(bbox) / 255, torch.tensor(mapping), torch.from_numpy( + np.transpose(cv2.resize(image, (256, 256)), (2, 0, 1))), category + + def __len__(self): + return len(self.labels) + + class FolderDataSetLoader(Dataset): def __init__(self, path, classes):