From 3dde08724d31d1a5ed7e0332bf3174fcdf89a9af Mon Sep 17 00:00:00 2001 From: Daniel Aioanei Date: Thu, 14 Sep 2023 21:01:27 +0200 Subject: [PATCH] Reshape the bbox tensor so that the last dimension has size 4, even when empty. Otherwise torch.cdist fails during training with "cdist only supports at least 2D tensors, X2 got: 1D" on batches where none of the target entries have any objects. --- src/table_datasets.py | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/src/table_datasets.py b/src/table_datasets.py index 1fbe017e7..a559f0315 100644 --- a/src/table_datasets.py +++ b/src/table_datasets.py @@ -71,12 +71,12 @@ def crop_around_bbox_coco(image, crop_bbox, max_margin, target): cropped_labels.append(label) if len(cropped_bboxes) > 0: - target["boxes"] = torch.as_tensor(cropped_bboxes, dtype=torch.float32) + target["boxes"] = torch.as_tensor(cropped_bboxes, dtype=torch.float32).reshape(-1, 4) target["labels"] = torch.as_tensor(cropped_labels, dtype=torch.int64) w, h = img.size target["size"] = torch.tensor([w, h]) return cropped_image, target - + return image, target @@ -162,7 +162,7 @@ def __call__(self, image, target): cropped_labels.append(label) if len(cropped_bboxes) > 0: - target["boxes"] = torch.as_tensor(cropped_bboxes, dtype=torch.float32) + target["boxes"] = torch.as_tensor(cropped_bboxes, dtype=torch.float32).reshape(-1, 4) target["labels"] = torch.as_tensor(cropped_labels, dtype=torch.int64) return cropped_image, target @@ -202,9 +202,9 @@ def __call__(self, image, target): if bbox[0] < bbox[2] - 1 and bbox[1] < bbox[3] - 1: resized_bboxes.append(bbox) resized_labels.append(label) - + if len(resized_bboxes) > 0: - target["boxes"] = torch.as_tensor(resized_bboxes, dtype=torch.float32) + target["boxes"] = torch.as_tensor(resized_bboxes, dtype=torch.float32).reshape(-1, 4) target["labels"] = torch.as_tensor(resized_labels, dtype=torch.int64) return resized_image, target @@ -290,9 +290,9 @@ def __call__(self, image, target): if bbox[0] < bbox[2] and bbox[1] < bbox[3]: cropped_bboxes.append(bbox) cropped_labels.append(label) - + if len(cropped_bboxes) > 0: - target["boxes"] = torch.as_tensor(cropped_bboxes, dtype=torch.float32) + target["boxes"] = torch.as_tensor(cropped_bboxes, dtype=torch.float32).reshape(-1, 4) target["labels"] = torch.as_tensor(cropped_labels, dtype=torch.int64) return cropped_image, target @@ -324,9 +324,9 @@ def __call__(self, image, target): if bbox[0] < bbox[2] and bbox[1] < bbox[3]: cropped_bboxes.append(bbox) cropped_labels.append(label) - + if len(cropped_bboxes) > 0: - target["boxes"] = torch.as_tensor(cropped_bboxes, dtype=torch.float32) + target["boxes"] = torch.as_tensor(cropped_bboxes, dtype=torch.float32).reshape(-1, 4) target["labels"] = torch.as_tensor(cropped_labels, dtype=torch.int64) return cropped_image, target @@ -413,7 +413,7 @@ def __call__(self, image, target): bbox = [scale*elem for elem in bbox] resized_bboxes.append(bbox) - target["boxes"] = torch.as_tensor(resized_bboxes, dtype=torch.float32) + target["boxes"] = torch.as_tensor(resized_bboxes, dtype=torch.float32).reshape(-1, 4) return resized_image, target @@ -433,7 +433,7 @@ def __call__(self, image, target): bbox = [scale*elem for elem in bbox] resized_bboxes.append(bbox) - target["boxes"] = torch.as_tensor(resized_bboxes, dtype=torch.float32) + target["boxes"] = torch.as_tensor(resized_bboxes, dtype=torch.float32).reshape(-1, 4) return resized_image, target @@ -619,7 +619,7 @@ def __getitem__(self, idx): img = Image.open(img_path).convert("RGB") w, h = img.size - if self.types[idx] == 1: + if self.types[idx] == 1: bboxes, labels = read_pascal_voc(annot_path, class_map=self.class_map) # Reduce class set @@ -627,14 +627,8 @@ def __getitem__(self, idx): bboxes = [bboxes[idx] for idx in keep_indices] labels = [labels[idx] for idx in keep_indices] - # Convert to Torch Tensor - if len(labels) > 0: - bboxes = torch.as_tensor(bboxes, dtype=torch.float32) - labels = torch.as_tensor(labels, dtype=torch.int64) - else: - # Not clear if it's necessary to force the shape of bboxes to be (0, 4) - bboxes = torch.empty((0, 4), dtype=torch.float32) - labels = torch.empty((0,), dtype=torch.int64) + bboxes = torch.as_tensor(bboxes, dtype=torch.float32).reshape(-1, 4) + labels = torch.as_tensor(labels, dtype=torch.int64) else: bboxes = torch.empty((0, 4), dtype=torch.float32) labels = torch.empty((0,), dtype=torch.int64)