diff --git a/mipcandy/data/__init__.py b/mipcandy/data/__init__.py index 2426a32..76c5057 100644 --- a/mipcandy/data/__init__.py +++ b/mipcandy/data/__init__.py @@ -4,6 +4,6 @@ from mipcandy.data.download import download_dataset from mipcandy.data.geometric import ensure_num_dimensions, orthographic_views, aggregate_orthographic_views, crop from mipcandy.data.inspection import InspectionAnnotation, InspectionAnnotations, load_inspection_annotations, \ - inspect, ROIDataset + inspect, ROIDataset, RandomROIDataset from mipcandy.data.io import resample_to_isotropic, load_image, save_image from mipcandy.data.visualization import visualize2d, visualize3d, overlay diff --git a/mipcandy/data/inspection.py b/mipcandy/data/inspection.py index 31e1312..3d2130e 100644 --- a/mipcandy/data/inspection.py +++ b/mipcandy/data/inspection.py @@ -29,6 +29,7 @@ class InspectionAnnotation(object): shape: tuple[int, ...] foreground_bbox: tuple[int, int, int, int] | tuple[int, int, int, int, int, int] ids: tuple[int, ...] + foreground_samples: tuple[tuple[int, ...], ...] | None = None def foreground_shape(self) -> tuple[int, int] | tuple[int, int, int]: r = (self.foreground_bbox[1] - self.foreground_bbox[0], self.foreground_bbox[3] - self.foreground_bbox[2]) @@ -71,7 +72,7 @@ def __len__(self) -> int: def save(self, path: str | PathLike[str]) -> None: with open(path, "w") as f: - dump({"background": self._background, "annotations": self._annotations}, f) + dump({"background": self._background, "annotations": [a.to_dict() for a in self._annotations]}, f) def _get_shapes(self, get_shape: Callable[[InspectionAnnotation], tuple[int, ...]]) -> tuple[ tuple[int, ...] | None, tuple[int, ...], tuple[int, ...]]: @@ -211,8 +212,14 @@ def crop_roi(self, i: int, *, percentile: float = .95) -> tuple[torch.Tensor, to return crop(image.unsqueeze(0), roi).squeeze(0), crop(label.unsqueeze(0), roi).squeeze(0) +def _list_to_tuple(v: Any) -> Any: + if isinstance(v, list): + return tuple(_list_to_tuple(item) for item in v) + return v + + def _lists_to_tuples(pairs: Sequence[tuple[str, Any]]) -> dict[str, Any]: - return {k: tuple(v) if isinstance(v, list) else v for k, v in pairs} + return {k: _list_to_tuple(v) for k, v in pairs} def load_inspection_annotations(path: str | PathLike[str], dataset: SupervisedDataset) -> InspectionAnnotations: @@ -223,7 +230,9 @@ def load_inspection_annotations(path: str | PathLike[str], dataset: SupervisedDa )) -def inspect(dataset: SupervisedDataset, *, background: int = 0, console: Console = Console()) -> InspectionAnnotations: +def inspect(dataset: SupervisedDataset, *, background: int = 0, min_foreground_samples: int = 500, + max_foreground_samples: int = 10000, min_percent_coverage: float = 0.01, + console: Console = Console()) -> InspectionAnnotations: r = [] with Progress(*Progress.get_default_columns(), SpinnerColumn(), console=console) as progress: task = progress.add_task("Inspecting dataset...", total=len(dataset)) @@ -233,8 +242,24 @@ def inspect(dataset: SupervisedDataset, *, background: int = 0, console: Console mins = indices.min(dim=0)[0].tolist() maxs = indices.max(dim=0)[0].tolist() bbox = (mins[1], maxs[1], mins[2], maxs[2]) + if len(indices) > 0: + if len(indices) <= min_foreground_samples: + sampled = indices + else: + target_samples = min( + max_foreground_samples, + max(min_foreground_samples, int(np.ceil(len(indices) * min_percent_coverage))) + ) + sampled_idx = torch.randperm(len(indices))[:target_samples] + sampled = indices[sampled_idx] + foreground_samples = tuple(tuple(coord.tolist()) for coord in sampled) + else: + foreground_samples = None r.append(InspectionAnnotation( - label.shape[1:], bbox if label.ndim == 3 else bbox + (mins[3], maxs[3]), tuple(label.unique()) + label.shape[1:], + bbox if label.ndim == 3 else bbox + (mins[3], maxs[3]), + tuple(label.unique()), + foreground_samples )) return InspectionAnnotations(dataset, background, *r, device=dataset.device()) @@ -251,8 +276,64 @@ def __len__(self) -> int: @override def construct_new(self, images: list[torch.Tensor], labels: list[torch.Tensor]) -> Self: - return ROIDataset(self._annotations) + return self.__class__(self._annotations, percentile=self._percentile) @override def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: return self._annotations.crop_roi(idx, percentile=self._percentile) + + +class RandomROIDataset(ROIDataset): + def __init__(self, annotations: InspectionAnnotations, *, percentile: float = .95, + foreground_oversample_percent: float = 0.33) -> None: + super().__init__(annotations, percentile=percentile) + self._fg_oversample: float = foreground_oversample_percent + + def _random_roi(self, idx: int) -> tuple[int, int, int, int] | tuple[int, int, int, int, int, int]: + annotation = self._annotations[idx] + roi_shape = self._annotations.roi_shape(percentile=self._percentile) + roi = [] + for dim_size, patch_size in zip(annotation.shape, roi_shape): + left = patch_size // 2 + right = patch_size - left + min_center = left + max_center = dim_size - right + center = torch.randint(min_center, max_center + 1, (1,)).item() + roi.append(center - left) + roi.append(center + right) + return tuple(roi) + + def _foreground_guided_random_roi(self, idx: int) -> tuple[int, int, int, int] | tuple[ + int, int, int, int, int, int]: + annotation = self._annotations[idx] + roi_shape = self._annotations.roi_shape(percentile=self._percentile) + + if annotation.foreground_samples is None or len(annotation.foreground_samples) == 0: + return self._random_roi(idx) + + fg_idx = torch.randint(0, len(annotation.foreground_samples), (1,)).item() + fg_position = annotation.foreground_samples[fg_idx] + + roi = [] + for fg_pos, dim_size, patch_size in zip(fg_position, annotation.shape, roi_shape): + left = patch_size // 2 + right = patch_size - left + center = max(left, min(fg_pos, dim_size - right)) + roi.append(center - left) + roi.append(center + right) + return tuple(roi) + + @override + def construct_new(self, images: list[torch.Tensor], labels: list[torch.Tensor]) -> Self: + return self.__class__(self._annotations, percentile=self._percentile, + foreground_oversample_percent=self._fg_oversample) + + @override + def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: + image, label = self._annotations._dataset[idx] + force_fg = torch.rand(1).item() < self._fg_oversample + if force_fg: + roi = self._foreground_guided_random_roi(idx) + else: + roi = self._random_roi(idx) + return crop(image.unsqueeze(0), roi).squeeze(0), crop(label.unsqueeze(0), roi).squeeze(0)