diff --git a/DESCRIPTION b/DESCRIPTION index 4772f52fd..045182c1f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -91,6 +91,7 @@ Collate: 'CallbackSetUnfreeze.R' 'ContextTorch.R' 'DataBackendLazy.R' + 'DataBackendLazyTensors.R' 'utils.R' 'DataDescriptor.R' 'LearnerTorch.R' diff --git a/NAMESPACE b/NAMESPACE index 122ebafb6..1ed76e181 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -7,11 +7,15 @@ S3method("[[<-",lazy_tensor) S3method(as.data.table,DictionaryMlr3torchCallbacks) S3method(as.data.table,DictionaryMlr3torchLosses) S3method(as.data.table,DictionaryMlr3torchOptimizers) +S3method(as_data_backend,dataset) S3method(as_data_descriptor,dataset) S3method(as_lazy_tensor,DataDescriptor) S3method(as_lazy_tensor,dataset) S3method(as_lazy_tensor,numeric) S3method(as_lazy_tensor,torch_tensor) +S3method(as_lazy_tensors,dataset) +S3method(as_task_classif,dataset) +S3method(as_task_regr,dataset) S3method(as_torch_callback,R6ClassGenerator) S3method(as_torch_callback,TorchCallback) S3method(as_torch_callback,character) @@ -27,6 +31,8 @@ S3method(as_torch_optimizer,character) S3method(as_torch_optimizer,torch_optimizer_generator) S3method(c,lazy_tensor) S3method(col_info,DataBackendLazy) +S3method(col_info,DataBackendLazyTensors) +S3method(distinct_values,lazy_tensor) S3method(format,lazy_tensor) S3method(hash_input,TorchIngressToken) S3method(hash_input,lazy_tensor) @@ -71,6 +77,7 @@ export(CallbackSetTB) export(CallbackSetUnfreeze) export(ContextTorch) export(DataBackendLazy) +export(DataBackendLazyTensors) export(DataDescriptor) export(LearnerTorch) export(LearnerTorchFeatureless) @@ -161,6 +168,7 @@ export(TorchLoss) export(TorchOptimizer) export(as_data_descriptor) export(as_lazy_tensor) +export(as_lazy_tensors) export(as_lr_scheduler) export(as_torch_callback) export(as_torch_callbacks) diff --git a/NEWS.md b/NEWS.md index 9f2cbf174..1d581c0b2 100644 --- a/NEWS.md +++ b/NEWS.md @@ -11,6 +11,8 @@ This means that for binary classification tasks, `t_loss("cross_entropy")` now generates `nn_bce_with_logits_loss` instead of `nn_cross_entropy_loss`. This also came with a reparametrization of the `t_loss("cross_entropy")` loss (thanks to @tdhock, #374). +* fix: `NA` is now a valid shape for lazy tensors. +* feat: `lazy_tensor`s of length 0 can now be materialized. # mlr3torch 0.2.1 diff --git a/R/DataBackendLazyTensors.R b/R/DataBackendLazyTensors.R new file mode 100644 index 000000000..2328a86b4 --- /dev/null +++ b/R/DataBackendLazyTensors.R @@ -0,0 +1,249 @@ + +#' @title Special Backend for Lazy Tensors +#' @description +#' This backend essentially allows you to use a [`torch::dataset`] directly with +#' an [`mlr3::Learner`]. +#' +#' * The data cannot contain missing values, as [`lazy_tensor`]s do not support them. +#' For this reason, calling `$missings()` will always return `0` for all columns. +#' * The `$distinct()` method will consider two lazy tensors that refer to the same element of a +#' [`DataDescriptor`] to be identical. +#' This means, that it might be underreporting the number of distinct values of lazy tensor columns. +#' +#' @export +#' @examplesIf torch::torch_is_installed() +#' # used as feature in all backends +#' x = torch_randn(100, 10) +#' # regression +#' ds_regr = tensor_dataset(x = x, y = torch_randn(100, 1)) +#' be_regr = as_data_backend(ds_regr, converter = list(y = as.numeric)) +#' be_regr$head() +#' +#' +#' # binary classification: underlying target tensor must be float in [0, 1] +#' ds_binary = tensor_dataset(x = x, y = torch_randint(0, 2, c(100, 1))$float()) +#' be_binary = as_data_backend(ds_binary, converter = list( +#' y = function(x) factor(as.integer(x), levels = c(0, 1), labels = c("A", "yes")) +#' )) +#' be_binary$head() +#' +#' # multi-class classification: underlying target tensor must be integer in [1, K] +#' ds_multiclass = tensor_dataset(x = x, y = torch_randint(1, 4, size = c(100, 1))) +#' be_multiclass = as_data_backend(ds_multiclass, converter = list(y = as.numeric)) +#' be_multiclass$head() + +DataBackendLazyTensors = R6Class("DataBackendLazyTensors", + cloneable = FALSE, + inherit = DataBackendDataTable, + public = list( + chunk_size = NULL, + #' @description + #' Create a new instance of this [R6][R6::R6Class] class. + #' @param data (`data.table`)\cr + #' Data containing (among others) [`lazy_tensor`] columns. + #' @param primary_key (`character(1)`)\cr + #' Name of the column used as primary key. + #' @param converter (named `list()` of `function`s)\cr + #' A named list of functions that convert the lazy tensor columns to their R representation. + #' The names must be the names of the columns that need conversion. + #' @param cache (`character()`)\cr + #' Names of the columns that should be cached. + #' Per default, all columns that are converted are cached. + initialize = function(data, primary_key, converter, cache = names(converter), chunk_size = 100) { + private$.converter = assert_list(converter, types = "function", any.missing = FALSE) + assert_subset(names(converter), colnames(data)) + assert_subset(cache, names(converter), empty.ok = TRUE) + private$.cached_cols = assert_subset(cache, names(converter)) + self$chunk_size = assert_int(chunk_size, lower = 1L) + walk(names(private$.converter), function(nm) { + if (!inherits(data[[nm]], "lazy_tensor")) { + stopf("Column '%s' is not a lazy tensor.", nm) + } + }) + super$initialize(data, primary_key) + # select the column whose name is stored in primary_key from private$.data but keep its name + private$.data_cache = private$.data[, primary_key, with = FALSE] + }, + data = function(rows, cols) { + rows = assert_integerish(rows, coerce = TRUE) + assert_names(cols, type = "unique") + + if (getOption("mlr3torch.data_loading", FALSE)) { + # no caching, no materialization as this is called in the training loop + return(super$data(rows, cols)) + } + if (all(intersect(cols, private$.cached_cols) %in% names(private$.data_cache))) { + expensive_cols = intersect(cols, private$.cached_cols) + other_cols = setdiff(cols, expensive_cols) + cache_hit = private$.data_cache[list(rows), expensive_cols, on = self$primary_key, with = FALSE] + complete = complete.cases(cache_hit) + cache_hit = cache_hit[complete] + if (nrow(cache_hit) == length(rows)) { + tbl = cbind(cache_hit, super$data(rows, other_cols)) + setcolorder(tbl, cols) + return(tbl) + } + combined = rbindlist(list(cache_hit, private$.load_and_cache(rows[!complete], expensive_cols))) + reorder = vector("integer", nrow(combined)) + reorder[complete] = seq_len(nrow(cache_hit)) + reorder[!complete] = nrow(cache_hit) + seq_len(nrow(combined) - nrow(cache_hit)) + + tbl = cbind(combined[reorder], super$data(rows, other_cols)) + setcolorder(tbl, cols) + return(tbl) + } + + private$.load_and_cache(rows, cols) + }, + head = function(n = 6L) { + if (getOption("mlr3torch.data_loading", FALSE)) { + return(super$head(n)) + } + + self$data(seq_len(n), self$colnames) + }, + missings = function(rows, cols) { + set_names(rep(0L, length(cols)), cols) + } + ), + active = list( + converter = function(rhs) { + assert_ro_binding(rhs) + private$.converter + } + ), + private = list( + # call this function only with rows that are not in the cache yet + .load_and_cache = function(rows, cols) { + # Process columns that need conversion + tbl = super$data(rows, cols) + cols_to_convert = intersect(names(private$.converter), names(tbl)) + tbl_to_mat = tbl[, cols_to_convert, with = FALSE] + # chunk the rows of tbl_to_mat into chunks of size self$chunk_size, apply materialize + n = nrow(tbl_to_mat) + chunks = split(seq_len(n), rep(seq_len(ceiling(n / self$chunk_size)), each = self$chunk_size, length.out = n)) + + tbl_mat = if (n == 0) { + set_names(list(torch_empty(0)), names(tbl_to_mat)) + } else { + set_names(lapply(transpose_list(lapply(chunks, function(chunk) { + materialize(tbl_to_mat[chunk, ], rbind = TRUE) + })), torch_cat, dim = 1L), names(tbl_to_mat)) + } + + for (nm in cols_to_convert) { + converted = private$.converter[[nm]](tbl_mat[[nm]]) + tbl[[nm]] = converted + + if (nm %in% private$.cached_cols) { + set(private$.data_cache, i = rows, j = nm, value = converted) + } + } + return(tbl) + }, + .data_cache = NULL, + .converter = NULL, + .cached_cols = NULL + ) +) + +#' @export +as_data_backend.dataset = function(x, dataset_shapes, ...) { + tbl = as_lazy_tensors(x, dataset_shapes, ...) + tbl$row_id = seq_len(nrow(tbl)) + DataBackendLazyTensors$new(tbl, primary_key = "row_id", ...) +} + +#' @export +as_task_classif.dataset = function(x, target, levels, converter = NULL, dataset_shapes = NULL, chunk_size = 100, cache = names(converter), ...) { + if (length(x) < 2) { + stopf("Dataset must have at least 2 rows.") + } + batch = dataloader(x, batch_size = 2)$.iter()$.next() + if (is.null(converter)) { + if (length(levels) == 2) { + if (batch[[target]]$dtype != torch_float()) { + stopf("Target must be a float tensor, but has dtype %s", batch[[target]]$dtype) + } + if (test_equal(batch[[target]]$shape, c(2L, 1L))) { + converter = set_names(list(crate(function(x) factor(as.integer(x), levels = 0:1, labels = levels), levels)), target) + } else { + stopf("Target must be a float tensor of shape (batch_size, 1), but has shape (batch_size, %s)", + paste(batch[[target]]$shape[-1L], collapse = ", ")) + } + converter = set_names(list(crate(function(x) factor(as.integer(x), levels = 0:1, labels = levels), levels)), target) + } else { + if (batch[[target]]$dtype != torch_int()) { + stopf("Target must be an integer tensor, but has dtype %s", batch[[target]]$dtype) + } + if (test_equal(batch[[target]]$shape, 2L)) { + converter = set_names(list(crate(function(x) factor(as.integer(x), labels = levels), levels)), target) + } else { + stopf("Target must be an integer tensor of shape (batch_size), but has shape (batch_size, %s)", + paste(batch[[target]]$shape[-1L], collapse = ", ")) + } + converter = set_names(list(crate(function(x) factor(as.integer(x), labels = levels), levels)), target) + } + } + be = as_data_backend(x, dataset_shapes, converter = converter, cache = cache, chunk_size = chunk_size) + as_task_classif(be, target = target, ...) +} + +#' @export +as_task_regr.dataset = function(x, target, converter = NULL, dataset_shapes = NULL, chunk_size = 100, cache = names(converter), ...) { + if (length(x) < 2) { + stopf("Dataset must have at least 2 rows.") + } + if (is.null(converter)) { + converter = set_names(list(as.numeric), target) + } + batch = dataloader(x, batch_size = 2)$.iter()$.next() + + if (batch[[target]]$dtype != torch_float()) { + stopf("Target must be a float tensor, but has dtype %s", batch[[target]]$dtype) + } + + if (!test_equal(batch[[target]]$shape, c(2L, 1L))) { + stopf("Target must be a float tensor of shape (batch_size, 1), but has shape (batch_size, %s)", + paste(batch[[target]]$shape[-1L], collapse = ", ")) + } + + dataset_shapes = get_or_check_dataset_shapes(x, dataset_shapes) + be = as_data_backend(x, dataset_shapes, converter = converter, cache = cache, chunk_size = chunk_size) + as_task_regr(be, target = target, ...) +} + +#' @export +col_info.DataBackendLazyTensors = function(x, ...) { # nolint + first_row = x$head(1L) + types = map_chr(first_row, function(x) class(x)[1L]) + discrete = setdiff(names(types)[types %chin% c("factor", "ordered")], x$primary_key) + levels = insert_named(named_list(names(types)), map(first_row[, discrete, with = FALSE], levels)) + data.table(id = names(types), type = unname(types), levels = levels, key = "id") +} + + +# conservative check that avoids that a pseudo-lazy-tensor is preprocessed by some pipeop +# @param be +# the backend +# @param candidates +# the feature and target names +# @param visited +# Union of all colnames already visited +# @return visited +check_lazy_tensors_backend = function(be, candidates, visited = character()) { + if (inherits(be, "DataBackendRbind") || inherits(be, "DataBackendCbind")) { + bs = be$.__enclos_env__$private$.data + # first we check b2, then b1, because b2 possibly overshadows some b1 rows/cols + visited = check_lazy_tensors_backend(bs$b2, candidates, visited) + check_lazy_tensors_backend(bs$b1, candidates, visited) + } else { + if (inherits(be, "DataBackendLazyTensors")) { + if (any(names(be$converter) %in% visited)) { + converter_cols = names(be$converter)[names(be$converter) %in% visited] + stopf("A converter column ('%s') from a DataBackendLazyTensors was presumably preprocessed by some PipeOp. This can cause inefficiencies and is therefore not allowed. If you want to preprocess them, please directly encode them as R types.", paste0(converter_cols, collapse = ", ")) # nolint + } + } + union(visited, intersect(candidates, be$colnames)) + } +} diff --git a/R/DataDescriptor.R b/R/DataDescriptor.R index 1bf3cd68d..6a1d65740 100644 --- a/R/DataDescriptor.R +++ b/R/DataDescriptor.R @@ -60,14 +60,7 @@ DataDescriptor = R6Class("DataDescriptor", # For simplicity we here require the first dimension of the shape to be NA so we don't have to deal with it, # e.g. during subsetting - if (is.null(dataset_shapes)) { - if (is.null(dataset$.getbatch)) { - stopf("dataset_shapes must be provided if dataset does not have a `.getbatch` method.") - } - dataset_shapes = infer_shapes_from_getbatch(dataset) - } else { - assert_compatible_shapes(dataset_shapes, dataset) - } + dataset_shapes = get_or_check_dataset_shapes(dataset, dataset_shapes) if (is.null(graph)) { # avoid name conflicts @@ -84,8 +77,7 @@ DataDescriptor = R6Class("DataDescriptor", assert_true(length(graph$pipeops) >= 1L) } # no preprocessing, dataset returns only a single element (there we can infer a lot) - simple_case = length(graph$pipeops) == 1L && inherits(graph$pipeops[[1L]], "PipeOpNOP") && - length(dataset_shapes) == 1L + simple_case = (length(graph$pipeops) == 1L) && inherits(graph$pipeops[[1L]], "PipeOpNOP") if (is.null(input_map) && nrow(graph$input) == 1L && length(dataset_shapes) == 1L) { input_map = names(dataset_shapes) @@ -100,7 +92,7 @@ DataDescriptor = R6Class("DataDescriptor", assert_choice(pointer[[2]], graph$pipeops[[pointer[[1]]]]$output$name) } if (is.null(pointer_shape) && simple_case) { - pointer_shape = dataset_shapes[[1L]] + pointer_shape = dataset_shapes[[input_map]] } else { assert_shape(pointer_shape, null_ok = TRUE) } @@ -225,13 +217,14 @@ infer_shapes_from_getbatch = function(ds) { } assert_compatible_shapes = function(shapes, dataset) { - assert_shapes(shapes, null_ok = TRUE, unknown_batch = TRUE, named = TRUE) + shapes = assert_shapes(shapes, null_ok = TRUE, unknown_batch = TRUE, named = TRUE, coerce = TRUE) # prevent user from e.g. forgetting to wrap the return in a list - example = if (is.null(dataset$.getbatch)) { - dataset$.getitem(1L) - } else { + has_getbatch = !is.null(dataset$.getbatch) + example = if (has_getbatch) { dataset$.getbatch(1L) + } else { + dataset$.getitem(1L) } if (!test_list(example, names = "unique") || !test_permutation(names(example), names(shapes))) { stopf("Dataset must return a list with named elements that are a permutation of the dataset_shapes names.") @@ -242,17 +235,17 @@ assert_compatible_shapes = function(shapes, dataset) { } }) - if (is.null(dataset$.getbatch)) { - example = map(example, function(x) x$unsqueeze(1)) - } - iwalk(shapes, function(dataset_shape, name) { - if (!is.null(dataset_shape) && !test_equal(shapes[[name]][-1], example[[name]]$shape[-1L])) { - expected_shape = example[[name]]$shape - expected_shape[1] = NA + observed_shape = example[[name]]$shape + if (has_getbatch) { + observed_shape[1L] = NA_integer_ + } else { + observed_shape = c(NA_integer_, observed_shape) + } + if (!is.null(dataset_shape) && !test_equal(observed_shape, dataset_shape)) { stopf(paste0("First batch from dataset is incompatible with the provided shape of %s:\n", - "* Provided shape: %s.\n* Expected shape: %s."), name, - shape_to_str(unname(shapes[name])), shape_to_str(list(expected_shape))) + "* Provided shape: %s.\n* Observed shape: %s."), name, + shape_to_str(unname(shapes[name])), shape_to_str(list(observed_shape))) } }) } diff --git a/R/LearnerTorch.R b/R/LearnerTorch.R index af1db6e5e..068d54a84 100644 --- a/R/LearnerTorch.R +++ b/R/LearnerTorch.R @@ -109,7 +109,7 @@ #' #' For information on the expected target encoding of `y`, see section *Network Head and Target Encoding*. #' Moreover, one needs to pay attention respect the row ids of the provided task. -#' It is recommended to relu on [`task_dataset`] for creating the [`dataset`][torch::dataset]. +#' It is strongly recommended to use the [`task_dataset`] class to create the dataset. #' #' It is also possible to overwrite the private `.dataloader()` method. #' This must respect the dataloader parameters from the [`ParamSet`][paradox::ParamSet]. diff --git a/R/lazy_tensor.R b/R/lazy_tensor.R index d050f8545..00b397575 100644 --- a/R/lazy_tensor.R +++ b/R/lazy_tensor.R @@ -197,6 +197,19 @@ as_lazy_tensor.torch_tensor = function(x, ...) { # nolint as_lazy_tensor(ds, dataset_shapes = list(x = c(NA, dim(x)[-1]))) } +#' @export +as_lazy_tensors = function(x, ...) { + UseMethod("as_lazy_tensors") +} + +#' @export +as_lazy_tensors.dataset = function(x, dataset_shapes = NULL, ...) { + dataset_shapes = get_or_check_dataset_shapes(x, dataset_shapes) + set_names(map_dtc(names(dataset_shapes), function(shape) { + as_lazy_tensor(x, dataset_shapes = dataset_shapes, input_map = shape) + }), names(dataset_shapes)) +} + #' Assert Lazy Tensor #' #' Asserts whether something is a lazy tensor. @@ -339,3 +352,13 @@ rep.lazy_tensor = function(x, ...) { rep_len.lazy_tensor = function(x, ...) { set_class(NextMethod(), c("lazy_tensor", "list")) } + + +#' @export +distinct_values.lazy_tensor = function(x, drop = TRUE, na_rm = TRUE) { + if (!length(x)) { + return(x) + } + ids = distinct_values(map_int(x, 1)) + lazy_tensor(dd(x), ids) +} \ No newline at end of file diff --git a/R/learner_torch_methods.R b/R/learner_torch_methods.R index 79cebaa4a..7259bf587 100644 --- a/R/learner_torch_methods.R +++ b/R/learner_torch_methods.R @@ -18,8 +18,10 @@ learner_torch_predict = function(self, private, super, task, param_vals) { private$.encode_prediction(predict_tensor = predict_tensor, task = task) } + learner_torch_train = function(self, private, super, task, param_vals) { # Here, all param_vals (like seed = "random" or device = "auto") have already been resolved + check_lazy_tensors_backend(task$backend, c(task$feature_names, task$target_names)) dataset_train = private$.dataset(task, param_vals) dataset_train = as_multi_tensor_dataset(dataset_train, param_vals) loader_train = private$.dataloader(dataset_train, param_vals) @@ -356,3 +358,5 @@ as_multi_tensor_dataset = function(dataset, param_vals) { dataset } } + + diff --git a/R/materialize.R b/R/materialize.R index 849024ad4..ee113830d 100644 --- a/R/materialize.R +++ b/R/materialize.R @@ -63,6 +63,13 @@ materialize.list = function(x, device = "cpu", rbind = FALSE, cache = "auto", .. map(x, function(col) { if (is_lazy_tensor(col)) { + if (length(col) == 0L) { + if (rbind) { + return(torch_empty(0L)) + } else { + return(list()) + } + } materialize_internal(col, device = device, cache = cache, rbind = rbind) } else { col @@ -76,16 +83,30 @@ materialize.list = function(x, device = "cpu", rbind = FALSE, cache = "auto", .. #' @method materialize data.frame #' @export materialize.data.frame = function(x, device = "cpu", rbind = FALSE, cache = "auto", ...) { # nolint + if (nrow(x) == 0L) { + if (rbind) { + set_names(replicate(ncol(x), torch_empty(0L)), names(x)) + } else { + set_names(replicate(ncol(x), list()), names(x)) + } + } materialize(as.list(x), device = device, rbind = rbind, cache = cache) } #' @export materialize.lazy_tensor = function(x, device = "cpu", rbind = FALSE, ...) { # nolint + if (length(x) == 0L) { + if (rbind) { + return(torch_empty(0L)) + } else { + return(list()) + } + } materialize_internal(x = x, device = device, cache = NULL, rbind = rbind) } -get_input = function(ds, ids, varying_shapes, rbind) { +get_input = function(ds, ids, varying_shapes) { if (is.null(ds$.getbatch)) { # .getindex is never NULL but a function that errs if it was not defined x = map(ids, function(id) map(ds$.getitem(id), function(x) x$unsqueeze(1))) if (varying_shapes) { @@ -154,9 +175,6 @@ get_output = function(input, graph, varying_shapes, rbind, device) { #' @return [`lazy_tensor()`] #' @keywords internal materialize_internal = function(x, device = "cpu", cache = NULL, rbind) { - if (!length(x)) { - stopf("Cannot materialize lazy tensor of length 0.") - } do_caching = !is.null(cache) ids = map_int(x, 1) @@ -183,7 +201,7 @@ materialize_internal = function(x, device = "cpu", cache = NULL, rbind) { } if (!do_caching || !input_hit) { - input = get_input(ds, ids, varying_shapes, rbind) + input = get_input(ds, ids, varying_shapes) } if (do_caching && !input_hit) { diff --git a/R/shape.R b/R/shape.R index d1fdda83d..7970c37ec 100644 --- a/R/shape.R +++ b/R/shape.R @@ -30,7 +30,7 @@ test_shape = function(shape, null_ok = FALSE, unknown_batch = NULL, len = NULL) if (is.null(shape) && null_ok) { return(TRUE) } - ok = test_integerish(shape, min.len = 2L, all.missing = FALSE, any.missing = TRUE, len = len) + ok = test_integerish(shape, min.len = 1L, any.missing = TRUE, len = len) if (!ok) { return(FALSE) diff --git a/R/task_dataset.R b/R/task_dataset.R index bd088d1bf..10ab6e268 100644 --- a/R/task_dataset.R +++ b/R/task_dataset.R @@ -81,13 +81,21 @@ task_dataset = dataset("task_dataset", .getbatch = function(index) { cache = if (self$cache_lazy_tensors) new.env() - datapool = self$task$data(rows = self$task$row_ids[index], cols = self$all_features) + datapool = withr::with_options(list(mlr3torch.data_loading = TRUE), { + self$task$data(rows = self$task$row_ids[index], cols = self$all_features) + }) + x = lapply(self$feature_ingress_tokens, function(it) { it$batchgetter(datapool[, it$features, with = FALSE], cache = cache) }) y = if (!is.null(self$target_batchgetter)) { - self$target_batchgetter(datapool[, self$task$target_names, with = FALSE]) + target = datapool[, self$task$target_names, with = FALSE] + if (!inherits(target[[1L]], "lazy_tensor")) { + self$target_batchgetter(target) + } else { + materialize(target[[1L]], rbind = TRUE) + } } out = list(x = x, .index = torch_tensor(index, dtype = torch_long())) if (!is.null(y)) out$y = y diff --git a/R/utils.R b/R/utils.R index 2f9d68302..bcb17af66 100644 --- a/R/utils.R +++ b/R/utils.R @@ -190,7 +190,10 @@ list_to_batch = function(tensors) { } auto_cache_lazy_tensors = function(lts) { - any(duplicated(map_chr(lts, function(x) dd(x)$dataset_hash))) + if (length(lts) <= 1L) { + return(FALSE) + } + anyDuplicated(unlist(map_if(lts, function(x) length(x) > 0, function(x) dd(x)$dataset_hash))) > 0L } #' Replace the head of a network @@ -300,6 +303,18 @@ infer_shapes = function(shapes_in, param_vals, output_names, fn, rowwise, id) { set_names(list(sout), output_names) } +get_or_check_dataset_shapes = function(dataset, dataset_shapes) { + if (is.null(dataset_shapes)) { + if (is.null(dataset$.getbatch)) { + stopf("dataset_shapes must be provided if dataset does not have a `.getbatch` method.") + } + dataset_shapes = infer_shapes_from_getbatch(dataset) + } else { + assert_compatible_shapes(dataset_shapes, dataset) + } + dataset_shapes +} + #' @title Network Output Dimension #' @description #' Calculates the output dimension of a neural network for a given task that is expected by diff --git a/TODO.md b/TODO.md new file mode 100644 index 000000000..62d953809 --- /dev/null +++ b/TODO.md @@ -0,0 +1,25 @@ +* Add `as_lazy_tensors()` +* Make it easier to se +* Fix the bug that the shapes are reported as unknown below and make the code easier. + ```r + ds = dataset("test", + initialize = function() { + self$x = torch_randn(100, 10) + self$y = torch_randn(100, 1) + }, + .getitem = function(i) { + list(x = self$x[i, ], y = self$y[i]) + }, + .length = function() { + nrow(self$x) + } + )() + x_lt = as_lazy_tensor(ds, list(x = c(NA, 10), y = c(NA, 1)), input_map = "x") + y_lt = as_lazy_tensor(ds, list(x = c(NA, 10), y = c(NA, 1)), input_map = "y") + + tbl = data.table(x = x_lt, y = y_lt) + ``` +* Add checks on usage of `DataBackendLazyTensors` in `task_dataset` +* Add optimization that truths values don't have to be loaded twice during resampling, i.e. + once for making the predictions and once for retrieving the truth column. +* only allow caching converter columns in `DataBackendLazyTensors` (probably just remove the `cache` parameter) \ No newline at end of file diff --git a/man/DataBackendLazyTensors.Rd b/man/DataBackendLazyTensors.Rd new file mode 100644 index 000000000..42880dec7 --- /dev/null +++ b/man/DataBackendLazyTensors.Rd @@ -0,0 +1,123 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/DataBackendLazyTensors.R +\name{DataBackendLazyTensors} +\alias{DataBackendLazyTensors} +\title{Special Backend for Lazy Tensors} +\description{ +This backend essentially allows you to use a \code{\link[torch:dataset]{torch::dataset}} directly with +an \code{\link[mlr3:Learner]{mlr3::Learner}}. +\itemize{ +\item The data cannot contain missing values, as \code{\link{lazy_tensor}}s do not support them. +For this reason, calling \verb{$missings()} will always return \code{0} for all columns. +\item The \verb{$distinct()} method will consider two lazy tensors that refer to the same element of a +\code{\link{DataDescriptor}} to be identical. +This means, that it might be underreporting the number of distinct values of lazy tensor columns. +} +} +\examples{ +\dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +# used as feature in all backends +x = torch_randn(100, 10) +# regression +ds_regr = tensor_dataset(x = x, y = torch_randn(100, 1)) +be_regr = as_data_backend(ds_regr, converter = list(y = as.numeric)) +be_regr$head() + + +# binary classification: underlying target tensor must be float in [0, 1] +ds_binary = tensor_dataset(x = x, y = torch_randint(0, 2, c(100, 1))$float()) +be_binary = as_data_backend(ds_binary, converter = list( + y = function(x) factor(as.integer(x), levels = c(0, 1), labels = c("A", "yes")) +)) +be_binary$head() + +# multi-class classification: underlying target tensor must be integer in [1, K] +ds_multiclass = tensor_dataset(x = x, y = torch_randint(1, 4, size = c(100, 1))) +be_multiclass = as_data_backend(ds_multiclass, converter = list(y = as.numeric)) +be_multiclass$head() +\dontshow{\}) # examplesIf} +} +\section{Super classes}{ +\code{\link[mlr3:DataBackend]{mlr3::DataBackend}} -> \code{\link[mlr3:DataBackendDataTable]{mlr3::DataBackendDataTable}} -> \code{DataBackendLazyTensors} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-DataBackendLazyTensors-new}{\code{DataBackendLazyTensors$new()}} +\item \href{#method-DataBackendLazyTensors-data}{\code{DataBackendLazyTensors$data()}} +\item \href{#method-DataBackendLazyTensors-head}{\code{DataBackendLazyTensors$head()}} +\item \href{#method-DataBackendLazyTensors-missings}{\code{DataBackendLazyTensors$missings()}} +} +} +\if{html}{\out{ +
Inherited methods + +
+}} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-DataBackendLazyTensors-new}{}}} +\subsection{Method \code{new()}}{ +Create a new instance of this \link[R6:R6Class]{R6} class. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{DataBackendLazyTensors$new( + data, + primary_key, + converter, + cache = names(converter), + chunk_size = 100 +)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{data}}{(\code{data.table})\cr +Data containing (among others) \code{\link{lazy_tensor}} columns.} + +\item{\code{primary_key}}{(\code{character(1)})\cr +Name of the column used as primary key.} + +\item{\code{converter}}{(named \code{list()} of \code{function}s)\cr +A named list of functions that convert the lazy tensor columns to their R representation. +The names must be the names of the columns that need conversion.} + +\item{\code{cache}}{(\code{character()})\cr +Names of the columns that should be cached. +Per default, all columns that are converted are cached.} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-DataBackendLazyTensors-data}{}}} +\subsection{Method \code{data()}}{ +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{DataBackendLazyTensors$data(rows, cols)}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-DataBackendLazyTensors-head}{}}} +\subsection{Method \code{head()}}{ +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{DataBackendLazyTensors$head(n = 6L)}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-DataBackendLazyTensors-missings}{}}} +\subsection{Method \code{missings()}}{ +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{DataBackendLazyTensors$missings(rows, cols)}\if{html}{\out{
}} +} + +} +} diff --git a/man/mlr_learners_torch.Rd b/man/mlr_learners_torch.Rd index d9a924739..a4d630c64 100644 --- a/man/mlr_learners_torch.Rd +++ b/man/mlr_learners_torch.Rd @@ -205,7 +205,7 @@ For networks with more than one input, the names must correspond to the inputs o For information on the expected target encoding of \code{y}, see section \emph{Network Head and Target Encoding}. Moreover, one needs to pay attention respect the row ids of the provided task. -It is recommended to relu on \code{\link{task_dataset}} for creating the \code{\link[torch:dataset]{dataset}}. +It is strongly recommended to use the \code{\link{task_dataset}} class to create the dataset. } It is also possible to overwrite the private \code{.dataloader()} method. diff --git a/tests/testthat/test_DataBackendLazyTensors.R b/tests/testthat/test_DataBackendLazyTensors.R new file mode 100644 index 000000000..06fb91574 --- /dev/null +++ b/tests/testthat/test_DataBackendLazyTensors.R @@ -0,0 +1,333 @@ +test_that("main API works", { + # regression target + ds = tensor_dataset( + x = torch_tensor(matrix(100:1, nrow = 100, ncol = 1)), + y = torch_tensor(as.matrix(1:100, nrow = 100, ncol = 1)) + ) + + be = as_data_backend(ds, converter = list(y = as.numeric), dataset_shapes = list(x = c(NA, 1), y = c(NA, 1))) + + # converted data + + batch1 = be$data(1, c("x", "y")) + expect_class(batch1$x, "lazy_tensor") + expect_equal(length(batch1$x), 1) + expect_equal(materialize(batch1$x, rbind = TRUE), torch_tensor(matrix(100L, nrow = 1, ncol = 1))) + expect_equal(batch1$y, 1) + + batch2 = be$data(2:1, c("x", "y")) + expect_class(batch2$x, "lazy_tensor") + expect_equal(length(batch2$x), 2) + expect_equal(materialize(batch2$x, rbind = TRUE), torch_tensor(matrix(100:99, nrow = 2, ncol = 1))) + expect_equal(batch2$y, c(2, 1)) + + # lt data + batch_lt1 = withr::with_options(list(mlr3torch.data_loading = TRUE), { + be$data(1, c("x", "y")) + }) + expect_class(batch_lt1$x, "lazy_tensor") + expect_equal(length(batch_lt1$x), 1) + expect_equal(materialize(batch_lt1$x, rbind = TRUE), torch_tensor(matrix(100L, nrow = 1, ncol = 1))) + # y is still a lazy tensor + expect_class(batch_lt1$y, "lazy_tensor") + expect_equal(length(batch_lt1$y), 1) + + batch_lt2 = withr::with_options(list(mlr3torch.data_loading = TRUE), { + be$data(2:1, c("x", "y")) + }) + expect_class(batch_lt2$x, "lazy_tensor") + expect_equal(length(batch_lt2$x), 2) + expect_equal(materialize(batch_lt2$x, rbind = TRUE), torch_tensor(matrix(100:99, nrow = 2, ncol = 1))) + # y is still a lazy tensor + expect_class(batch_lt2$y, "lazy_tensor") + expect_equal(length(batch_lt2$y), 2) + + # missings + expect_equal(be$missings(1:100, c("y", "x")), c(y = 0, x = 0)) + expect_equal(be$missings(1:100, "y"), c(y = 0)) + expect_equal(be$missings(1:100, "x"), c(x = 0)) + + # head + tbl = be$head(n = 3) + expect_data_table(tbl, nrows = 3, ncols = 3) + expect_class(tbl$x, "lazy_tensor") + expect_equal(materialize(tbl$x, rbind = TRUE), torch_tensor(matrix(100:98, nrow = 3, ncol = 1))) + expect_class(tbl$y, "numeric") + expect_equal(tbl$row_id, as.numeric(1:3)) + expect_class(tbl$row_id, "integer") + expect_equal(tbl$row_id, 1:3) + + # distinct values: this can be expensive + dist = be$distinct(1:3, c("x", "y", "row_id")) + expect_list(dist, len = 3) + expect_equal(materialize(dist$x, rbind = TRUE), torch_tensor(matrix(100:98, nrow = 3, ncol = 1))) + expect_equal(dist$y, c(1, 2, 3)) + expect_equal(dist$row_id, 1:3) +}) + +test_that("classif target works", { + ds = dataset( + initialize = function() { + self$x = torch_tensor(matrix(100:1, nrow = 100, ncol = 1)) + self$y = torch_tensor(matrix(rep(c(0, 1), each = 50), nrow = 100, ncol = 1)) + }, + .getitem = function(i) { + list(x = self$x[i], y = self$y[i]) + }, + .length = function() { + nrow(self$x) + } + )() + + tbl = as_lazy_tensors(ds, list(x = c(NA, 1), y = c(NA, 1))) + tbl$row_id = 1:100 + + be = DataBackendLazyTensors$new(tbl, primary_key = "row_id", converter = list( + y = function(x) factor(as.integer(x), levels = c(0, 1), labels = c("yes", "no")) + )) + batch = be$data(c(1, 2, 51, 52), c("x", "y", "row_id")) + expect_class(batch$y, "factor") + expect_equal(batch$y, factor(c("yes", "yes", "no", "no"), levels = c("yes", "no"))) + + batch_lt = withr::with_options(list(mlr3torch.data_loading = TRUE), { + be$data(c(1, 2, 51, 52), c("x", "y", "row_id")) + }) + expect_class(batch_lt$y, "lazy_tensor") + expect_equal(length(batch_lt$y), 4) + expect_equal(materialize(batch_lt$y, rbind = TRUE), torch_tensor(matrix(c(1, 1, 0, 0), nrow = 4, ncol = 1))) +}) + +test_that("errors when weird preprocessing", { +}) + +test_that("chunking works ", { + ds = dataset( + initialize = function() { + self$x = torch_tensor(matrix(100:1, nrow = 100, ncol = 1)) + self$y = torch_tensor(as.matrix(1:100, nrow = 100, ncol = 1)) + self$counter = 0 + }, + .getbatch = function(i) { + self$counter = self$counter + 1 + list(x = self$x[i, drop = FALSE], y = self$y[i, drop = FALSE]) + }, + .length = function() { + nrow(self$x) + } + )() + + be = as_data_backend(ds, dataset_shapes = list(x = c(NA, 1), y = c(NA, 1)), chunk_size = 3, + converter = list(y = as.numeric)) + + counter_prev = ds$counter + be$data(1:3, c("x", "y")) + expect_equal(ds$counter, counter_prev + 1) + counter_prev = ds$counter + be$data(4:10, c("x", "y")) + expect_equal(ds$counter, counter_prev + 3) +}) + +test_that("can retrieve 0 rows", { + ds = tensor_dataset( + x = torch_tensor(matrix(100:1, nrow = 100, ncol = 1)), + y = torch_tensor(as.matrix(1:100, nrow = 100, ncol = 1)) + ) + be = as_data_backend(ds, dataset_shapes = list(x = c(NA, 1), y = c(NA, 1)), + converter = list(y = as.numeric)) + res = be$data(integer(0), c("x", "y", "row_id")) + expect_data_table(res, nrows = 0, ncols = 3) + expect_class(res$x, "lazy_tensor") + expect_class(res$y, "numeric") + expect_equal(res$row_id, integer(0)) +}) + +test_that("task converters work", { + # regression target + ds = tensor_dataset( + x = torch_tensor(matrix(100:1, nrow = 100, ncol = 1))$float(), + y = torch_tensor(as.matrix(1:100, nrow = 100, ncol = 1))$float() + ) + task = as_task_regr(ds, target = "y", converter = list(y = as.numeric)) + task$data(integer(0)) + expect_equal(task$head(2)$y, 1:2) + expect_equal(task$feature_names, "x") + expect_equal(task$target_names, "y") + expect_task(task) + + + # binary classification + ds = tensor_dataset( + x = torch_tensor(matrix(100:1, nrow = 100, ncol = 1))$float(), + y = torch_tensor(rep(0:1, times = 50))$float()$unsqueeze(2L) + ) + task = as_task_classif(ds, target = "y", levels = c("yes", "no")) + expect_task(task) + expect_equal(task$head()$y, factor(rep(c("yes", "no"), times = 3), levels = c("yes", "no"))) +}) + +test_that("caching works", { + dsc = dataset( + initialize = function() { + self$x = torch_tensor(matrix(100:1, nrow = 100, ncol = 1)) + self$y = torch_tensor(as.matrix(1:100, nrow = 100, ncol = 1)) + self$counter = 0 + }, + .getitem = function(i) { + self$counter = self$counter + 1 + list(x = self$x[i], y = self$y[i]) + }, + .length = function() { + nrow(self$x) + } + ) + + ds = dsc() + + be = as_data_backend(ds, dataset_shapes = list(x = c(NA, 1), y = c(NA, 1)), + converter = list(y = as.integer), cache = "y") + + check = function(be, ds, rows, cols, n) { + counter_prev = ds$counter + tbl = be$data(rows, cols) + observed_n = ds$counter - counter_prev + expect_equal(observed_n, n) + + if ("x" %in% cols) { + expect_equal(materialize(tbl$x, rbind = TRUE), ds$x[rows]) + } + if ("y" %in% cols) { + expect_equal(tbl$y, as.integer(ds$y[rows])) + } + } + check(be, ds, 1, c("x", "y"), 1) + # y is no in the cache, so .getitem() is not called on $data() + check(be, ds, 1, "y", 0) + + # everything is in the cache + check(be, ds, 1, c("x", "y"), 0) + # lazy tensor causes no materialization + check(be, ds, 1, "x", 0) + + # more than one row also works + check(be, ds, 2:1, "y", 1) + check(be, ds, c(3, 1), "y", 1) + check(be, ds, 1:3, "y", 0) + + # when caching more than one, we materialize only once per batch + be2 = as_data_backend(ds, dataset_shapes = list(x = c(NA, 1), y = c(NA, 1)), + converter = list(y = as.integer, x = as.integer), cache = c("y", "x")) + + check2 = function(be, ds, rows, cols, n) { + counter_prev = ds$counter + tbl = be$data(rows, cols) + observed_n = ds$counter - counter_prev + expect_equal(observed_n, n) + + expect_equal(tbl$y, as.integer(ds$y[rows])) + expect_equal(tbl$x, as.integer(ds$x[rows])) + } + + check2(be2, ds, 1, c("x", "y"), 1) + check2(be2, ds, 1, c("x", "y"), 0) + check2(be2, ds, 2:1, c("x", "y"), 1) + check2(be2, ds, 2, c("x", "y"), 0) +}) + +test_that("can train a regression learner", { + x = torch_randn(100, 1) + y = x + torch_randn(100, 1) + ds = tensor_dataset( + x = x, + y = y + ) + + be = as_data_backend(ds, dataset_shapes = list(x = c(NA, 1), y = c(NA, 1)), + converter = list(y = as.numeric)) + task = as_task_regr(be, target = "y") + + learner = lrn("regr.mlp", epochs = 10, batch_size = 100, jit_trace = TRUE, opt.lr = 1, seed = 1) + rr = resample(task, learner, rsmp("insample")) + expect_true(rr$aggregate(msr("regr.rmse")) < 1.5) +}) + +test_that("can train a binary classification learner", { + ds = tensor_dataset( + x = torch_tensor(matrix(100:1, nrow = 100, ncol = 1))$float(), + y = torch_tensor(rep(0:1, each = 50))$float()$unsqueeze(2L) + ) + + be = as_data_backend(ds, dataset_shapes = list(x = c(NA, 1), y = c(NA, 1)), + converter = list(y = function(x) factor(as.integer(x), levels = c(1, 0), labels = c("yes", "no")))) + task = as_task_classif(be, target = "y") + + learner = lrn("classif.mlp", epochs = 10, batch_size = 100, jit_trace = TRUE, opt.lr = 10, seed = 1) + rr = resample(task, learner, rsmp("insample")) + expect_true(rr$aggregate(msr("classif.ce")) < 0.1) +}) + +test_that("can train a multiclass classification learner", { + ds = tensor_dataset( + x = torch_tensor(matrix(100:1, nrow = 100, ncol = 1))$float(), + y = torch_tensor(rep(1:4, each = 25)) + ) + + be = as_data_backend(ds, dataset_shapes = list(x = c(NA, 1), y = NA), + converter = list(y = function(x) factor(as.integer(x), levels = 1:4, labels = c("a", "b", "c", "d")))) + task = as_task_classif(be, target = "y") + + learner = lrn("classif.mlp", epochs = 10, batch_size = 100, jit_trace = TRUE, opt.lr = 0.2, seed = 1, + neurons = 100) + rr = resample(task, learner, rsmp("insample")) + # just ensures that we lear something + expect_true(rr$aggregate(msr("classif.ce")) < 0.6) +}) + +test_that("check_lazy_tensors_backend works", { + ds = tensor_dataset( + x = torch_tensor(matrix(100:1, nrow = 100, ncol = 1))$float(), + y = torch_tensor(as.matrix(1:100, nrow = 100, ncol = 1))$float() + ) + + be = as_data_backend(ds, dataset_shapes = list(x = c(NA, 1), y = c(NA, 1)), + converter = list(y = as.numeric)) + task_orig = as_task_regr(be, target = "y") + + expect_error(check_lazy_tensors_backend(task_orig$backend, c("x", "y")), + regexp = NA) + + task1 = task_orig$clone(deep = TRUE)$cbind(data.table(y = 1:100)) + expect_error(check_lazy_tensors_backend(task1$backend, c("x", "y")), + regexp = "A converter column ('y')", fixed = TRUE) + + task2 = task_orig$clone(deep = TRUE)$rbind(data.table(x = as_lazy_tensor(1), y = 2, row_id = 999)) + expect_error(check_lazy_tensors_backend(task2$backend, c("x", "y")), + regexp = "A converter column ('y')", fixed = TRUE) +}) + + +test_that("...", { + ds = dataset( + initialize = function(x, y) { + self$x = torch_randn(100, 3) + self$y = torch_randn(100, 1) + self$counter = 0 + }, + .getbatch = function(i) { + print("hallo") + self$counter = self$counter + 1L + list(x = self$x[i, drop = FALSE], y = self$y[i, drop = FALSE]) + }, + .length = function() 100 + )() + +task = as_task_regr(ds, target = "y") + +counter = ds$counter +task$head() +print(ds$counter - counter) +counter = ds$counter +task$head() +expec +print(ds$counter - counter) + +}) diff --git a/tests/testthat/test_lazy_tensor.R b/tests/testthat/test_lazy_tensor.R index b74208f8c..95881b083 100644 --- a/tests/testthat/test_lazy_tensor.R +++ b/tests/testthat/test_lazy_tensor.R @@ -3,8 +3,6 @@ test_that("prototype", { expect_class(proto, "lazy_tensor") expect_true(length(proto) == 0L) expect_error(dd(proto)) - - expect_error(materialize(lazy_tensor()), "Cannot materialize") }) test_that("input checks", { diff --git a/tests/testthat/test_materialize.R b/tests/testthat/test_materialize.R index 170f673a4..f1fc5a8ae 100644 --- a/tests/testthat/test_materialize.R +++ b/tests/testthat/test_materialize.R @@ -17,8 +17,6 @@ test_that("materialize works on lazy_tensor", { expect_equal(torch_cat(map(output_meta_list, function(x) x$unsqueeze(1)), dim = 1L)$shape, output_meta_tnsr$shape) expect_true(output_meta_tnsr$device == torch_device("meta")) - - expect_error(materialize(lazy_tensor()), "Cannot materialize ") }) test_that("materialize works with differing shapes (hence uses .getitem)", { @@ -75,7 +73,7 @@ test_that("materialize works with same shapes and .getitem method", { }) test_that("materialize_internal works", { - expect_error(materialize_internal(lazy_tensor()), "Cannot materialize ") + expect_error(materialize_internal(lazy_tensor()), "Cannot access data descriptor") task = tsk("lazy_iris") x = task$data(1:2, cols = "x")[[1L]] res1 = materialize(x) @@ -184,3 +182,8 @@ test_that("PipeOpFeatureUnion can properly check whether two lazy tensors are id expect_error(graph$train(task), "cannot aggregate different features sharing") }) + +test_that("0-length", { + expect_equal(torch_empty(0L), materialize(lazy_tensor(), rbind = TRUE)) + expect_equal(list(), materialize(lazy_tensor(), rbind = FALSE)) +}) diff --git a/tests/testthat/test_shape.R b/tests/testthat/test_shape.R index dbafcb4dc..b1670a96e 100644 --- a/tests/testthat/test_shape.R +++ b/tests/testthat/test_shape.R @@ -21,4 +21,6 @@ test_that("assert_shape and friends", { expect_error(assert_shape(c(NA, 1, 2), len = 2)) # NULL is ok even when len is specified expect_true(check_shape(NULL, null_ok = TRUE, len = 2)) + # NA is valid shape + expect_true(check_shape(NA)) }) diff --git a/tests/testthat/test_utils.R b/tests/testthat/test_utils.R index f8c0477d1..0f6943f57 100644 --- a/tests/testthat/test_utils.R +++ b/tests/testthat/test_utils.R @@ -56,6 +56,7 @@ test_that("order_named_args works", { expect_error(order_named_args(function(..., x) NULL, list(2, 3, x = 1)), regexp = "`...` must") expect_error(order_named_args(function(y, ..., x) NULL, list(y = 4, 2, 3, x = 1)), regexp = "`...` must") }) + test_that("shape_to_str works", { expect_equal(shape_to_str(1), "(1)") expect_equal(shape_to_str(c(1, 2)), "(1,2)")