Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 5 additions & 123 deletions R/dataset-coco.R
Original file line number Diff line number Diff line change
Expand Up @@ -215,19 +215,13 @@ coco_detection_dataset <- torch::dataset(
#' Loads the MS COCO dataset for instance segmentation tasks.
#'
#' @rdname coco_segmentation_dataset
#' @param root Root directory where the dataset is stored or will be downloaded to.
#' @param train Logical. If TRUE, loads the training split; otherwise, loads the validation split.
#' @param year Character. Dataset version year. One of \code{"2014"} or \code{"2017"}.
#' @param download Logical. If TRUE, downloads the dataset if it's not already present in the \code{root} directory.
#' @param transform Optional transform function applied to the image.
#' @param target_transform Optional transform function applied to the target.
#' @inheritParams coco_detection_dataset
#' @param target_transform Optional transform function applied to the target.
#' Use \code{target_transform_coco_masks} to convert polygon annotations to binary masks.
#'
#' @return An object of class `coco_segmentation_dataset`. Each item is a list:
#' - `x`: a `(C, H, W)` array representing the image.
#' - `y$boxes`: a `(N, 4)` `torch_tensor` of bounding boxes in the format \eqn{(x_{min}, y_{min}, x_{max}, y_{max})}.
#' - `y$labels`: an integer `torch_tensor` with the class label for each object.
#' - `y$area`: a float `torch_tensor` indicating the area of each object.
#' - `y$iscrowd`: a boolean `torch_tensor`, where `TRUE` marks the object as part of a crowd.
#' - `y$segmentation`: a list of segmentation polygons for each object.
#' - `y$masks`: a `(N, H, W)` boolean `torch_tensor` containing binary segmentation masks (when using target_transform_coco_masks).
Expand All @@ -239,8 +233,8 @@ coco_detection_dataset <- torch::dataset(
#'
#' @details
#' The returned image `x` is in CHW format (channels, height, width), matching the torch convention.
#' The dataset `y` offers instance segmentation annotations including bounding boxes, labels,
#' areas, crowd indicators, and segmentation masks from the official COCO annotations.
#' The dataset `y` offers instance segmentation annotations including labels,
#' crowd indicators, and segmentation masks from the official COCO annotations.
#'
#' Files are downloaded to a \code{coco} subdirectory in the torch cache directory for better organization.
#'
Expand All @@ -262,70 +256,10 @@ coco_detection_dataset <- torch::dataset(
#' }
#' @family segmentation_dataset
#' @seealso \code{\link{coco_detection_dataset}} for object detection tasks
#' @importFrom jsonlite fromJSON
#' @export
coco_segmentation_dataset <- torch::dataset(
name = "coco_segmentation_dataset",
resources = data.frame(
year = rep(c(2017, 2014), each = 4 ),
content = rep(c("image", "annotation"), time = 2, each = 2),
split = rep(c("train", "val"), time = 4),
url = c("http://images.cocodataset.org/zips/train2017.zip", "http://images.cocodataset.org/zips/val2017.zip",
rep("http://images.cocodataset.org/annotations/annotations_trainval2017.zip", time = 2),
"http://images.cocodataset.org/zips/train2014.zip", "http://images.cocodataset.org/zips/val2014.zip",
rep("http://images.cocodataset.org/annotations/annotations_trainval2014.zip", time = 2)),
size = c("800 MB", "800 MB", rep("770 MB", time = 2), "6.33 GB", "6.33 GB", rep("242 MB", time = 2)),
md5 = c(c("cced6f7f71b7629ddf16f17bbcfab6b2", "442b8da7639aecaf257c1dceb8ba8c80"),
rep("f4bbac642086de4f52a3fdda2de5fa2c", time = 2),
c("0da8cfa0e090c266b78f30e2d2874f1a", "a3d79f5ed8d289b7a7554ce06a5782b3"),
rep("0a379cfc70b0e71301e0f377548639bd", time = 2)),
stringsAsFactors = FALSE
),

initialize = function(
root = tempdir(),
train = TRUE,
year = c("2017", "2014"),
download = FALSE,
transform = NULL,
target_transform = NULL
) {

year <- match.arg(year)
split <- ifelse(train, "train", "val")

root <- fs::path_expand(root)
self$root <- root
self$year <- year
self$split <- split
self$transform <- transform
self$target_transform <- target_transform
self$archive_size <- self$resources[self$resources$year == year & self$resources$split == split & self$resources$content == "image", ]$size

self$data_dir <- fs::path(root, glue::glue("coco{year}"))

image_year <- ifelse(year == "2016", "2014", year)
self$image_dir <- fs::path(self$data_dir, glue::glue("{split}{image_year}"))
self$annotation_file <- fs::path(self$data_dir, "annotations",
glue::glue("instances_{split}{year}.json"))

if (download) {
cli_inform("Dataset {.cls {class(self)[[1]]}} (~{.emph {self$archive_size}}) will be downloaded and processed if not already available.")
self$download()
}

if (!self$check_exists()) {
runtime_error("Dataset not found. You can use `download = TRUE` to download it.")
}

self$load_annotations()

cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$image_ids)} images.")
},

check_exists = function() {
fs::file_exists(self$annotation_file) && fs::dir_exists(self$image_dir)
},
inherit = coco_detection_dataset,

.getitem = function(index) {
image_id <- self$image_ids[index]
Expand All @@ -341,28 +275,18 @@ coco_segmentation_dataset <- torch::dataset(
anns <- self$annotations[self$annotations$image_id == image_id, ]

if (nrow(anns) > 0) {
boxes_wh <- torch::torch_tensor(do.call(rbind, anns$bbox), dtype = torch::torch_float())
boxes <- box_xywh_to_xyxy(boxes_wh)

label_ids <- anns$category_id
labels <- as.character(self$categories$name[match(label_ids, self$categories$id)])

area <- torch::torch_tensor(anns$area, dtype = torch::torch_float())
iscrowd <- torch::torch_tensor(as.logical(anns$iscrowd), dtype = torch::torch_bool())

} else {
# empty annotation
boxes <- torch::torch_zeros(c(0, 4), dtype = torch::torch_float())
labels <- character()
area <- torch::torch_empty(0, dtype = torch::torch_float())
iscrowd <- torch::torch_empty(0, dtype = torch::torch_bool())
anns$segmentation <- list()
}

y <- list(
boxes = boxes,
labels = labels,
area = area,
iscrowd = iscrowd,
segmentation = anns$segmentation
)
Expand All @@ -384,48 +308,6 @@ coco_segmentation_dataset <- torch::dataset(
class(result) <- c("image_with_segmentation_mask", class(result))
}
result
},

.length = function() {
length(self$image_ids)
},

download = function() {
annotation_filter <- self$resources$year == self$year & self$resources$split == self$split & self$resources$content == "annotation"
image_filter <- self$resources$year == self$year & self$resources$split == self$split & self$resources$content == "image"

cli_inform("Downloading {.cls {class(self)[[1]]}}...")

ann_zip <- download_and_cache(self$resources[annotation_filter, ]$url, prefix = "coco")
archive <- download_and_cache(self$resources[image_filter, ]$url, prefix = "coco")

if (tools::md5sum(archive) != self$resources[image_filter, ]$md5) {
runtime_error("Corrupt file! Delete the file in {archive} and try again.")
}

utils::unzip(ann_zip, exdir = self$data_dir)
utils::unzip(archive, exdir = self$data_dir)

cli_inform("Dataset {.cls {class(self)[[1]]}} downloaded and extracted successfully.")
},

load_annotations = function() {
data <- jsonlite::fromJSON(self$annotation_file)

self$image_metadata <- setNames(
split(data$images, seq_len(nrow(data$images))),
as.character(data$images$id)
)

self$annotations <- data$annotations
self$categories <- data$categories
self$category_names <- setNames(self$categories$name, self$categories$id)

ids <- as.numeric(names(self$image_metadata))
image_paths <- fs::path(self$image_dir,
sapply(ids, function(id) self$image_metadata[[as.character(id)]]$file_name))
exist <- fs::file_exists(image_paths)
self$image_ids <- ids[exist]
}
)

Expand Down
6 changes: 2 additions & 4 deletions man/coco_segmentation_dataset.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions tests/testthat/test-dataset-coco.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ test_that("coco_detection_dataset loads a single example correctly", {
expect_false("segmentation" %in% names(y))
})

test_that("coco_ dataset loads a single segmentation example correctly", {
test_that("coco_segmentation_dataset loads a single segmentation example correctly", {
skip_if(Sys.getenv("TEST_LARGE_DATASETS", unset = 0) != 1,
"Skipping test: set TEST_LARGE_DATASETS=1 to enable tests requiring large downloads.")

Expand All @@ -64,7 +64,7 @@ test_that("coco_ dataset loads a single segmentation example correctly", {
expect_length(dim(item$x), 3)

expect_type(y, "list")
expect_named(y, c("boxes", "labels", "area", "iscrowd", "segmentation", "image_height", "image_width", "masks"))
expect_named(y, c("labels", "iscrowd", "segmentation", "image_height", "image_width", "masks"))

expect_tensor(y$masks)
expect_equal(y$masks$ndim, 3)
Expand All @@ -85,7 +85,7 @@ test_that("coco_detection_dataset batches correctly using dataloader", {
expect_true(all(vapply(batch$x, is_torch_tensor, logical(1))))

expect_type(batch$y, "list")
expect_named(batch$y[[1]], c("boxes", "labels", "area", "iscrowd", "segmentation"))
expect_named(batch$y[[1]], c("boxes", "labels", "area", "iscrowd"))
expect_tensor(batch$y[[1]]$boxes)
expect_equal(batch$y[[1]]$boxes$ndim, 2)
expect_equal(batch$y[[1]]$boxes$size(2), 4)
Expand Down
26 changes: 6 additions & 20 deletions tests/testthat/test-segmentation-transforms.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ test_that("target_transform_coco_masks() handles multiple objects", {
expect_equal(result$masks$shape, c(2, 100, 100))
})

# Trimap mask transformation tests
# Trimap mask transformation tests

test_that("target_transform_trimap_masks() converts trimap to masks", {
skip_if_not_installed("torch")
Expand All @@ -89,12 +89,12 @@ test_that("target_transform_trimap_masks() creates mutually exclusive masks", {

# Dataset behavior tests (skipped unless TEST_LARGE_DATASETS=1)

test_that("coco_detection_dataset with target_transform produces masks", {
test_that("coco_segmentation_dataset with target_transform produces masks", {
skip_on_cran()
skip_if(Sys.getenv("TEST_LARGE_DATASETS", unset = 0) != 1,
"Set TEST_LARGE_DATASETS=1 to enable")
tmp <- withr::local_tempdir()
ds <- coco_detection_dataset(root = tmp, train = FALSE, year = "2017",
ds <- coco_segmentation_dataset(root = tmp, train = FALSE, year = "2017",
download = TRUE, target_transform = target_transform_coco_masks)
y <- ds[1]$y
expect_true("masks" %in% names(y))
Expand All @@ -103,21 +103,7 @@ test_that("coco_detection_dataset with target_transform produces masks", {
expect_tensor_dtype(y$masks, torch::torch_bool())
})

test_that("oxfordiiitpet with target_transform produces masks", {
skip_on_cran()
skip_if(Sys.getenv("TEST_LARGE_DATASETS", unset = 0) != 1,
"Set TEST_LARGE_DATASETS=1 to enable")
tmp <- withr::local_tempdir()
ds <- oxfordiiitpet_segmentation_dataset(root = tmp, train = TRUE,
download = TRUE, target_transform = target_transform_trimap_masks)
y <- ds[1]$y
expect_true("masks" %in% names(y))
expect_tensor(y$masks)
expect_tensor_shape(y$masks, c(3, NA, NA))
expect_tensor_dtype(y$masks, torch::torch_bool())
})

# Integration with draw_segmentation_masks()
# Integration with draw_segmentation_masks()

test_that("transformed masks work with draw_segmentation_masks()", {
skip_if_not_installed("torch")
Expand All @@ -132,7 +118,7 @@ test_that("transformed masks work with draw_segmentation_masks()", {
expect_equal(result$shape[1], 3)
})

# Existing helper tests
# Existing helper tests

test_that("coco_polygon_to_mask handles single polygon", {
skip_if_not_installed("torch")
Expand All @@ -150,5 +136,5 @@ test_that("coco_polygon_to_mask handles empty polygon", {
skip_if_not_installed("magick")
mask <- coco_polygon_to_mask(list(), 100, 100)
expect_tensor(mask)
expect_equal(as.numeric(mask$sum()$item()), 0)
expect_equal_to_r(mask$sum(), 0)
})
Loading