Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Collate:
'CallbackSetUnfreeze.R'
'ContextTorch.R'
'DataBackendLazy.R'
'DataBackendLazyTensors.R'
'utils.R'
'DataDescriptor.R'
'LearnerTorch.R'
Expand Down
8 changes: 8 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@ S3method("[[<-",lazy_tensor)
S3method(as.data.table,DictionaryMlr3torchCallbacks)
S3method(as.data.table,DictionaryMlr3torchLosses)
S3method(as.data.table,DictionaryMlr3torchOptimizers)
S3method(as_data_backend,dataset)
S3method(as_data_descriptor,dataset)
S3method(as_lazy_tensor,DataDescriptor)
S3method(as_lazy_tensor,dataset)
S3method(as_lazy_tensor,numeric)
S3method(as_lazy_tensor,torch_tensor)
S3method(as_lazy_tensors,dataset)
S3method(as_task_classif,dataset)
S3method(as_task_regr,dataset)
S3method(as_torch_callback,R6ClassGenerator)
S3method(as_torch_callback,TorchCallback)
S3method(as_torch_callback,character)
Expand All @@ -27,6 +31,8 @@ S3method(as_torch_optimizer,character)
S3method(as_torch_optimizer,torch_optimizer_generator)
S3method(c,lazy_tensor)
S3method(col_info,DataBackendLazy)
S3method(col_info,DataBackendLazyTensors)
S3method(distinct_values,lazy_tensor)
S3method(format,lazy_tensor)
S3method(hash_input,TorchIngressToken)
S3method(hash_input,lazy_tensor)
Expand Down Expand Up @@ -71,6 +77,7 @@ export(CallbackSetTB)
export(CallbackSetUnfreeze)
export(ContextTorch)
export(DataBackendLazy)
export(DataBackendLazyTensors)
export(DataDescriptor)
export(LearnerTorch)
export(LearnerTorchFeatureless)
Expand Down Expand Up @@ -161,6 +168,7 @@ export(TorchLoss)
export(TorchOptimizer)
export(as_data_descriptor)
export(as_lazy_tensor)
export(as_lazy_tensors)
export(as_lr_scheduler)
export(as_torch_callback)
export(as_torch_callbacks)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
This means that for binary classification tasks, `t_loss("cross_entropy")` now generates
`nn_bce_with_logits_loss` instead of `nn_cross_entropy_loss`.
This also came with a reparametrization of the `t_loss("cross_entropy")` loss (thanks to @tdhock, #374).
* fix: `NA` is now a valid shape for lazy tensors.
* feat: `lazy_tensor`s of length 0 can now be materialized.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since I have not used lazy tensors very much, it would help to see a more detailed description of changes here in NEWS. I also would expect some changes to https://mlr3torch.mlr-org.com/articles/lazy_tensor.html but I do not see any in this PR yet.
What is the typical use case which motivates this PR?


# mlr3torch 0.2.1

Expand Down
249 changes: 249 additions & 0 deletions R/DataBackendLazyTensors.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@

#' @title Special Backend for Lazy Tensors
#' @description
#' This backend essentially allows you to use a [`torch::dataset`] directly with
#' an [`mlr3::Learner`].
#'
#' * The data cannot contain missing values, as [`lazy_tensor`]s do not support them.
#' For this reason, calling `$missings()` will always return `0` for all columns.
#' * The `$distinct()` method will consider two lazy tensors that refer to the same element of a
#' [`DataDescriptor`] to be identical.
#' This means, that it might be underreporting the number of distinct values of lazy tensor columns.
#'
#' @export
#' @examplesIf torch::torch_is_installed()
#' # used as feature in all backends
#' x = torch_randn(100, 10)
#' # regression
#' ds_regr = tensor_dataset(x = x, y = torch_randn(100, 1))
#' be_regr = as_data_backend(ds_regr, converter = list(y = as.numeric))
#' be_regr$head()
#'
#'
#' # binary classification: underlying target tensor must be float in [0, 1]
#' ds_binary = tensor_dataset(x = x, y = torch_randint(0, 2, c(100, 1))$float())
#' be_binary = as_data_backend(ds_binary, converter = list(
#' y = function(x) factor(as.integer(x), levels = c(0, 1), labels = c("A", "yes"))
#' ))
#' be_binary$head()
#'
#' # multi-class classification: underlying target tensor must be integer in [1, K]
#' ds_multiclass = tensor_dataset(x = x, y = torch_randint(1, 4, size = c(100, 1)))
#' be_multiclass = as_data_backend(ds_multiclass, converter = list(y = as.numeric))
#' be_multiclass$head()

DataBackendLazyTensors = R6Class("DataBackendLazyTensors",
cloneable = FALSE,
inherit = DataBackendDataTable,
public = list(
chunk_size = NULL,
#' @description
#' Create a new instance of this [R6][R6::R6Class] class.
#' @param data (`data.table`)\cr
#' Data containing (among others) [`lazy_tensor`] columns.
#' @param primary_key (`character(1)`)\cr
#' Name of the column used as primary key.
#' @param converter (named `list()` of `function`s)\cr
#' A named list of functions that convert the lazy tensor columns to their R representation.
#' The names must be the names of the columns that need conversion.
#' @param cache (`character()`)\cr
#' Names of the columns that should be cached.
#' Per default, all columns that are converted are cached.
initialize = function(data, primary_key, converter, cache = names(converter), chunk_size = 100) {
private$.converter = assert_list(converter, types = "function", any.missing = FALSE)
assert_subset(names(converter), colnames(data))
assert_subset(cache, names(converter), empty.ok = TRUE)
private$.cached_cols = assert_subset(cache, names(converter))
self$chunk_size = assert_int(chunk_size, lower = 1L)
walk(names(private$.converter), function(nm) {
if (!inherits(data[[nm]], "lazy_tensor")) {
stopf("Column '%s' is not a lazy tensor.", nm)
}
})
super$initialize(data, primary_key)
# select the column whose name is stored in primary_key from private$.data but keep its name
private$.data_cache = private$.data[, primary_key, with = FALSE]
},
data = function(rows, cols) {
rows = assert_integerish(rows, coerce = TRUE)
assert_names(cols, type = "unique")

if (getOption("mlr3torch.data_loading", FALSE)) {
# no caching, no materialization as this is called in the training loop
return(super$data(rows, cols))
}
if (all(intersect(cols, private$.cached_cols) %in% names(private$.data_cache))) {
expensive_cols = intersect(cols, private$.cached_cols)
other_cols = setdiff(cols, expensive_cols)
cache_hit = private$.data_cache[list(rows), expensive_cols, on = self$primary_key, with = FALSE]
complete = complete.cases(cache_hit)
cache_hit = cache_hit[complete]
if (nrow(cache_hit) == length(rows)) {
tbl = cbind(cache_hit, super$data(rows, other_cols))
setcolorder(tbl, cols)
return(tbl)
}
combined = rbindlist(list(cache_hit, private$.load_and_cache(rows[!complete], expensive_cols)))
reorder = vector("integer", nrow(combined))
reorder[complete] = seq_len(nrow(cache_hit))
reorder[!complete] = nrow(cache_hit) + seq_len(nrow(combined) - nrow(cache_hit))

tbl = cbind(combined[reorder], super$data(rows, other_cols))
setcolorder(tbl, cols)
return(tbl)
}

private$.load_and_cache(rows, cols)
},
head = function(n = 6L) {
if (getOption("mlr3torch.data_loading", FALSE)) {
return(super$head(n))
}

self$data(seq_len(n), self$colnames)
},
missings = function(rows, cols) {
set_names(rep(0L, length(cols)), cols)
}
),
active = list(
converter = function(rhs) {
assert_ro_binding(rhs)
private$.converter
}
),
private = list(
# call this function only with rows that are not in the cache yet
.load_and_cache = function(rows, cols) {
# Process columns that need conversion
tbl = super$data(rows, cols)
cols_to_convert = intersect(names(private$.converter), names(tbl))
tbl_to_mat = tbl[, cols_to_convert, with = FALSE]
# chunk the rows of tbl_to_mat into chunks of size self$chunk_size, apply materialize
n = nrow(tbl_to_mat)
chunks = split(seq_len(n), rep(seq_len(ceiling(n / self$chunk_size)), each = self$chunk_size, length.out = n))

tbl_mat = if (n == 0) {
set_names(list(torch_empty(0)), names(tbl_to_mat))
} else {
set_names(lapply(transpose_list(lapply(chunks, function(chunk) {
materialize(tbl_to_mat[chunk, ], rbind = TRUE)
})), torch_cat, dim = 1L), names(tbl_to_mat))
}

for (nm in cols_to_convert) {
converted = private$.converter[[nm]](tbl_mat[[nm]])
tbl[[nm]] = converted

if (nm %in% private$.cached_cols) {
set(private$.data_cache, i = rows, j = nm, value = converted)
}
}
return(tbl)
},
.data_cache = NULL,
.converter = NULL,
.cached_cols = NULL
)
)

#' @export
as_data_backend.dataset = function(x, dataset_shapes, ...) {
tbl = as_lazy_tensors(x, dataset_shapes, ...)
tbl$row_id = seq_len(nrow(tbl))
DataBackendLazyTensors$new(tbl, primary_key = "row_id", ...)
}

#' @export
as_task_classif.dataset = function(x, target, levels, converter = NULL, dataset_shapes = NULL, chunk_size = 100, cache = names(converter), ...) {
if (length(x) < 2) {
stopf("Dataset must have at least 2 rows.")
}
batch = dataloader(x, batch_size = 2)$.iter()$.next()
if (is.null(converter)) {
if (length(levels) == 2) {
if (batch[[target]]$dtype != torch_float()) {
stopf("Target must be a float tensor, but has dtype %s", batch[[target]]$dtype)
}
if (test_equal(batch[[target]]$shape, c(2L, 1L))) {
converter = set_names(list(crate(function(x) factor(as.integer(x), levels = 0:1, labels = levels), levels)), target)
} else {
stopf("Target must be a float tensor of shape (batch_size, 1), but has shape (batch_size, %s)",
paste(batch[[target]]$shape[-1L], collapse = ", "))
}
converter = set_names(list(crate(function(x) factor(as.integer(x), levels = 0:1, labels = levels), levels)), target)
} else {
if (batch[[target]]$dtype != torch_int()) {
stopf("Target must be an integer tensor, but has dtype %s", batch[[target]]$dtype)
}
if (test_equal(batch[[target]]$shape, 2L)) {
converter = set_names(list(crate(function(x) factor(as.integer(x), labels = levels), levels)), target)
} else {
stopf("Target must be an integer tensor of shape (batch_size), but has shape (batch_size, %s)",
paste(batch[[target]]$shape[-1L], collapse = ", "))
}
converter = set_names(list(crate(function(x) factor(as.integer(x), labels = levels), levels)), target)
}
}
be = as_data_backend(x, dataset_shapes, converter = converter, cache = cache, chunk_size = chunk_size)
as_task_classif(be, target = target, ...)
}

#' @export
as_task_regr.dataset = function(x, target, converter = NULL, dataset_shapes = NULL, chunk_size = 100, cache = names(converter), ...) {
if (length(x) < 2) {
stopf("Dataset must have at least 2 rows.")
}
if (is.null(converter)) {
converter = set_names(list(as.numeric), target)
}
batch = dataloader(x, batch_size = 2)$.iter()$.next()

if (batch[[target]]$dtype != torch_float()) {
stopf("Target must be a float tensor, but has dtype %s", batch[[target]]$dtype)
}

if (!test_equal(batch[[target]]$shape, c(2L, 1L))) {
stopf("Target must be a float tensor of shape (batch_size, 1), but has shape (batch_size, %s)",
paste(batch[[target]]$shape[-1L], collapse = ", "))
}

dataset_shapes = get_or_check_dataset_shapes(x, dataset_shapes)
be = as_data_backend(x, dataset_shapes, converter = converter, cache = cache, chunk_size = chunk_size)
as_task_regr(be, target = target, ...)
}

#' @export
col_info.DataBackendLazyTensors = function(x, ...) { # nolint
first_row = x$head(1L)
types = map_chr(first_row, function(x) class(x)[1L])
discrete = setdiff(names(types)[types %chin% c("factor", "ordered")], x$primary_key)
levels = insert_named(named_list(names(types)), map(first_row[, discrete, with = FALSE], levels))
data.table(id = names(types), type = unname(types), levels = levels, key = "id")
}


# conservative check that avoids that a pseudo-lazy-tensor is preprocessed by some pipeop
# @param be
# the backend
# @param candidates
# the feature and target names
# @param visited
# Union of all colnames already visited
# @return visited
check_lazy_tensors_backend = function(be, candidates, visited = character()) {
if (inherits(be, "DataBackendRbind") || inherits(be, "DataBackendCbind")) {
bs = be$.__enclos_env__$private$.data
# first we check b2, then b1, because b2 possibly overshadows some b1 rows/cols
visited = check_lazy_tensors_backend(bs$b2, candidates, visited)
check_lazy_tensors_backend(bs$b1, candidates, visited)
} else {
if (inherits(be, "DataBackendLazyTensors")) {
if (any(names(be$converter) %in% visited)) {
converter_cols = names(be$converter)[names(be$converter) %in% visited]
stopf("A converter column ('%s') from a DataBackendLazyTensors was presumably preprocessed by some PipeOp. This can cause inefficiencies and is therefore not allowed. If you want to preprocess them, please directly encode them as R types.", paste0(converter_cols, collapse = ", ")) # nolint
}
}
union(visited, intersect(candidates, be$colnames))
}
}
41 changes: 17 additions & 24 deletions R/DataDescriptor.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,7 @@ DataDescriptor = R6Class("DataDescriptor",
# For simplicity we here require the first dimension of the shape to be NA so we don't have to deal with it,
# e.g. during subsetting

if (is.null(dataset_shapes)) {
if (is.null(dataset$.getbatch)) {
stopf("dataset_shapes must be provided if dataset does not have a `.getbatch` method.")
}
dataset_shapes = infer_shapes_from_getbatch(dataset)
} else {
assert_compatible_shapes(dataset_shapes, dataset)
}
dataset_shapes = get_or_check_dataset_shapes(dataset, dataset_shapes)

if (is.null(graph)) {
# avoid name conflicts
Expand All @@ -84,8 +77,7 @@ DataDescriptor = R6Class("DataDescriptor",
assert_true(length(graph$pipeops) >= 1L)
}
# no preprocessing, dataset returns only a single element (there we can infer a lot)
simple_case = length(graph$pipeops) == 1L && inherits(graph$pipeops[[1L]], "PipeOpNOP") &&
length(dataset_shapes) == 1L
simple_case = (length(graph$pipeops) == 1L) && inherits(graph$pipeops[[1L]], "PipeOpNOP")

if (is.null(input_map) && nrow(graph$input) == 1L && length(dataset_shapes) == 1L) {
input_map = names(dataset_shapes)
Expand All @@ -100,7 +92,7 @@ DataDescriptor = R6Class("DataDescriptor",
assert_choice(pointer[[2]], graph$pipeops[[pointer[[1]]]]$output$name)
}
if (is.null(pointer_shape) && simple_case) {
pointer_shape = dataset_shapes[[1L]]
pointer_shape = dataset_shapes[[input_map]]
} else {
assert_shape(pointer_shape, null_ok = TRUE)
}
Expand Down Expand Up @@ -225,13 +217,14 @@ infer_shapes_from_getbatch = function(ds) {
}

assert_compatible_shapes = function(shapes, dataset) {
assert_shapes(shapes, null_ok = TRUE, unknown_batch = TRUE, named = TRUE)
shapes = assert_shapes(shapes, null_ok = TRUE, unknown_batch = TRUE, named = TRUE, coerce = TRUE)

# prevent user from e.g. forgetting to wrap the return in a list
example = if (is.null(dataset$.getbatch)) {
dataset$.getitem(1L)
} else {
has_getbatch = !is.null(dataset$.getbatch)
example = if (has_getbatch) {
dataset$.getbatch(1L)
} else {
dataset$.getitem(1L)
}
if (!test_list(example, names = "unique") || !test_permutation(names(example), names(shapes))) {
stopf("Dataset must return a list with named elements that are a permutation of the dataset_shapes names.")
Expand All @@ -242,17 +235,17 @@ assert_compatible_shapes = function(shapes, dataset) {
}
})

if (is.null(dataset$.getbatch)) {
example = map(example, function(x) x$unsqueeze(1))
}

iwalk(shapes, function(dataset_shape, name) {
if (!is.null(dataset_shape) && !test_equal(shapes[[name]][-1], example[[name]]$shape[-1L])) {
expected_shape = example[[name]]$shape
expected_shape[1] = NA
observed_shape = example[[name]]$shape
if (has_getbatch) {
observed_shape[1L] = NA_integer_
} else {
observed_shape = c(NA_integer_, observed_shape)
}
if (!is.null(dataset_shape) && !test_equal(observed_shape, dataset_shape)) {
stopf(paste0("First batch from dataset is incompatible with the provided shape of %s:\n",
"* Provided shape: %s.\n* Expected shape: %s."), name,
shape_to_str(unname(shapes[name])), shape_to_str(list(expected_shape)))
"* Provided shape: %s.\n* Observed shape: %s."), name,
shape_to_str(unname(shapes[name])), shape_to_str(list(observed_shape)))
}
})
}
2 changes: 1 addition & 1 deletion R/LearnerTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
#'
#' For information on the expected target encoding of `y`, see section *Network Head and Target Encoding*.
#' Moreover, one needs to pay attention respect the row ids of the provided task.
#' It is recommended to relu on [`task_dataset`] for creating the [`dataset`][torch::dataset].
#' It is strongly recommended to use the [`task_dataset`] class to create the dataset.
#'
#' It is also possible to overwrite the private `.dataloader()` method.
#' This must respect the dataloader parameters from the [`ParamSet`][paradox::ParamSet].
Expand Down
Loading
Loading