diff --git a/DESCRIPTION b/DESCRIPTION
index 4772f52fd..045182c1f 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -91,6 +91,7 @@ Collate:
'CallbackSetUnfreeze.R'
'ContextTorch.R'
'DataBackendLazy.R'
+ 'DataBackendLazyTensors.R'
'utils.R'
'DataDescriptor.R'
'LearnerTorch.R'
diff --git a/NAMESPACE b/NAMESPACE
index 122ebafb6..1ed76e181 100644
--- a/NAMESPACE
+++ b/NAMESPACE
@@ -7,11 +7,15 @@ S3method("[[<-",lazy_tensor)
S3method(as.data.table,DictionaryMlr3torchCallbacks)
S3method(as.data.table,DictionaryMlr3torchLosses)
S3method(as.data.table,DictionaryMlr3torchOptimizers)
+S3method(as_data_backend,dataset)
S3method(as_data_descriptor,dataset)
S3method(as_lazy_tensor,DataDescriptor)
S3method(as_lazy_tensor,dataset)
S3method(as_lazy_tensor,numeric)
S3method(as_lazy_tensor,torch_tensor)
+S3method(as_lazy_tensors,dataset)
+S3method(as_task_classif,dataset)
+S3method(as_task_regr,dataset)
S3method(as_torch_callback,R6ClassGenerator)
S3method(as_torch_callback,TorchCallback)
S3method(as_torch_callback,character)
@@ -27,6 +31,8 @@ S3method(as_torch_optimizer,character)
S3method(as_torch_optimizer,torch_optimizer_generator)
S3method(c,lazy_tensor)
S3method(col_info,DataBackendLazy)
+S3method(col_info,DataBackendLazyTensors)
+S3method(distinct_values,lazy_tensor)
S3method(format,lazy_tensor)
S3method(hash_input,TorchIngressToken)
S3method(hash_input,lazy_tensor)
@@ -71,6 +77,7 @@ export(CallbackSetTB)
export(CallbackSetUnfreeze)
export(ContextTorch)
export(DataBackendLazy)
+export(DataBackendLazyTensors)
export(DataDescriptor)
export(LearnerTorch)
export(LearnerTorchFeatureless)
@@ -161,6 +168,7 @@ export(TorchLoss)
export(TorchOptimizer)
export(as_data_descriptor)
export(as_lazy_tensor)
+export(as_lazy_tensors)
export(as_lr_scheduler)
export(as_torch_callback)
export(as_torch_callbacks)
diff --git a/NEWS.md b/NEWS.md
index 9f2cbf174..1d581c0b2 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -11,6 +11,8 @@
This means that for binary classification tasks, `t_loss("cross_entropy")` now generates
`nn_bce_with_logits_loss` instead of `nn_cross_entropy_loss`.
This also came with a reparametrization of the `t_loss("cross_entropy")` loss (thanks to @tdhock, #374).
+* fix: `NA` is now a valid shape for lazy tensors.
+* feat: `lazy_tensor`s of length 0 can now be materialized.
# mlr3torch 0.2.1
diff --git a/R/DataBackendLazyTensors.R b/R/DataBackendLazyTensors.R
new file mode 100644
index 000000000..2328a86b4
--- /dev/null
+++ b/R/DataBackendLazyTensors.R
@@ -0,0 +1,249 @@
+
+#' @title Special Backend for Lazy Tensors
+#' @description
+#' This backend essentially allows you to use a [`torch::dataset`] directly with
+#' an [`mlr3::Learner`].
+#'
+#' * The data cannot contain missing values, as [`lazy_tensor`]s do not support them.
+#' For this reason, calling `$missings()` will always return `0` for all columns.
+#' * The `$distinct()` method will consider two lazy tensors that refer to the same element of a
+#' [`DataDescriptor`] to be identical.
+#' This means, that it might be underreporting the number of distinct values of lazy tensor columns.
+#'
+#' @export
+#' @examplesIf torch::torch_is_installed()
+#' # used as feature in all backends
+#' x = torch_randn(100, 10)
+#' # regression
+#' ds_regr = tensor_dataset(x = x, y = torch_randn(100, 1))
+#' be_regr = as_data_backend(ds_regr, converter = list(y = as.numeric))
+#' be_regr$head()
+#'
+#'
+#' # binary classification: underlying target tensor must be float in [0, 1]
+#' ds_binary = tensor_dataset(x = x, y = torch_randint(0, 2, c(100, 1))$float())
+#' be_binary = as_data_backend(ds_binary, converter = list(
+#' y = function(x) factor(as.integer(x), levels = c(0, 1), labels = c("A", "yes"))
+#' ))
+#' be_binary$head()
+#'
+#' # multi-class classification: underlying target tensor must be integer in [1, K]
+#' ds_multiclass = tensor_dataset(x = x, y = torch_randint(1, 4, size = c(100, 1)))
+#' be_multiclass = as_data_backend(ds_multiclass, converter = list(y = as.numeric))
+#' be_multiclass$head()
+
+DataBackendLazyTensors = R6Class("DataBackendLazyTensors",
+ cloneable = FALSE,
+ inherit = DataBackendDataTable,
+ public = list(
+ chunk_size = NULL,
+ #' @description
+ #' Create a new instance of this [R6][R6::R6Class] class.
+ #' @param data (`data.table`)\cr
+ #' Data containing (among others) [`lazy_tensor`] columns.
+ #' @param primary_key (`character(1)`)\cr
+ #' Name of the column used as primary key.
+ #' @param converter (named `list()` of `function`s)\cr
+ #' A named list of functions that convert the lazy tensor columns to their R representation.
+ #' The names must be the names of the columns that need conversion.
+ #' @param cache (`character()`)\cr
+ #' Names of the columns that should be cached.
+ #' Per default, all columns that are converted are cached.
+ initialize = function(data, primary_key, converter, cache = names(converter), chunk_size = 100) {
+ private$.converter = assert_list(converter, types = "function", any.missing = FALSE)
+ assert_subset(names(converter), colnames(data))
+ assert_subset(cache, names(converter), empty.ok = TRUE)
+ private$.cached_cols = assert_subset(cache, names(converter))
+ self$chunk_size = assert_int(chunk_size, lower = 1L)
+ walk(names(private$.converter), function(nm) {
+ if (!inherits(data[[nm]], "lazy_tensor")) {
+ stopf("Column '%s' is not a lazy tensor.", nm)
+ }
+ })
+ super$initialize(data, primary_key)
+ # select the column whose name is stored in primary_key from private$.data but keep its name
+ private$.data_cache = private$.data[, primary_key, with = FALSE]
+ },
+ data = function(rows, cols) {
+ rows = assert_integerish(rows, coerce = TRUE)
+ assert_names(cols, type = "unique")
+
+ if (getOption("mlr3torch.data_loading", FALSE)) {
+ # no caching, no materialization as this is called in the training loop
+ return(super$data(rows, cols))
+ }
+ if (all(intersect(cols, private$.cached_cols) %in% names(private$.data_cache))) {
+ expensive_cols = intersect(cols, private$.cached_cols)
+ other_cols = setdiff(cols, expensive_cols)
+ cache_hit = private$.data_cache[list(rows), expensive_cols, on = self$primary_key, with = FALSE]
+ complete = complete.cases(cache_hit)
+ cache_hit = cache_hit[complete]
+ if (nrow(cache_hit) == length(rows)) {
+ tbl = cbind(cache_hit, super$data(rows, other_cols))
+ setcolorder(tbl, cols)
+ return(tbl)
+ }
+ combined = rbindlist(list(cache_hit, private$.load_and_cache(rows[!complete], expensive_cols)))
+ reorder = vector("integer", nrow(combined))
+ reorder[complete] = seq_len(nrow(cache_hit))
+ reorder[!complete] = nrow(cache_hit) + seq_len(nrow(combined) - nrow(cache_hit))
+
+ tbl = cbind(combined[reorder], super$data(rows, other_cols))
+ setcolorder(tbl, cols)
+ return(tbl)
+ }
+
+ private$.load_and_cache(rows, cols)
+ },
+ head = function(n = 6L) {
+ if (getOption("mlr3torch.data_loading", FALSE)) {
+ return(super$head(n))
+ }
+
+ self$data(seq_len(n), self$colnames)
+ },
+ missings = function(rows, cols) {
+ set_names(rep(0L, length(cols)), cols)
+ }
+ ),
+ active = list(
+ converter = function(rhs) {
+ assert_ro_binding(rhs)
+ private$.converter
+ }
+ ),
+ private = list(
+ # call this function only with rows that are not in the cache yet
+ .load_and_cache = function(rows, cols) {
+ # Process columns that need conversion
+ tbl = super$data(rows, cols)
+ cols_to_convert = intersect(names(private$.converter), names(tbl))
+ tbl_to_mat = tbl[, cols_to_convert, with = FALSE]
+ # chunk the rows of tbl_to_mat into chunks of size self$chunk_size, apply materialize
+ n = nrow(tbl_to_mat)
+ chunks = split(seq_len(n), rep(seq_len(ceiling(n / self$chunk_size)), each = self$chunk_size, length.out = n))
+
+ tbl_mat = if (n == 0) {
+ set_names(list(torch_empty(0)), names(tbl_to_mat))
+ } else {
+ set_names(lapply(transpose_list(lapply(chunks, function(chunk) {
+ materialize(tbl_to_mat[chunk, ], rbind = TRUE)
+ })), torch_cat, dim = 1L), names(tbl_to_mat))
+ }
+
+ for (nm in cols_to_convert) {
+ converted = private$.converter[[nm]](tbl_mat[[nm]])
+ tbl[[nm]] = converted
+
+ if (nm %in% private$.cached_cols) {
+ set(private$.data_cache, i = rows, j = nm, value = converted)
+ }
+ }
+ return(tbl)
+ },
+ .data_cache = NULL,
+ .converter = NULL,
+ .cached_cols = NULL
+ )
+)
+
+#' @export
+as_data_backend.dataset = function(x, dataset_shapes, ...) {
+ tbl = as_lazy_tensors(x, dataset_shapes, ...)
+ tbl$row_id = seq_len(nrow(tbl))
+ DataBackendLazyTensors$new(tbl, primary_key = "row_id", ...)
+}
+
+#' @export
+as_task_classif.dataset = function(x, target, levels, converter = NULL, dataset_shapes = NULL, chunk_size = 100, cache = names(converter), ...) {
+ if (length(x) < 2) {
+ stopf("Dataset must have at least 2 rows.")
+ }
+ batch = dataloader(x, batch_size = 2)$.iter()$.next()
+ if (is.null(converter)) {
+ if (length(levels) == 2) {
+ if (batch[[target]]$dtype != torch_float()) {
+ stopf("Target must be a float tensor, but has dtype %s", batch[[target]]$dtype)
+ }
+ if (test_equal(batch[[target]]$shape, c(2L, 1L))) {
+ converter = set_names(list(crate(function(x) factor(as.integer(x), levels = 0:1, labels = levels), levels)), target)
+ } else {
+ stopf("Target must be a float tensor of shape (batch_size, 1), but has shape (batch_size, %s)",
+ paste(batch[[target]]$shape[-1L], collapse = ", "))
+ }
+ converter = set_names(list(crate(function(x) factor(as.integer(x), levels = 0:1, labels = levels), levels)), target)
+ } else {
+ if (batch[[target]]$dtype != torch_int()) {
+ stopf("Target must be an integer tensor, but has dtype %s", batch[[target]]$dtype)
+ }
+ if (test_equal(batch[[target]]$shape, 2L)) {
+ converter = set_names(list(crate(function(x) factor(as.integer(x), labels = levels), levels)), target)
+ } else {
+ stopf("Target must be an integer tensor of shape (batch_size), but has shape (batch_size, %s)",
+ paste(batch[[target]]$shape[-1L], collapse = ", "))
+ }
+ converter = set_names(list(crate(function(x) factor(as.integer(x), labels = levels), levels)), target)
+ }
+ }
+ be = as_data_backend(x, dataset_shapes, converter = converter, cache = cache, chunk_size = chunk_size)
+ as_task_classif(be, target = target, ...)
+}
+
+#' @export
+as_task_regr.dataset = function(x, target, converter = NULL, dataset_shapes = NULL, chunk_size = 100, cache = names(converter), ...) {
+ if (length(x) < 2) {
+ stopf("Dataset must have at least 2 rows.")
+ }
+ if (is.null(converter)) {
+ converter = set_names(list(as.numeric), target)
+ }
+ batch = dataloader(x, batch_size = 2)$.iter()$.next()
+
+ if (batch[[target]]$dtype != torch_float()) {
+ stopf("Target must be a float tensor, but has dtype %s", batch[[target]]$dtype)
+ }
+
+ if (!test_equal(batch[[target]]$shape, c(2L, 1L))) {
+ stopf("Target must be a float tensor of shape (batch_size, 1), but has shape (batch_size, %s)",
+ paste(batch[[target]]$shape[-1L], collapse = ", "))
+ }
+
+ dataset_shapes = get_or_check_dataset_shapes(x, dataset_shapes)
+ be = as_data_backend(x, dataset_shapes, converter = converter, cache = cache, chunk_size = chunk_size)
+ as_task_regr(be, target = target, ...)
+}
+
+#' @export
+col_info.DataBackendLazyTensors = function(x, ...) { # nolint
+ first_row = x$head(1L)
+ types = map_chr(first_row, function(x) class(x)[1L])
+ discrete = setdiff(names(types)[types %chin% c("factor", "ordered")], x$primary_key)
+ levels = insert_named(named_list(names(types)), map(first_row[, discrete, with = FALSE], levels))
+ data.table(id = names(types), type = unname(types), levels = levels, key = "id")
+}
+
+
+# conservative check that avoids that a pseudo-lazy-tensor is preprocessed by some pipeop
+# @param be
+# the backend
+# @param candidates
+# the feature and target names
+# @param visited
+# Union of all colnames already visited
+# @return visited
+check_lazy_tensors_backend = function(be, candidates, visited = character()) {
+ if (inherits(be, "DataBackendRbind") || inherits(be, "DataBackendCbind")) {
+ bs = be$.__enclos_env__$private$.data
+ # first we check b2, then b1, because b2 possibly overshadows some b1 rows/cols
+ visited = check_lazy_tensors_backend(bs$b2, candidates, visited)
+ check_lazy_tensors_backend(bs$b1, candidates, visited)
+ } else {
+ if (inherits(be, "DataBackendLazyTensors")) {
+ if (any(names(be$converter) %in% visited)) {
+ converter_cols = names(be$converter)[names(be$converter) %in% visited]
+ stopf("A converter column ('%s') from a DataBackendLazyTensors was presumably preprocessed by some PipeOp. This can cause inefficiencies and is therefore not allowed. If you want to preprocess them, please directly encode them as R types.", paste0(converter_cols, collapse = ", ")) # nolint
+ }
+ }
+ union(visited, intersect(candidates, be$colnames))
+ }
+}
diff --git a/R/DataDescriptor.R b/R/DataDescriptor.R
index 1bf3cd68d..6a1d65740 100644
--- a/R/DataDescriptor.R
+++ b/R/DataDescriptor.R
@@ -60,14 +60,7 @@ DataDescriptor = R6Class("DataDescriptor",
# For simplicity we here require the first dimension of the shape to be NA so we don't have to deal with it,
# e.g. during subsetting
- if (is.null(dataset_shapes)) {
- if (is.null(dataset$.getbatch)) {
- stopf("dataset_shapes must be provided if dataset does not have a `.getbatch` method.")
- }
- dataset_shapes = infer_shapes_from_getbatch(dataset)
- } else {
- assert_compatible_shapes(dataset_shapes, dataset)
- }
+ dataset_shapes = get_or_check_dataset_shapes(dataset, dataset_shapes)
if (is.null(graph)) {
# avoid name conflicts
@@ -84,8 +77,7 @@ DataDescriptor = R6Class("DataDescriptor",
assert_true(length(graph$pipeops) >= 1L)
}
# no preprocessing, dataset returns only a single element (there we can infer a lot)
- simple_case = length(graph$pipeops) == 1L && inherits(graph$pipeops[[1L]], "PipeOpNOP") &&
- length(dataset_shapes) == 1L
+ simple_case = (length(graph$pipeops) == 1L) && inherits(graph$pipeops[[1L]], "PipeOpNOP")
if (is.null(input_map) && nrow(graph$input) == 1L && length(dataset_shapes) == 1L) {
input_map = names(dataset_shapes)
@@ -100,7 +92,7 @@ DataDescriptor = R6Class("DataDescriptor",
assert_choice(pointer[[2]], graph$pipeops[[pointer[[1]]]]$output$name)
}
if (is.null(pointer_shape) && simple_case) {
- pointer_shape = dataset_shapes[[1L]]
+ pointer_shape = dataset_shapes[[input_map]]
} else {
assert_shape(pointer_shape, null_ok = TRUE)
}
@@ -225,13 +217,14 @@ infer_shapes_from_getbatch = function(ds) {
}
assert_compatible_shapes = function(shapes, dataset) {
- assert_shapes(shapes, null_ok = TRUE, unknown_batch = TRUE, named = TRUE)
+ shapes = assert_shapes(shapes, null_ok = TRUE, unknown_batch = TRUE, named = TRUE, coerce = TRUE)
# prevent user from e.g. forgetting to wrap the return in a list
- example = if (is.null(dataset$.getbatch)) {
- dataset$.getitem(1L)
- } else {
+ has_getbatch = !is.null(dataset$.getbatch)
+ example = if (has_getbatch) {
dataset$.getbatch(1L)
+ } else {
+ dataset$.getitem(1L)
}
if (!test_list(example, names = "unique") || !test_permutation(names(example), names(shapes))) {
stopf("Dataset must return a list with named elements that are a permutation of the dataset_shapes names.")
@@ -242,17 +235,17 @@ assert_compatible_shapes = function(shapes, dataset) {
}
})
- if (is.null(dataset$.getbatch)) {
- example = map(example, function(x) x$unsqueeze(1))
- }
-
iwalk(shapes, function(dataset_shape, name) {
- if (!is.null(dataset_shape) && !test_equal(shapes[[name]][-1], example[[name]]$shape[-1L])) {
- expected_shape = example[[name]]$shape
- expected_shape[1] = NA
+ observed_shape = example[[name]]$shape
+ if (has_getbatch) {
+ observed_shape[1L] = NA_integer_
+ } else {
+ observed_shape = c(NA_integer_, observed_shape)
+ }
+ if (!is.null(dataset_shape) && !test_equal(observed_shape, dataset_shape)) {
stopf(paste0("First batch from dataset is incompatible with the provided shape of %s:\n",
- "* Provided shape: %s.\n* Expected shape: %s."), name,
- shape_to_str(unname(shapes[name])), shape_to_str(list(expected_shape)))
+ "* Provided shape: %s.\n* Observed shape: %s."), name,
+ shape_to_str(unname(shapes[name])), shape_to_str(list(observed_shape)))
}
})
}
diff --git a/R/LearnerTorch.R b/R/LearnerTorch.R
index af1db6e5e..068d54a84 100644
--- a/R/LearnerTorch.R
+++ b/R/LearnerTorch.R
@@ -109,7 +109,7 @@
#'
#' For information on the expected target encoding of `y`, see section *Network Head and Target Encoding*.
#' Moreover, one needs to pay attention respect the row ids of the provided task.
-#' It is recommended to relu on [`task_dataset`] for creating the [`dataset`][torch::dataset].
+#' It is strongly recommended to use the [`task_dataset`] class to create the dataset.
#'
#' It is also possible to overwrite the private `.dataloader()` method.
#' This must respect the dataloader parameters from the [`ParamSet`][paradox::ParamSet].
diff --git a/R/lazy_tensor.R b/R/lazy_tensor.R
index d050f8545..00b397575 100644
--- a/R/lazy_tensor.R
+++ b/R/lazy_tensor.R
@@ -197,6 +197,19 @@ as_lazy_tensor.torch_tensor = function(x, ...) { # nolint
as_lazy_tensor(ds, dataset_shapes = list(x = c(NA, dim(x)[-1])))
}
+#' @export
+as_lazy_tensors = function(x, ...) {
+ UseMethod("as_lazy_tensors")
+}
+
+#' @export
+as_lazy_tensors.dataset = function(x, dataset_shapes = NULL, ...) {
+ dataset_shapes = get_or_check_dataset_shapes(x, dataset_shapes)
+ set_names(map_dtc(names(dataset_shapes), function(shape) {
+ as_lazy_tensor(x, dataset_shapes = dataset_shapes, input_map = shape)
+ }), names(dataset_shapes))
+}
+
#' Assert Lazy Tensor
#'
#' Asserts whether something is a lazy tensor.
@@ -339,3 +352,13 @@ rep.lazy_tensor = function(x, ...) {
rep_len.lazy_tensor = function(x, ...) {
set_class(NextMethod(), c("lazy_tensor", "list"))
}
+
+
+#' @export
+distinct_values.lazy_tensor = function(x, drop = TRUE, na_rm = TRUE) {
+ if (!length(x)) {
+ return(x)
+ }
+ ids = distinct_values(map_int(x, 1))
+ lazy_tensor(dd(x), ids)
+}
\ No newline at end of file
diff --git a/R/learner_torch_methods.R b/R/learner_torch_methods.R
index 79cebaa4a..7259bf587 100644
--- a/R/learner_torch_methods.R
+++ b/R/learner_torch_methods.R
@@ -18,8 +18,10 @@ learner_torch_predict = function(self, private, super, task, param_vals) {
private$.encode_prediction(predict_tensor = predict_tensor, task = task)
}
+
learner_torch_train = function(self, private, super, task, param_vals) {
# Here, all param_vals (like seed = "random" or device = "auto") have already been resolved
+ check_lazy_tensors_backend(task$backend, c(task$feature_names, task$target_names))
dataset_train = private$.dataset(task, param_vals)
dataset_train = as_multi_tensor_dataset(dataset_train, param_vals)
loader_train = private$.dataloader(dataset_train, param_vals)
@@ -356,3 +358,5 @@ as_multi_tensor_dataset = function(dataset, param_vals) {
dataset
}
}
+
+
diff --git a/R/materialize.R b/R/materialize.R
index 849024ad4..ee113830d 100644
--- a/R/materialize.R
+++ b/R/materialize.R
@@ -63,6 +63,13 @@ materialize.list = function(x, device = "cpu", rbind = FALSE, cache = "auto", ..
map(x, function(col) {
if (is_lazy_tensor(col)) {
+ if (length(col) == 0L) {
+ if (rbind) {
+ return(torch_empty(0L))
+ } else {
+ return(list())
+ }
+ }
materialize_internal(col, device = device, cache = cache, rbind = rbind)
} else {
col
@@ -76,16 +83,30 @@ materialize.list = function(x, device = "cpu", rbind = FALSE, cache = "auto", ..
#' @method materialize data.frame
#' @export
materialize.data.frame = function(x, device = "cpu", rbind = FALSE, cache = "auto", ...) { # nolint
+ if (nrow(x) == 0L) {
+ if (rbind) {
+ set_names(replicate(ncol(x), torch_empty(0L)), names(x))
+ } else {
+ set_names(replicate(ncol(x), list()), names(x))
+ }
+ }
materialize(as.list(x), device = device, rbind = rbind, cache = cache)
}
#' @export
materialize.lazy_tensor = function(x, device = "cpu", rbind = FALSE, ...) { # nolint
+ if (length(x) == 0L) {
+ if (rbind) {
+ return(torch_empty(0L))
+ } else {
+ return(list())
+ }
+ }
materialize_internal(x = x, device = device, cache = NULL, rbind = rbind)
}
-get_input = function(ds, ids, varying_shapes, rbind) {
+get_input = function(ds, ids, varying_shapes) {
if (is.null(ds$.getbatch)) { # .getindex is never NULL but a function that errs if it was not defined
x = map(ids, function(id) map(ds$.getitem(id), function(x) x$unsqueeze(1)))
if (varying_shapes) {
@@ -154,9 +175,6 @@ get_output = function(input, graph, varying_shapes, rbind, device) {
#' @return [`lazy_tensor()`]
#' @keywords internal
materialize_internal = function(x, device = "cpu", cache = NULL, rbind) {
- if (!length(x)) {
- stopf("Cannot materialize lazy tensor of length 0.")
- }
do_caching = !is.null(cache)
ids = map_int(x, 1)
@@ -183,7 +201,7 @@ materialize_internal = function(x, device = "cpu", cache = NULL, rbind) {
}
if (!do_caching || !input_hit) {
- input = get_input(ds, ids, varying_shapes, rbind)
+ input = get_input(ds, ids, varying_shapes)
}
if (do_caching && !input_hit) {
diff --git a/R/shape.R b/R/shape.R
index d1fdda83d..7970c37ec 100644
--- a/R/shape.R
+++ b/R/shape.R
@@ -30,7 +30,7 @@ test_shape = function(shape, null_ok = FALSE, unknown_batch = NULL, len = NULL)
if (is.null(shape) && null_ok) {
return(TRUE)
}
- ok = test_integerish(shape, min.len = 2L, all.missing = FALSE, any.missing = TRUE, len = len)
+ ok = test_integerish(shape, min.len = 1L, any.missing = TRUE, len = len)
if (!ok) {
return(FALSE)
diff --git a/R/task_dataset.R b/R/task_dataset.R
index bd088d1bf..10ab6e268 100644
--- a/R/task_dataset.R
+++ b/R/task_dataset.R
@@ -81,13 +81,21 @@ task_dataset = dataset("task_dataset",
.getbatch = function(index) {
cache = if (self$cache_lazy_tensors) new.env()
- datapool = self$task$data(rows = self$task$row_ids[index], cols = self$all_features)
+ datapool = withr::with_options(list(mlr3torch.data_loading = TRUE), {
+ self$task$data(rows = self$task$row_ids[index], cols = self$all_features)
+ })
+
x = lapply(self$feature_ingress_tokens, function(it) {
it$batchgetter(datapool[, it$features, with = FALSE], cache = cache)
})
y = if (!is.null(self$target_batchgetter)) {
- self$target_batchgetter(datapool[, self$task$target_names, with = FALSE])
+ target = datapool[, self$task$target_names, with = FALSE]
+ if (!inherits(target[[1L]], "lazy_tensor")) {
+ self$target_batchgetter(target)
+ } else {
+ materialize(target[[1L]], rbind = TRUE)
+ }
}
out = list(x = x, .index = torch_tensor(index, dtype = torch_long()))
if (!is.null(y)) out$y = y
diff --git a/R/utils.R b/R/utils.R
index 2f9d68302..bcb17af66 100644
--- a/R/utils.R
+++ b/R/utils.R
@@ -190,7 +190,10 @@ list_to_batch = function(tensors) {
}
auto_cache_lazy_tensors = function(lts) {
- any(duplicated(map_chr(lts, function(x) dd(x)$dataset_hash)))
+ if (length(lts) <= 1L) {
+ return(FALSE)
+ }
+ anyDuplicated(unlist(map_if(lts, function(x) length(x) > 0, function(x) dd(x)$dataset_hash))) > 0L
}
#' Replace the head of a network
@@ -300,6 +303,18 @@ infer_shapes = function(shapes_in, param_vals, output_names, fn, rowwise, id) {
set_names(list(sout), output_names)
}
+get_or_check_dataset_shapes = function(dataset, dataset_shapes) {
+ if (is.null(dataset_shapes)) {
+ if (is.null(dataset$.getbatch)) {
+ stopf("dataset_shapes must be provided if dataset does not have a `.getbatch` method.")
+ }
+ dataset_shapes = infer_shapes_from_getbatch(dataset)
+ } else {
+ assert_compatible_shapes(dataset_shapes, dataset)
+ }
+ dataset_shapes
+}
+
#' @title Network Output Dimension
#' @description
#' Calculates the output dimension of a neural network for a given task that is expected by
diff --git a/TODO.md b/TODO.md
new file mode 100644
index 000000000..62d953809
--- /dev/null
+++ b/TODO.md
@@ -0,0 +1,25 @@
+* Add `as_lazy_tensors()`
+* Make it easier to se
+* Fix the bug that the shapes are reported as unknown below and make the code easier.
+ ```r
+ ds = dataset("test",
+ initialize = function() {
+ self$x = torch_randn(100, 10)
+ self$y = torch_randn(100, 1)
+ },
+ .getitem = function(i) {
+ list(x = self$x[i, ], y = self$y[i])
+ },
+ .length = function() {
+ nrow(self$x)
+ }
+ )()
+ x_lt = as_lazy_tensor(ds, list(x = c(NA, 10), y = c(NA, 1)), input_map = "x")
+ y_lt = as_lazy_tensor(ds, list(x = c(NA, 10), y = c(NA, 1)), input_map = "y")
+
+ tbl = data.table(x = x_lt, y = y_lt)
+ ```
+* Add checks on usage of `DataBackendLazyTensors` in `task_dataset`
+* Add optimization that truths values don't have to be loaded twice during resampling, i.e.
+ once for making the predictions and once for retrieving the truth column.
+* only allow caching converter columns in `DataBackendLazyTensors` (probably just remove the `cache` parameter)
\ No newline at end of file
diff --git a/man/DataBackendLazyTensors.Rd b/man/DataBackendLazyTensors.Rd
new file mode 100644
index 000000000..42880dec7
--- /dev/null
+++ b/man/DataBackendLazyTensors.Rd
@@ -0,0 +1,123 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/DataBackendLazyTensors.R
+\name{DataBackendLazyTensors}
+\alias{DataBackendLazyTensors}
+\title{Special Backend for Lazy Tensors}
+\description{
+This backend essentially allows you to use a \code{\link[torch:dataset]{torch::dataset}} directly with
+an \code{\link[mlr3:Learner]{mlr3::Learner}}.
+\itemize{
+\item The data cannot contain missing values, as \code{\link{lazy_tensor}}s do not support them.
+For this reason, calling \verb{$missings()} will always return \code{0} for all columns.
+\item The \verb{$distinct()} method will consider two lazy tensors that refer to the same element of a
+\code{\link{DataDescriptor}} to be identical.
+This means, that it might be underreporting the number of distinct values of lazy tensor columns.
+}
+}
+\examples{
+\dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf}
+# used as feature in all backends
+x = torch_randn(100, 10)
+# regression
+ds_regr = tensor_dataset(x = x, y = torch_randn(100, 1))
+be_regr = as_data_backend(ds_regr, converter = list(y = as.numeric))
+be_regr$head()
+
+
+# binary classification: underlying target tensor must be float in [0, 1]
+ds_binary = tensor_dataset(x = x, y = torch_randint(0, 2, c(100, 1))$float())
+be_binary = as_data_backend(ds_binary, converter = list(
+ y = function(x) factor(as.integer(x), levels = c(0, 1), labels = c("A", "yes"))
+))
+be_binary$head()
+
+# multi-class classification: underlying target tensor must be integer in [1, K]
+ds_multiclass = tensor_dataset(x = x, y = torch_randint(1, 4, size = c(100, 1)))
+be_multiclass = as_data_backend(ds_multiclass, converter = list(y = as.numeric))
+be_multiclass$head()
+\dontshow{\}) # examplesIf}
+}
+\section{Super classes}{
+\code{\link[mlr3:DataBackend]{mlr3::DataBackend}} -> \code{\link[mlr3:DataBackendDataTable]{mlr3::DataBackendDataTable}} -> \code{DataBackendLazyTensors}
+}
+\section{Methods}{
+\subsection{Public methods}{
+\itemize{
+\item \href{#method-DataBackendLazyTensors-new}{\code{DataBackendLazyTensors$new()}}
+\item \href{#method-DataBackendLazyTensors-data}{\code{DataBackendLazyTensors$data()}}
+\item \href{#method-DataBackendLazyTensors-head}{\code{DataBackendLazyTensors$head()}}
+\item \href{#method-DataBackendLazyTensors-missings}{\code{DataBackendLazyTensors$missings()}}
+}
+}
+\if{html}{\out{
+Inherited methods
+
+