From 0668997419f1e6578ae5a9b0fd2729d9b271a85e Mon Sep 17 00:00:00 2001 From: be-marc Date: Tue, 21 Oct 2025 11:52:35 +0200 Subject: [PATCH 1/8] feat: add gpu support --- R/Auto.R | 16 +++++++++++-- R/AutoCatboost.R | 20 +++++++++++----- R/AutoExtraTrees.R | 13 +++++++---- R/AutoFTTransformer.R | 29 +++++++++++++++++++----- R/AutoGlmnet.R | 12 ++++++---- R/AutoKknn.R | 12 ++++++---- R/AutoLda.R | 10 ++++---- R/AutoLightgbm.R | 18 ++++++++++----- R/AutoMlp.R | 29 +++++++++++++++++++----- R/AutoRanger.R | 12 ++++++---- R/AutoResNet.R | 28 ++++++++++++++++++----- R/AutoSvm.R | 13 +++++++---- R/AutoTabpfn.R | 25 ++++++++++++-------- R/AutoXgboost.R | 19 +++++++++++----- R/LearnerClassifAuto.R | 1 + R/train_auto.R | 5 ++-- tests/testthat/test_LearnerClassifAuto.R | 22 ++++++++++++++++++ 17 files changed, 204 insertions(+), 80 deletions(-) diff --git a/R/Auto.R b/R/Auto.R index 2ff875a..f139a5e 100644 --- a/R/Auto.R +++ b/R/Auto.R @@ -30,15 +30,22 @@ Auto = R6Class("Auto", #' @field packages (`character()`). packages = NULL, + #' @field devices (`character()`). + devices = NULL, + #' @description #' Creates a new instance of this [R6][R6::R6Class] class. - initialize = function(id) { + initialize = function(id, properties = character(0), task_types = character(0), packages = character(0), devices = character(0)) { self$id = assert_string(id) + self$properties = assert_character(properties) + self$task_types = assert_character(task_types) + self$packages = assert_character(packages) + self$devices = assert_character(devices) }, #' @description #' Check if the auto is compatible with the task. - check = function(task, memory_limit = Inf, large_data_set = FALSE) { + check = function(task, memory_limit = Inf, large_data_set = FALSE, devices) { if (!task$task_type %in% self$task_types) { lg$info("Learner '%s' is not compatible with task type '%s'", self$id, task$task_type) return(FALSE) @@ -51,6 +58,11 @@ Auto = R6Class("Auto", lg$info("Learner '%s' is not compatible with large data sets", self$id) return(FALSE) } + if (any(devices %nin% self$devices)) { + lg$info("Learner '%s' is not compatible with devices '%s'", self$id, as_short_string(devices)) + return(FALSE) + } + TRUE }, diff --git a/R/AutoCatboost.R b/R/AutoCatboost.R index a689a69..a5541c2 100644 --- a/R/AutoCatboost.R +++ b/R/AutoCatboost.R @@ -19,28 +19,36 @@ AutoCatboost = R6Class("AutoCatboost", #' @description #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function(id = "catboost") { - super$initialize(id = id) - self$task_types = c("classif", "regr") - self$properties = c("internal_tuning", "large_data_sets") - self$packages = c("mlr3", "mlr3extralearners", "catboost") + super$initialize( + id = id, + properties = c("internal_tuning", "large_data_sets"), + task_types = c("classif", "regr"), + packages = c("mlr3", "mlr3extralearners", "catboost"), + devices = c("cpu", "cuda") + ) }, #' @description #' Create the graph for the auto. - graph = function(task, measure, n_threads, timeout) { + graph = function(task, measure, n_threads, timeout, devices) { assert_task(task) assert_measure(measure) assert_count(n_threads) assert_count(timeout) + assert_subset(devices, self$devices) require_namespaces("mlr3extralearners") + # catboost only supports gpu via cuda + task_type = if ("cuda" %in% devices) "GPU" else "CPU" + learner = lrn(sprintf("%s.catboost", task$task_type), id = "catboost", iterations = self$search_space(task)$upper["catboost.iterations"] %??% 1000L, early_stopping_rounds = self$early_stopping_rounds(task), use_best_model = TRUE, - eval_metric = self$internal_measure(measure, task)) + eval_metric = self$internal_measure(measure, task), + task_type = task_type) set_threads(learner, n_threads) po("removeconstants", id = "catboost_removeconstants") %>>% diff --git a/R/AutoExtraTrees.R b/R/AutoExtraTrees.R index de03a2b..320e9a1 100644 --- a/R/AutoExtraTrees.R +++ b/R/AutoExtraTrees.R @@ -20,10 +20,13 @@ AutoExtraTrees = R6Class("AutoExtraTrees", #' @description #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function(id = "extra_trees") { - super$initialize(id = id) - self$task_types = c("classif", "regr") - self$properties = c("large_data_sets", "hyperparameter-free") - self$packages = c("mlr3", "mlr3learners", "ranger") + super$initialize( + id = id, + properties = c("large_data_sets", "hyperparameter-free"), + task_types = c("classif", "regr"), + packages = c("mlr3", "mlr3learners", "ranger"), + devices = "cpu" + ) }, #' @description @@ -33,7 +36,7 @@ AutoExtraTrees = R6Class("AutoExtraTrees", #' @param measure ([mlr3::Measure]). #' @param n_threads (`numeric(1)`). #' @param timeout (`numeric(1)`). - graph = function(task, measure, n_threads, timeout) { + graph = function(task, measure, n_threads, timeout, devices) { assert_task(task) assert_measure(measure) assert_count(n_threads) diff --git a/R/AutoFTTransformer.R b/R/AutoFTTransformer.R index ec2aea4..0ee61e6 100644 --- a/R/AutoFTTransformer.R +++ b/R/AutoFTTransformer.R @@ -19,22 +19,38 @@ AutoFTTransformer = R6Class("AutoFTTransformer", #' @description #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function(id = "ft_transformer") { - super$initialize(id = id) - self$task_types = c("classif", "regr") - self$properties = "internal_tuning" - self$packages = c("mlr3", "mlr3torch") + super$initialize( + id = id, + properties = "internal_tuning", + task_types = c("classif", "regr"), + packages = c("mlr3", "mlr3torch"), + devices = c("cpu", "cuda") + ) + }, + + #' @description + #' Check if the auto is compatible with the task. + check = function(task, memory_limit = Inf, large_data_set = FALSE, devices = "cpu") { + if ("cuda" %nin% devices && task$nrow > 1e3) { + lg$info("Learner '%s' is not compatible with tasks with more than 1,000 rows when using 'cpu' as device", self$id) + return(FALSE) + } + super$check(task, memory_limit, large_data_set, devices) }, #' @description #' Create the graph for the auto. - graph = function(task, measure, n_threads, timeout) { + graph = function(task, measure, n_threads, timeout, devices) { assert_task(task) assert_measure(measure) assert_count(n_threads) assert_count(timeout) + assert_subset(devices, self$devices) require_namespaces("mlr3torch") + device = if ("cuda" %in% devices) "cuda" else "auto" + # copied from mlr3tuningspaces no_wd = function(name) { # this will also disable weight decay for the input projection bias of the attention heads @@ -74,7 +90,8 @@ AutoFTTransformer = R6Class("AutoFTTransformer", patience = self$early_stopping_rounds(task), batch_size = 32L, attention_n_heads = 8L, - opt.param_groups = rtdl_param_groups + opt.param_groups = rtdl_param_groups, + device = "cuda" ) set_threads(learner, n_threads) diff --git a/R/AutoGlmnet.R b/R/AutoGlmnet.R index 2d76890..d26ac12 100644 --- a/R/AutoGlmnet.R +++ b/R/AutoGlmnet.R @@ -19,15 +19,17 @@ AutoGlmnet = R6Class("AutoGlmnet", #' @description #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function(id = "glmnet") { - super$initialize(id = id) - self$task_types = c("classif", "regr") - self$properties = character() - self$packages = c("mlr3", "mlr3learners", "glmnet") + super$initialize(id = id, + properties = character(), + task_types = c("classif", "regr"), + packages = c("mlr3", "mlr3learners", "glmnet"), + devices = "cpu" + ) }, #' @description #' Create the graph for the auto. - graph = function(task, measure, n_threads, timeout) { + graph = function(task, measure, n_threads, timeout, devices) { assert_task(task) assert_measure(measure) assert_count(n_threads) diff --git a/R/AutoKknn.R b/R/AutoKknn.R index a4a5d28..752cc2e 100644 --- a/R/AutoKknn.R +++ b/R/AutoKknn.R @@ -19,15 +19,17 @@ AutoKknn = R6Class("AutoKknn", #' @description #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function(id = "kknn") { - super$initialize(id = id) - self$task_types = c("classif", "regr") - self$properties = character() - self$packages = c("mlr3", "mlr3learners", "kknn") + super$initialize(id = id, + properties = character(), + task_types = c("classif", "regr"), + packages = c("mlr3", "mlr3learners", "kknn"), + devices = "cpu" + ) }, #' @description #' Create the graph for the auto. - graph = function(task, measure, n_threads, timeout) { + graph = function(task, measure, n_threads, timeout, devices) { assert_task(task) assert_measure(measure) assert_count(n_threads) diff --git a/R/AutoLda.R b/R/AutoLda.R index db28465..3b41cc6 100644 --- a/R/AutoLda.R +++ b/R/AutoLda.R @@ -22,10 +22,12 @@ AutoLda = R6Class("AutoLda", #' @description #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function(id = "lda") { - super$initialize(id = id) - self$task_types = "classif" - self$properties = "hyperparameter-free" - self$packages = c("mlr3", "mlr3learners", "MASS") + super$initialize( + id = id, + task_types = "classif", + properties = "hyperparameter-free", + packages = c("mlr3", "mlr3learners", "MASS"), + devices = "cpu") }, #' @description diff --git a/R/AutoLightgbm.R b/R/AutoLightgbm.R index 9ac69de..b26bfb0 100644 --- a/R/AutoLightgbm.R +++ b/R/AutoLightgbm.R @@ -19,27 +19,33 @@ AutoLightgbm = R6Class("AutoLightgbm", #' @description #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function(id = "lightgbm") { - super$initialize(id = id) - self$task_types = c("classif", "regr") - self$properties = c("internal_tuning", "large_data_sets") - self$packages = c("mlr3", "mlr3extralearners", "lightgbm") + super$initialize(id = id, + properties = c("internal_tuning", "large_data_sets"), + task_types = c("classif", "regr"), + packages = c("mlr3", "mlr3extralearners", "lightgbm"), + devices = c("cpu", "cuda") + ) }, #' @description #' Create the graph for the auto. - graph = function(task, measure, n_threads, timeout) { + graph = function(task, measure, n_threads, timeout, devices) { assert_task(task) assert_measure(measure) assert_count(n_threads) assert_count(timeout) + assert_subset(devices, self$devices) require_namespaces("mlr3extralearners") + device_type = if ("cuda" %in% devices) "gpu" else "cpu" + learner = lrn(sprintf("%s.lightgbm", task$task_type), id = "lightgbm", early_stopping_rounds = self$early_stopping_rounds(task), callbacks = list(cb_timeout_lightgbm(timeout * 0.8)), - eval = self$internal_measure(measure, task)) + eval = self$internal_measure(measure, task), + device_type = device_type) set_threads(learner, n_threads) learner diff --git a/R/AutoMlp.R b/R/AutoMlp.R index 1ef89dc..b4506d5 100644 --- a/R/AutoMlp.R +++ b/R/AutoMlp.R @@ -19,27 +19,44 @@ AutoMlp = R6Class("AutoMlp", #' @description #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function(id = "mlp") { - super$initialize(id = id) - self$task_types = c("classif", "regr") - self$properties = "internal_tuning" - self$packages = c("mlr3", "mlr3torch") + super$initialize( + id = id, + properties = "internal_tuning", + task_types = c("classif", "regr"), + packages = c("mlr3", "mlr3torch"), + devices = c("cpu", "cuda") + ) + }, + + #' @description + #' Check if the auto is compatible with the task. + check = function(task, memory_limit = Inf, large_data_set = FALSE, devices = "cpu") { + if ("cuda" %nin% devices && task$nrow > 1e3) { + lg$info("Learner '%s' is not compatible with tasks with more than 1,000 rows when using 'cpu' as device", self$id) + return(FALSE) + } + super$check(task, memory_limit, large_data_set, devices) }, #' @description #' Create the graph for the auto. - graph = function(task, measure, n_threads, timeout) { + graph = function(task, measure, n_threads, timeout, devices) { assert_task(task) assert_measure(measure) assert_count(n_threads) assert_count(timeout) + assert_subset(devices, self$devices) require_namespaces("mlr3torch") + device = if ("cuda" %in% devices) "cuda" else "auto" + learner = lrn(sprintf("%s.mlp", task$task_type), id = "mlp", measures_valid = measure, patience = self$early_stopping_rounds(task), - batch_size = 32L + batch_size = 32L, + device = device ) set_threads(learner, n_threads) diff --git a/R/AutoRanger.R b/R/AutoRanger.R index b394902..0fad623 100644 --- a/R/AutoRanger.R +++ b/R/AutoRanger.R @@ -19,15 +19,17 @@ AutoRanger = R6Class("AutoRanger", #' @description #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function(id = "ranger") { - super$initialize(id = id) - self$task_types = c("classif", "regr") - self$properties = "large_data_sets" - self$packages = c("mlr3", "mlr3learners", "ranger") + super$initialize(id = id, + properties = "large_data_sets", + task_types = c("classif", "regr"), + packages = c("mlr3", "mlr3learners", "ranger"), + devices = "cpu" + ) }, #' @description #' Create the graph for the auto. - graph = function(task, measure, n_threads, timeout) { + graph = function(task, measure, n_threads, timeout, devices) { assert_task(task) assert_measure(measure) assert_count(n_threads) diff --git a/R/AutoResNet.R b/R/AutoResNet.R index c48109f..009529c 100644 --- a/R/AutoResNet.R +++ b/R/AutoResNet.R @@ -19,27 +19,43 @@ AutoResNet = R6Class("AutoResNet", #' @description #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function(id = "resnet") { - super$initialize(id = id) - self$task_types = c("classif", "regr") - self$properties = "internal_tuning" - self$packages = c("mlr3", "mlr3torch") + super$initialize(id = id, + properties = "internal_tuning", + task_types = c("classif", "regr"), + packages = c("mlr3", "mlr3torch"), + devices = c("cpu", "cuda") + ) + }, + + #' @description + #' Check if the auto is compatible with the task. + check = function(task, memory_limit = Inf, large_data_set = FALSE, devices = "cpu") { + if ("cuda" %nin% devices && task$nrow > 1e3) { + lg$info("Learner '%s' is not compatible with tasks with more than 1,000 rows when using 'cpu' as device", self$id) + return(FALSE) + } + super$check(task, memory_limit, large_data_set, devices) }, #' @description #' Create the graph for the auto. - graph = function(task, measure, n_threads, timeout) { + graph = function(task, measure, n_threads, timeout, devices) { assert_task(task) assert_measure(measure) assert_count(n_threads) assert_count(timeout) + assert_subset(devices, self$devices) require_namespaces("mlr3torch") + device = if ("cuda" %in% devices) "cuda" else "auto" + learner = lrn(sprintf("%s.tab_resnet", task$task_type), id = "resnet", measures_valid = measure, patience = self$early_stopping_rounds(task), - batch_size = 32L + batch_size = 32L, + device = device ) set_threads(learner, n_threads) diff --git a/R/AutoSvm.R b/R/AutoSvm.R index 241dcc2..2203124 100644 --- a/R/AutoSvm.R +++ b/R/AutoSvm.R @@ -19,19 +19,22 @@ AutoSvm = R6Class("AutoSvm", #' @description #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function(id = "svm") { - super$initialize(id = id) - self$task_types = c("classif", "regr") - self$properties = character(0) - self$packages = c("mlr3", "mlr3learners", "e1071") + super$initialize(id = id, + properties = character(0), + task_types = c("classif", "regr"), + packages = c("mlr3", "mlr3learners", "e1071"), + devices = "cpu" + ) }, #' @description #' Create the graph for the auto. - graph = function(task, measure, n_threads, timeout) { + graph = function(task, measure, n_threads, timeout, devices) { assert_task(task) assert_measure(measure) assert_count(n_threads) assert_count(timeout) + assert_subset(devices, self$devices) require_namespaces("mlr3learners") diff --git a/R/AutoTabpfn.R b/R/AutoTabpfn.R index e286fcc..da09ec7 100644 --- a/R/AutoTabpfn.R +++ b/R/AutoTabpfn.R @@ -21,15 +21,17 @@ AutoTabpfn = R6Class("AutoTabpfn", #' @description #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function(id = "tabpfn") { - super$initialize(id = id) - self$task_types = c("classif", "regr") - self$properties = character(0) - self$packages = c("mlr3", "mlr3extralearners") + super$initialize(id = id, + properties = character(0), + task_types = c("classif", "regr"), + packages = c("mlr3", "mlr3extralearners"), + devices = c("cpu", "cuda") + ) }, #' @description #' Check if the auto is compatible with the task. - check = function(task, memory_limit = Inf, large_data_set = FALSE) { + check = function(task, memory_limit = Inf, large_data_set = FALSE, devices = "cpu") { ok = check_python_packages(c("fastai", "torch")) if (!isTRUE(ok)) { lg$info(ok) @@ -37,24 +39,27 @@ AutoTabpfn = R6Class("AutoTabpfn", return(FALSE) } - if (task$nrow > 1e3) { - lg$info("Learner '%s' is not compatible with tasks with more than 1,000 rows", self$id) + if ("cuda" %nin% devices && task$nrow > 1e3) { + lg$info("Learner '%s' is not compatible with tasks with more than 1,000 rows when using 'cpu' as device", self$id) return(FALSE) } - super$check(task, memory_limit, large_data_set) + super$check(task, memory_limit, large_data_set, devices) }, #' @description #' Create the graph for the auto. - graph = function(task, measure, n_threads, timeout) { + graph = function(task, measure, n_threads, timeout, devices) { assert_task(task) assert_measure(measure) assert_count(n_threads) assert_count(timeout) + assert_subset(devices, self$devices) require_namespaces("mlr3extralearners") - learner = lrn(sprintf("%s.tabpfn", task$task_type), id = "tabpfn") + device = if ("cuda" %in% devices) "cuda" else "cpu" + + learner = lrn(sprintf("%s.tabpfn", task$task_type), id = "tabpfn", device = device) set_threads(learner, n_threads) diff --git a/R/AutoXgboost.R b/R/AutoXgboost.R index 8cfb32e..0758314 100644 --- a/R/AutoXgboost.R +++ b/R/AutoXgboost.R @@ -19,28 +19,35 @@ AutoXgboost = R6Class("AutoXgboost", #' @description #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function(id = "xgboost") { - super$initialize(id = id) - self$task_types = c("classif", "regr") - self$properties = c("internal_tuning", "large_data_sets") - self$packages = c("mlr3", "mlr3learners", "xgboost") + super$initialize(id = id, + properties = c("internal_tuning", "large_data_sets"), + task_types = c("classif", "regr"), + packages = c("mlr3", "mlr3learners", "xgboost"), + devices = c("cpu", "gpu") + ) }, #' @description #' Create the graph for the auto. - graph = function(task, measure, n_threads, timeout) { + graph = function(task, measure, n_threads, timeout, devices) { assert_task(task) assert_measure(measure) assert_count(n_threads) assert_count(timeout) + assert_subset(devices, self$devices) require_namespaces("mlr3learners") + device = if ("cuda" %in% devices) "cuda" else "cpu" + learner = lrn(sprintf("%s.xgboost", task$task_type), id = "xgboost", early_stopping_rounds = self$early_stopping_rounds(task), callbacks = list(cb_timeout_xgboost(timeout * 0.8)), eval_metric = self$internal_measure(measure, task), - nrounds = 5000L) + nrounds = 5000L, + device = device + ) set_threads(learner, n_threads) po("removeconstants", id = "xgboost_removeconstants") %>>% diff --git a/R/LearnerClassifAuto.R b/R/LearnerClassifAuto.R index 8c82b26..3c79711 100644 --- a/R/LearnerClassifAuto.R +++ b/R/LearnerClassifAuto.R @@ -36,6 +36,7 @@ LearnerClassifAuto = R6Class("LearnerClassifAuto", # system n_threads = p_int(lower = 1L, init = 1L, tags = c("train", "catboost", "lightgbm", "ranger", "xgboost")), memory_limit = p_int(lower = 1L, init = 32000L, tags = c("train", "catboost", "lightgbm", "ranger", "xgboost")), + devices = p_uty(init = c("cpu", "cuda"), tags = c("train", "super"), custom_check = crate({function(x) check_subset(x, c("cpu", "cuda"))})), # large data large_data_size = p_int(lower = 1L, init = 1e6, tags = c("train", "super")), # small data diff --git a/R/train_auto.R b/R/train_auto.R index 624691e..09f0369 100644 --- a/R/train_auto.R +++ b/R/train_auto.R @@ -30,7 +30,7 @@ train_auto = function(self, private, task) { } # initialize graph learner - autos = keep(autos, function(auto) auto$check(task, memory_limit = memory_limit, large_data_set = large_data_set)) + autos = keep(autos, function(auto) auto$check(task, memory_limit = memory_limit, large_data_set = large_data_set, devices = pv$devices)) if (!length(autos)) { error_config("No learner is compatible with the task.") @@ -40,14 +40,13 @@ train_auto = function(self, private, task) { error_config("All learners have no hyperparameters to tune. Combine with other learners.") } - branches = map(autos, function(auto) auto$graph(task, pv$measure, n_threads, pv$learner_timeout)) + branches = map(autos, function(auto) auto$graph(task, pv$measure, n_threads, pv$learner_timeout, pv$devices)) graph_learner = as_learner(po("branch", options = names(branches)) %>>% gunion(unname(branches)) %>>% po("unbranch", options = names(branches)), clone = TRUE) graph_learner$id = "graph_learner" graph_learner$predict_type = pv$measure$predict_type - if (pv$encapsulate_learner) { fallback = lrn(sprintf("%s.featureless", task$task_type)) fallback$predict_type = pv$measure$predict_type diff --git a/tests/testthat/test_LearnerClassifAuto.R b/tests/testthat/test_LearnerClassifAuto.R index a0b1b4a..9726b8c 100644 --- a/tests/testthat/test_LearnerClassifAuto.R +++ b/tests/testthat/test_LearnerClassifAuto.R @@ -304,3 +304,25 @@ test_that("initial design runtime limit works", { expect_class(learner$train(task), "LearnerClassifAuto") }) + +test_that("devices works", { + skip_on_cran() + skip_if_not_installed("rush") + skip_if_not_installed(all_packages) + flush_redis() + + rush_plan(n_workers = 2, worker_type = "remote") + mirai::daemons(2) + + task = tsk("penguins") + learner = lrn("classif.auto", + devices = "cpu", + measure = msr("classif.ce"), + terminator = trm("evals", n_evals = 10), + initial_design_size = 1, + encapsulate_learner = FALSE, + encapsulate_mbo = FALSE + ) + + expect_class(learner$train(task), "LearnerClassifAuto") +}) From c21fe90b749cb769d35b46e4bc67efd0181e6e74 Mon Sep 17 00:00:00 2001 From: be-marc Date: Thu, 23 Oct 2025 18:23:12 +0200 Subject: [PATCH 2/8] ... --- R/Auto.R | 2 +- R/AutoLda.R | 2 +- R/AutoTabpfn.R | 2 +- tests/testthat/test_LearnerClassifAuto.R | 23 +++++++++++++++++++++++ 4 files changed, 26 insertions(+), 3 deletions(-) diff --git a/R/Auto.R b/R/Auto.R index f139a5e..e8500b5 100644 --- a/R/Auto.R +++ b/R/Auto.R @@ -68,7 +68,7 @@ Auto = R6Class("Auto", #' @description #' Create the graph for the auto. - graph = function(task, measure, n_threads, timeout) { + graph = function(task, measure, n_threads, timeout, devices) { stop("Abstract") }, diff --git a/R/AutoLda.R b/R/AutoLda.R index 3b41cc6..4fde1ba 100644 --- a/R/AutoLda.R +++ b/R/AutoLda.R @@ -32,7 +32,7 @@ AutoLda = R6Class("AutoLda", #' @description #' Create the graph for the auto. - graph = function(task, measure, n_threads, timeout) { + graph = function(task, measure, n_threads, timeout, devices) { assert_task(task) assert_measure(measure) assert_count(n_threads) diff --git a/R/AutoTabpfn.R b/R/AutoTabpfn.R index da09ec7..8c4f46a 100644 --- a/R/AutoTabpfn.R +++ b/R/AutoTabpfn.R @@ -32,7 +32,7 @@ AutoTabpfn = R6Class("AutoTabpfn", #' @description #' Check if the auto is compatible with the task. check = function(task, memory_limit = Inf, large_data_set = FALSE, devices = "cpu") { - ok = check_python_packages(c("fastai", "torch")) + ok = check_python_packages(c("torch")) if (!isTRUE(ok)) { lg$info(ok) lg$info("Remove tabpfn from search space") diff --git a/tests/testthat/test_LearnerClassifAuto.R b/tests/testthat/test_LearnerClassifAuto.R index 9726b8c..cb4e9fb 100644 --- a/tests/testthat/test_LearnerClassifAuto.R +++ b/tests/testthat/test_LearnerClassifAuto.R @@ -326,3 +326,26 @@ test_that("devices works", { expect_class(learner$train(task), "LearnerClassifAuto") }) + +test_that("devices works", { + skip_if(TRUE) + skip_on_cran() + skip_if_not_installed("rush") + skip_if_not_installed(all_packages) + flush_redis() + + rush_plan(n_workers = 2, worker_type = "remote") + mirai::daemons(2) + + task = tsk("penguins") + learner = lrn("classif.auto", + devices = "cuda", + measure = msr("classif.ce"), + terminator = trm("evals", n_evals = 10), + initial_design_size = 1, + encapsulate_learner = FALSE, + encapsulate_mbo = FALSE + ) + + expect_class(learner$train(task), "LearnerClassifAuto") +}) From b59da2581758edbae92088f702cb693e9fbf27ca Mon Sep 17 00:00:00 2001 From: be-marc Date: Thu, 23 Oct 2025 21:07:03 +0200 Subject: [PATCH 3/8] ... --- R/LearnerClassifAuto.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/LearnerClassifAuto.R b/R/LearnerClassifAuto.R index 3c79711..fc28db7 100644 --- a/R/LearnerClassifAuto.R +++ b/R/LearnerClassifAuto.R @@ -36,7 +36,7 @@ LearnerClassifAuto = R6Class("LearnerClassifAuto", # system n_threads = p_int(lower = 1L, init = 1L, tags = c("train", "catboost", "lightgbm", "ranger", "xgboost")), memory_limit = p_int(lower = 1L, init = 32000L, tags = c("train", "catboost", "lightgbm", "ranger", "xgboost")), - devices = p_uty(init = c("cpu", "cuda"), tags = c("train", "super"), custom_check = crate({function(x) check_subset(x, c("cpu", "cuda"))})), + devices = p_uty(init = "cpu", tags = c("train", "super"), custom_check = crate({function(x) check_subset(x, c("cpu", "cuda"))})), # large data large_data_size = p_int(lower = 1L, init = 1e6, tags = c("train", "super")), # small data From 77448c6763f909182b9c79217caeb3445a4d1c5c Mon Sep 17 00:00:00 2001 From: be-marc Date: Thu, 23 Oct 2025 21:48:22 +0200 Subject: [PATCH 4/8] ... --- R/AutoFTTransformer.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/AutoFTTransformer.R b/R/AutoFTTransformer.R index 0ee61e6..5745c11 100644 --- a/R/AutoFTTransformer.R +++ b/R/AutoFTTransformer.R @@ -91,7 +91,7 @@ AutoFTTransformer = R6Class("AutoFTTransformer", batch_size = 32L, attention_n_heads = 8L, opt.param_groups = rtdl_param_groups, - device = "cuda" + device = device ) set_threads(learner, n_threads) From eb2c67286a01d842579e05a2a997afa5ba2f8796 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 24 Oct 2025 12:35:48 +0200 Subject: [PATCH 5/8] ... --- DESCRIPTION | 2 +- NAMESPACE | 1 + R/Auto.R | 14 ++++- R/AutoCatboost.R | 1 + R/AutoExtraTrees.R | 1 + R/AutoFTTransformer.R | 15 +---- R/AutoGlmnet.R | 1 + R/AutoKknn.R | 1 + R/AutoLda.R | 1 + R/AutoLightgbm.R | 1 + R/AutoMlp.R | 15 +---- R/AutoRanger.R | 2 + R/AutoResNet.R | 15 +---- R/AutoSvm.R | 1 + R/AutoTabpfn.R | 1 + R/AutoXgboost.R | 1 + R/LearnerClassifAuto.R | 5 +- R/train_auto.R | 8 ++- R/zzz.R | 2 +- man-roxygen/param_devices.R | 4 ++ man/Auto.Rd | 35 +++++++++-- man/AutoCatboost.Rd | 7 ++- man/AutoExtraTrees.Rd | 7 ++- man/AutoFTTransformer.Rd | 7 ++- man/AutoGlmnet.Rd | 7 ++- man/AutoKknn.Rd | 7 ++- man/AutoLda.Rd | 7 ++- man/AutoLightgbm.Rd | 7 ++- man/AutoMlp.Rd | 7 ++- man/AutoRanger.Rd | 7 ++- man/AutoResNet.Rd | 7 ++- man/AutoSvm.Rd | 7 ++- man/AutoTabpfn.Rd | 19 +++++- man/AutoXgboost.Rd | 7 ++- tests/testthat/helper.R | 15 ++++- tests/testthat/test_Auto.R | 44 ++++++++++++++ tests/testthat/test_LearnerClassifAuto.R | 74 ++++-------------------- 37 files changed, 234 insertions(+), 129 deletions(-) create mode 100644 man-roxygen/param_devices.R create mode 100644 tests/testthat/test_Auto.R diff --git a/DESCRIPTION b/DESCRIPTION index 9e93751..f436aa7 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -21,6 +21,7 @@ Imports: checkmate, data.table, lhs, + mlr3learners (>= 0.12.0), mlr3mbo (>= 0.2.8), mlr3misc (>= 0.15.1), mlr3pipelines, @@ -38,7 +39,6 @@ Suggests: MASS, mirai, mlr3extralearners, - mlr3learners (>= 0.12.0), mlr3torch, mlr3viz, ranger, diff --git a/NAMESPACE b/NAMESPACE index 6a3c9b4..ad8944d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -30,6 +30,7 @@ import(checkmate) import(data.table) import(lhs) import(mlr3) +import(mlr3learners) import(mlr3mbo) import(mlr3misc) import(mlr3pipelines) diff --git a/R/Auto.R b/R/Auto.R index e8500b5..0435b5f 100644 --- a/R/Auto.R +++ b/R/Auto.R @@ -13,6 +13,7 @@ #' @template param_memory_limit #' @template param_large_data_set #' @template param_size +#' @template param_devices #' #' @export Auto = R6Class("Auto", @@ -35,7 +36,18 @@ Auto = R6Class("Auto", #' @description #' Creates a new instance of this [R6][R6::R6Class] class. - initialize = function(id, properties = character(0), task_types = character(0), packages = character(0), devices = character(0)) { + #' + #' @param id (`character(1)`). + #' @param properties (`character()`). + #' @param task_types (`character()`). + #' @param packages (`character()`). + #' @param devices (`character()`). + initialize = function(id, + properties = character(0), + task_types = character(0), + packages = character(0), + devices = character(0) + ) { self$id = assert_string(id) self$properties = assert_character(properties) self$task_types = assert_character(task_types) diff --git a/R/AutoCatboost.R b/R/AutoCatboost.R index a5541c2..42951f4 100644 --- a/R/AutoCatboost.R +++ b/R/AutoCatboost.R @@ -10,6 +10,7 @@ #' @template param_measure #' @template param_n_threads #' @template param_timeout +#' @template param_devices #' #' @export AutoCatboost = R6Class("AutoCatboost", diff --git a/R/AutoExtraTrees.R b/R/AutoExtraTrees.R index 320e9a1..fc93819 100644 --- a/R/AutoExtraTrees.R +++ b/R/AutoExtraTrees.R @@ -11,6 +11,7 @@ #' @template param_task #' @template param_measure #' @template param_size +#' @template param_devices #' #' @export AutoExtraTrees = R6Class("AutoExtraTrees", diff --git a/R/AutoFTTransformer.R b/R/AutoFTTransformer.R index 5745c11..c8cc74c 100644 --- a/R/AutoFTTransformer.R +++ b/R/AutoFTTransformer.R @@ -10,6 +10,7 @@ #' @template param_measure #' @template param_n_threads #' @template param_timeout +#' @template param_devices #' #' @export AutoFTTransformer = R6Class("AutoFTTransformer", @@ -24,20 +25,10 @@ AutoFTTransformer = R6Class("AutoFTTransformer", properties = "internal_tuning", task_types = c("classif", "regr"), packages = c("mlr3", "mlr3torch"), - devices = c("cpu", "cuda") + devices = "cuda" ) }, - #' @description - #' Check if the auto is compatible with the task. - check = function(task, memory_limit = Inf, large_data_set = FALSE, devices = "cpu") { - if ("cuda" %nin% devices && task$nrow > 1e3) { - lg$info("Learner '%s' is not compatible with tasks with more than 1,000 rows when using 'cpu' as device", self$id) - return(FALSE) - } - super$check(task, memory_limit, large_data_set, devices) - }, - #' @description #' Create the graph for the auto. graph = function(task, measure, n_threads, timeout, devices) { @@ -45,7 +36,7 @@ AutoFTTransformer = R6Class("AutoFTTransformer", assert_measure(measure) assert_count(n_threads) assert_count(timeout) - assert_subset(devices, self$devices) + assert_subset(devices, c("cuda", "cpu")) require_namespaces("mlr3torch") diff --git a/R/AutoGlmnet.R b/R/AutoGlmnet.R index d26ac12..be600b1 100644 --- a/R/AutoGlmnet.R +++ b/R/AutoGlmnet.R @@ -10,6 +10,7 @@ #' @template param_measure #' @template param_n_threads #' @template param_timeout +#' @template param_devices #' #' @export AutoGlmnet = R6Class("AutoGlmnet", diff --git a/R/AutoKknn.R b/R/AutoKknn.R index 752cc2e..867e0d1 100644 --- a/R/AutoKknn.R +++ b/R/AutoKknn.R @@ -10,6 +10,7 @@ #' @template param_measure #' @template param_n_threads #' @template param_timeout +#' @template param_devices #' #' @export AutoKknn = R6Class("AutoKknn", diff --git a/R/AutoLda.R b/R/AutoLda.R index 4fde1ba..9d66116 100644 --- a/R/AutoLda.R +++ b/R/AutoLda.R @@ -13,6 +13,7 @@ #' @template param_size #' @template param_n_threads #' @template param_timeout +#' @template param_devices #' #' @export AutoLda = R6Class("AutoLda", diff --git a/R/AutoLightgbm.R b/R/AutoLightgbm.R index b26bfb0..f14159f 100644 --- a/R/AutoLightgbm.R +++ b/R/AutoLightgbm.R @@ -10,6 +10,7 @@ #' @template param_measure #' @template param_n_threads #' @template param_timeout +#' @template param_devices #' #' @export AutoLightgbm = R6Class("AutoLightgbm", diff --git a/R/AutoMlp.R b/R/AutoMlp.R index b4506d5..8108e99 100644 --- a/R/AutoMlp.R +++ b/R/AutoMlp.R @@ -10,6 +10,7 @@ #' @template param_measure #' @template param_n_threads #' @template param_timeout +#' @template param_devices #' #' @export AutoMlp = R6Class("AutoMlp", @@ -24,20 +25,10 @@ AutoMlp = R6Class("AutoMlp", properties = "internal_tuning", task_types = c("classif", "regr"), packages = c("mlr3", "mlr3torch"), - devices = c("cpu", "cuda") + devices = "cuda" ) }, - #' @description - #' Check if the auto is compatible with the task. - check = function(task, memory_limit = Inf, large_data_set = FALSE, devices = "cpu") { - if ("cuda" %nin% devices && task$nrow > 1e3) { - lg$info("Learner '%s' is not compatible with tasks with more than 1,000 rows when using 'cpu' as device", self$id) - return(FALSE) - } - super$check(task, memory_limit, large_data_set, devices) - }, - #' @description #' Create the graph for the auto. graph = function(task, measure, n_threads, timeout, devices) { @@ -45,7 +36,7 @@ AutoMlp = R6Class("AutoMlp", assert_measure(measure) assert_count(n_threads) assert_count(timeout) - assert_subset(devices, self$devices) + assert_subset(devices, c("cuda", "cpu")) require_namespaces("mlr3torch") diff --git a/R/AutoRanger.R b/R/AutoRanger.R index 0fad623..b019823 100644 --- a/R/AutoRanger.R +++ b/R/AutoRanger.R @@ -10,6 +10,8 @@ #' @template param_measure #' @template param_n_threads #' @template param_timeout +#' @template param_devices +#' #' #' @export AutoRanger = R6Class("AutoRanger", diff --git a/R/AutoResNet.R b/R/AutoResNet.R index 009529c..e3264a7 100644 --- a/R/AutoResNet.R +++ b/R/AutoResNet.R @@ -10,6 +10,7 @@ #' @template param_measure #' @template param_n_threads #' @template param_timeout +#' @template param_devices #' #' @export AutoResNet = R6Class("AutoResNet", @@ -23,20 +24,10 @@ AutoResNet = R6Class("AutoResNet", properties = "internal_tuning", task_types = c("classif", "regr"), packages = c("mlr3", "mlr3torch"), - devices = c("cpu", "cuda") + devices = "cuda" ) }, - #' @description - #' Check if the auto is compatible with the task. - check = function(task, memory_limit = Inf, large_data_set = FALSE, devices = "cpu") { - if ("cuda" %nin% devices && task$nrow > 1e3) { - lg$info("Learner '%s' is not compatible with tasks with more than 1,000 rows when using 'cpu' as device", self$id) - return(FALSE) - } - super$check(task, memory_limit, large_data_set, devices) - }, - #' @description #' Create the graph for the auto. graph = function(task, measure, n_threads, timeout, devices) { @@ -44,7 +35,7 @@ AutoResNet = R6Class("AutoResNet", assert_measure(measure) assert_count(n_threads) assert_count(timeout) - assert_subset(devices, self$devices) + assert_subset(devices, c("cuda", "cpu")) require_namespaces("mlr3torch") diff --git a/R/AutoSvm.R b/R/AutoSvm.R index 2203124..5b49027 100644 --- a/R/AutoSvm.R +++ b/R/AutoSvm.R @@ -10,6 +10,7 @@ #' @template param_measure #' @template param_n_threads #' @template param_timeout +#' @template param_devices #' #' @export AutoSvm = R6Class("AutoSvm", diff --git a/R/AutoTabpfn.R b/R/AutoTabpfn.R index 8c4f46a..d8ba24d 100644 --- a/R/AutoTabpfn.R +++ b/R/AutoTabpfn.R @@ -12,6 +12,7 @@ #' @template param_timeout #' @template param_memory_limit #' @template param_large_data_set +#' @template param_devices #' #' @export AutoTabpfn = R6Class("AutoTabpfn", diff --git a/R/AutoXgboost.R b/R/AutoXgboost.R index 0758314..a225a66 100644 --- a/R/AutoXgboost.R +++ b/R/AutoXgboost.R @@ -10,6 +10,7 @@ #' @template param_measure #' @template param_n_threads #' @template param_timeout +#' @template param_devices #' #' @export AutoXgboost = R6Class("AutoXgboost", diff --git a/R/LearnerClassifAuto.R b/R/LearnerClassifAuto.R index fc28db7..11c2e5b 100644 --- a/R/LearnerClassifAuto.R +++ b/R/LearnerClassifAuto.R @@ -53,7 +53,8 @@ LearnerClassifAuto = R6Class("LearnerClassifAuto", store_models = p_lgl(init = FALSE, tags = c("train", "super")), # debugging encapsulate_learner = p_lgl(init = TRUE, tags = c("train", "super")), - encapsulate_mbo = p_lgl(init = TRUE, tags = c("train", "super")) + encapsulate_mbo = p_lgl(init = TRUE, tags = c("train", "super")), + check_learners = p_lgl(init = TRUE, tags = c("train", "super")) ) # subset to relevant parameters for selected learners param_set = param_set$subset(ids = unique(param_set$ids(any_tags = c("super", learner_ids)))) @@ -65,7 +66,7 @@ LearnerClassifAuto = R6Class("LearnerClassifAuto", id = id, task_type = "classif", param_set = param_set, - packages = union(c("mlr3", "mlr3tuning","mlr3pipelines"), packages), + packages = union(c("mlr3", "mlr3tuning","mlr3pipelines", "mlr3learners"), packages), feature_types = c("logical", "integer", "numeric", "character", "factor"), predict_types = c("response", "prob"), properties = c("missings", "weights", "twoclass", "multiclass"), diff --git a/R/train_auto.R b/R/train_auto.R index 2280cdd..f60f13f 100644 --- a/R/train_auto.R +++ b/R/train_auto.R @@ -30,10 +30,12 @@ train_auto = function(self, private, task) { } # initialize graph learner - autos = keep(autos, function(auto) auto$check(task, memory_limit = memory_limit, large_data_set = large_data_set, devices = pv$devices)) + if (pv$check_learners) { + autos = keep(autos, function(auto) auto$check(task, memory_limit = memory_limit, large_data_set = large_data_set, devices = pv$devices)) - if (!length(autos)) { - error_config("No learner is compatible with the task.") + if (!length(autos)) { + error_config("No learner is compatible with the task.") + } } if (all(map_lgl(autos, function(auto) "hyperparameter-free" %in% auto$properties))) { diff --git a/R/zzz.R b/R/zzz.R index 02a7ca7..9e5be1c 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -9,7 +9,7 @@ #' @import R6 #' @importFrom rush rush_config #' @import lhs - +#' @import mlr3learners "_PACKAGE" diff --git a/man-roxygen/param_devices.R b/man-roxygen/param_devices.R new file mode 100644 index 0000000..bee9b09 --- /dev/null +++ b/man-roxygen/param_devices.R @@ -0,0 +1,4 @@ +#' @param devices (`character()`)\cr +#' Devices to use. +#' Allowed values are `"cpu"` and `"cuda"`. +#' Default is "cpu". diff --git a/man/Auto.Rd b/man/Auto.Rd index a76cc2c..757816e 100644 --- a/man/Auto.Rd +++ b/man/Auto.Rd @@ -16,6 +16,8 @@ This class is the base class for all autos. \item{\code{task_types}}{(\code{character()}).} \item{\code{packages}}{(\code{character()}).} + +\item{\code{devices}}{(\code{character()}).} } \if{html}{\out{}} } @@ -42,14 +44,27 @@ This class is the base class for all autos. \subsection{Method \code{new()}}{ Creates a new instance of this \link[R6:R6Class]{R6} class. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{Auto$new(id)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{Auto$new( + id, + properties = character(0), + task_types = character(0), + packages = character(0), + devices = character(0) +)}\if{html}{\out{
}} } \subsection{Arguments}{ \if{html}{\out{
}} \describe{ -\item{\code{id}}{(\code{character(1)})\cr -Identifier for the new instance.} +\item{\code{id}}{(\code{character(1)}).} + +\item{\code{properties}}{(\code{character()}).} + +\item{\code{task_types}}{(\code{character()}).} + +\item{\code{packages}}{(\code{character()}).} + +\item{\code{devices}}{(\code{character()}).} } \if{html}{\out{
}} } @@ -60,7 +75,7 @@ Identifier for the new instance.} \subsection{Method \code{check()}}{ Check if the auto is compatible with the task. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{Auto$check(task, memory_limit = Inf, large_data_set = FALSE)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{Auto$check(task, memory_limit = Inf, large_data_set = FALSE, devices)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -71,6 +86,11 @@ Check if the auto is compatible with the task. \item{\code{memory_limit}}{(\code{integer(1)}).} \item{\code{large_data_set}}{(\code{logical(1)}).} + +\item{\code{devices}}{(\code{character()})\cr +Devices to use. +Allowed values are \code{"cpu"} and \code{"cuda"}. +Default is "cpu".} } \if{html}{\out{}} } @@ -81,7 +101,7 @@ Check if the auto is compatible with the task. \subsection{Method \code{graph()}}{ Create the graph for the auto. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{Auto$graph(task, measure, n_threads, timeout)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{Auto$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -94,6 +114,11 @@ Create the graph for the auto. \item{\code{n_threads}}{(\code{integer(1)}).} \item{\code{timeout}}{(\code{integer(1)}).} + +\item{\code{devices}}{(\code{character()})\cr +Devices to use. +Allowed values are \code{"cpu"} and \code{"cuda"}. +Default is "cpu".} } \if{html}{\out{}} } diff --git a/man/AutoCatboost.Rd b/man/AutoCatboost.Rd index 0dcae38..7ca87da 100644 --- a/man/AutoCatboost.Rd +++ b/man/AutoCatboost.Rd @@ -57,7 +57,7 @@ Identifier for the new instance.} \subsection{Method \code{graph()}}{ Create the graph for the auto. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{AutoCatboost$graph(task, measure, n_threads, timeout)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{AutoCatboost$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -70,6 +70,11 @@ Create the graph for the auto. \item{\code{n_threads}}{(\code{integer(1)}).} \item{\code{timeout}}{(\code{integer(1)}).} + +\item{\code{devices}}{(\code{character()})\cr +Devices to use. +Allowed values are \code{"cpu"} and \code{"cuda"}. +Default is "cpu".} } \if{html}{\out{}} } diff --git a/man/AutoExtraTrees.Rd b/man/AutoExtraTrees.Rd index 9a07a3d..27a9078 100644 --- a/man/AutoExtraTrees.Rd +++ b/man/AutoExtraTrees.Rd @@ -56,7 +56,7 @@ Identifier for the new instance.} \subsection{Method \code{graph()}}{ Create the graph for the auto. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{AutoExtraTrees$graph(task, measure, n_threads, timeout)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{AutoExtraTrees$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -69,6 +69,11 @@ Create the graph for the auto. \item{\code{n_threads}}{(\code{numeric(1)}).} \item{\code{timeout}}{(\code{numeric(1)}).} + +\item{\code{devices}}{(\code{character()})\cr +Devices to use. +Allowed values are \code{"cpu"} and \code{"cuda"}. +Default is "cpu".} } \if{html}{\out{}} } diff --git a/man/AutoFTTransformer.Rd b/man/AutoFTTransformer.Rd index 5eac5db..584fd98 100644 --- a/man/AutoFTTransformer.Rd +++ b/man/AutoFTTransformer.Rd @@ -56,7 +56,7 @@ Identifier for the new instance.} \subsection{Method \code{graph()}}{ Create the graph for the auto. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{AutoFTTransformer$graph(task, measure, n_threads, timeout)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{AutoFTTransformer$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -69,6 +69,11 @@ Create the graph for the auto. \item{\code{n_threads}}{(\code{integer(1)}).} \item{\code{timeout}}{(\code{integer(1)}).} + +\item{\code{devices}}{(\code{character()})\cr +Devices to use. +Allowed values are \code{"cpu"} and \code{"cuda"}. +Default is "cpu".} } \if{html}{\out{}} } diff --git a/man/AutoGlmnet.Rd b/man/AutoGlmnet.Rd index 42f174a..fc3a514 100644 --- a/man/AutoGlmnet.Rd +++ b/man/AutoGlmnet.Rd @@ -56,7 +56,7 @@ Identifier for the new instance.} \subsection{Method \code{graph()}}{ Create the graph for the auto. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{AutoGlmnet$graph(task, measure, n_threads, timeout)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{AutoGlmnet$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -69,6 +69,11 @@ Create the graph for the auto. \item{\code{n_threads}}{(\code{integer(1)}).} \item{\code{timeout}}{(\code{integer(1)}).} + +\item{\code{devices}}{(\code{character()})\cr +Devices to use. +Allowed values are \code{"cpu"} and \code{"cuda"}. +Default is "cpu".} } \if{html}{\out{}} } diff --git a/man/AutoKknn.Rd b/man/AutoKknn.Rd index 67d833a..2c048e2 100644 --- a/man/AutoKknn.Rd +++ b/man/AutoKknn.Rd @@ -56,7 +56,7 @@ Identifier for the new instance.} \subsection{Method \code{graph()}}{ Create the graph for the auto. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{AutoKknn$graph(task, measure, n_threads, timeout)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{AutoKknn$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -69,6 +69,11 @@ Create the graph for the auto. \item{\code{n_threads}}{(\code{integer(1)}).} \item{\code{timeout}}{(\code{integer(1)}).} + +\item{\code{devices}}{(\code{character()})\cr +Devices to use. +Allowed values are \code{"cpu"} and \code{"cuda"}. +Default is "cpu".} } \if{html}{\out{}} } diff --git a/man/AutoLda.Rd b/man/AutoLda.Rd index b160d1d..84f80d2 100644 --- a/man/AutoLda.Rd +++ b/man/AutoLda.Rd @@ -56,7 +56,7 @@ Identifier for the new instance.} \subsection{Method \code{graph()}}{ Create the graph for the auto. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{AutoLda$graph(task, measure, n_threads, timeout)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{AutoLda$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -73,6 +73,11 @@ Create the graph for the auto. \item{\code{n_threads}}{(\code{integer(1)}).} \item{\code{timeout}}{(\code{integer(1)}).} + +\item{\code{devices}}{(\code{character()})\cr +Devices to use. +Allowed values are \code{"cpu"} and \code{"cuda"}. +Default is "cpu".} } \if{html}{\out{}} } diff --git a/man/AutoLightgbm.Rd b/man/AutoLightgbm.Rd index aa02d0d..bc854d4 100644 --- a/man/AutoLightgbm.Rd +++ b/man/AutoLightgbm.Rd @@ -57,7 +57,7 @@ Identifier for the new instance.} \subsection{Method \code{graph()}}{ Create the graph for the auto. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{AutoLightgbm$graph(task, measure, n_threads, timeout)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{AutoLightgbm$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -70,6 +70,11 @@ Create the graph for the auto. \item{\code{n_threads}}{(\code{integer(1)}).} \item{\code{timeout}}{(\code{integer(1)}).} + +\item{\code{devices}}{(\code{character()})\cr +Devices to use. +Allowed values are \code{"cpu"} and \code{"cuda"}. +Default is "cpu".} } \if{html}{\out{}} } diff --git a/man/AutoMlp.Rd b/man/AutoMlp.Rd index c7000a4..d5bbc78 100644 --- a/man/AutoMlp.Rd +++ b/man/AutoMlp.Rd @@ -56,7 +56,7 @@ Identifier for the new instance.} \subsection{Method \code{graph()}}{ Create the graph for the auto. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{AutoMlp$graph(task, measure, n_threads, timeout)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{AutoMlp$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -69,6 +69,11 @@ Create the graph for the auto. \item{\code{n_threads}}{(\code{integer(1)}).} \item{\code{timeout}}{(\code{integer(1)}).} + +\item{\code{devices}}{(\code{character()})\cr +Devices to use. +Allowed values are \code{"cpu"} and \code{"cuda"}. +Default is "cpu".} } \if{html}{\out{}} } diff --git a/man/AutoRanger.Rd b/man/AutoRanger.Rd index 5cd9cf0..08c3f5f 100644 --- a/man/AutoRanger.Rd +++ b/man/AutoRanger.Rd @@ -56,7 +56,7 @@ Identifier for the new instance.} \subsection{Method \code{graph()}}{ Create the graph for the auto. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{AutoRanger$graph(task, measure, n_threads, timeout)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{AutoRanger$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -69,6 +69,11 @@ Create the graph for the auto. \item{\code{n_threads}}{(\code{integer(1)}).} \item{\code{timeout}}{(\code{integer(1)}).} + +\item{\code{devices}}{(\code{character()})\cr +Devices to use. +Allowed values are \code{"cpu"} and \code{"cuda"}. +Default is "cpu".} } \if{html}{\out{}} } diff --git a/man/AutoResNet.Rd b/man/AutoResNet.Rd index e8a212f..fd91667 100644 --- a/man/AutoResNet.Rd +++ b/man/AutoResNet.Rd @@ -56,7 +56,7 @@ Identifier for the new instance.} \subsection{Method \code{graph()}}{ Create the graph for the auto. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{AutoResNet$graph(task, measure, n_threads, timeout)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{AutoResNet$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -69,6 +69,11 @@ Create the graph for the auto. \item{\code{n_threads}}{(\code{integer(1)}).} \item{\code{timeout}}{(\code{integer(1)}).} + +\item{\code{devices}}{(\code{character()})\cr +Devices to use. +Allowed values are \code{"cpu"} and \code{"cuda"}. +Default is "cpu".} } \if{html}{\out{}} } diff --git a/man/AutoSvm.Rd b/man/AutoSvm.Rd index d11c9a1..49434c5 100644 --- a/man/AutoSvm.Rd +++ b/man/AutoSvm.Rd @@ -56,7 +56,7 @@ Identifier for the new instance.} \subsection{Method \code{graph()}}{ Create the graph for the auto. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{AutoSvm$graph(task, measure, n_threads, timeout)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{AutoSvm$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -69,6 +69,11 @@ Create the graph for the auto. \item{\code{n_threads}}{(\code{integer(1)}).} \item{\code{timeout}}{(\code{integer(1)}).} + +\item{\code{devices}}{(\code{character()})\cr +Devices to use. +Allowed values are \code{"cpu"} and \code{"cuda"}. +Default is "cpu".} } \if{html}{\out{}} } diff --git a/man/AutoTabpfn.Rd b/man/AutoTabpfn.Rd index 34574d4..d01008e 100644 --- a/man/AutoTabpfn.Rd +++ b/man/AutoTabpfn.Rd @@ -56,7 +56,12 @@ Identifier for the new instance.} \subsection{Method \code{check()}}{ Check if the auto is compatible with the task. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{AutoTabpfn$check(task, memory_limit = Inf, large_data_set = FALSE)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{AutoTabpfn$check( + task, + memory_limit = Inf, + large_data_set = FALSE, + devices = "cpu" +)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -67,6 +72,11 @@ Check if the auto is compatible with the task. \item{\code{memory_limit}}{(\code{integer(1)}).} \item{\code{large_data_set}}{(\code{logical(1)}).} + +\item{\code{devices}}{(\code{character()})\cr +Devices to use. +Allowed values are \code{"cpu"} and \code{"cuda"}. +Default is "cpu".} } \if{html}{\out{}} } @@ -77,7 +87,7 @@ Check if the auto is compatible with the task. \subsection{Method \code{graph()}}{ Create the graph for the auto. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{AutoTabpfn$graph(task, measure, n_threads, timeout)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{AutoTabpfn$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -90,6 +100,11 @@ Create the graph for the auto. \item{\code{n_threads}}{(\code{integer(1)}).} \item{\code{timeout}}{(\code{integer(1)}).} + +\item{\code{devices}}{(\code{character()})\cr +Devices to use. +Allowed values are \code{"cpu"} and \code{"cuda"}. +Default is "cpu".} } \if{html}{\out{}} } diff --git a/man/AutoXgboost.Rd b/man/AutoXgboost.Rd index c049238..0300a43 100644 --- a/man/AutoXgboost.Rd +++ b/man/AutoXgboost.Rd @@ -57,7 +57,7 @@ Identifier for the new instance.} \subsection{Method \code{graph()}}{ Create the graph for the auto. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{AutoXgboost$graph(task, measure, n_threads, timeout)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{AutoXgboost$graph(task, measure, n_threads, timeout, devices)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -70,6 +70,11 @@ Create the graph for the auto. \item{\code{n_threads}}{(\code{integer(1)}).} \item{\code{timeout}}{(\code{integer(1)}).} + +\item{\code{devices}}{(\code{character()})\cr +Devices to use. +Allowed values are \code{"cpu"} and \code{"cuda"}. +Default is "cpu".} } \if{html}{\out{}} } diff --git a/tests/testthat/helper.R b/tests/testthat/helper.R index b2daf98..45298d9 100644 --- a/tests/testthat/helper.R +++ b/tests/testthat/helper.R @@ -25,7 +25,14 @@ expect_rush_reset = function(rush, type = "kill") { } -test_classif_learner = function(learner_id, initial_design_size = 2, initial_design_type = "lhs", n_evals = 5) { +test_classif_learner = function( + learner_id, + initial_design_size = 2, + initial_design_type = "lhs", + n_evals = 4, + task = NULL, + check_learners = TRUE + ) { skip_on_cran() skip_if_not_installed(unlist(map(mlr_auto$mget(learner_id), "packages"))) skip_if_not_installed("rush") @@ -34,7 +41,7 @@ test_classif_learner = function(learner_id, initial_design_size = 2, initial_des rush_plan(n_workers = 2, worker_type = "remote") mirai::daemons(2) - task = tsk("penguins") + task = if (is.null(task)) tsk("penguins") else task learner = lrn("classif.auto", learner_ids = learner_id, small_data_size = 1, @@ -44,7 +51,8 @@ test_classif_learner = function(learner_id, initial_design_size = 2, initial_des initial_design_type = initial_design_type, initial_design_size = initial_design_size, encapsulate_learner = FALSE, - encapsulate_mbo = FALSE + encapsulate_mbo = FALSE, + check_learners = check_learners ) expect_class(learner$train(task), "LearnerClassifAuto") @@ -84,3 +92,4 @@ test_regr_learner = function(learner_id, n_evals = 6) { learner } +all_packages = c("glmnet", "kknn", "ranger", "e1071", "xgboost", "catboost","lightgbm", "fastai", "mlr3torch") diff --git a/tests/testthat/test_Auto.R b/tests/testthat/test_Auto.R new file mode 100644 index 0000000..13100de --- /dev/null +++ b/tests/testthat/test_Auto.R @@ -0,0 +1,44 @@ +test_that("default design is generated", { + skip_if_not_installed(all_packages) + + autos = mlr_auto$mget(mlr_auto$keys()) + xdt = map_dtr(autos, function(auto) auto$design_default(tsk("penguins")), .fill = TRUE) + expect_data_table(xdt, nrows = length(autos)) + expect_set_equal(xdt$branch.selection, mlr_auto$keys()) +}) + +test_that("lhs design is generated", { + skip_if_not_installed(all_packages) + + autos = mlr_auto$mget(mlr_auto$keys()) + xdt = map_dtr(autos, function(auto) auto$design_lhs(tsk("penguins"), 10L), .fill = TRUE) + expect_data_table(xdt, nrows = length(autos) * 10 - 20 + 2) + expect_set_equal(xdt$branch.selection, mlr_auto$keys()) +}) + +test_that("random design is generated", { + skip_if_not_installed(all_packages) + + autos = mlr_auto$mget(mlr_auto$keys()) + xdt = map_dtr(autos, function(auto) auto$design_random(tsk("penguins"), 10L), .fill = TRUE) + expect_data_table(xdt, nrows = length(autos) * 10 - 20 + 2) + expect_set_equal(xdt$branch.selection, mlr_auto$keys()) + +}) + +test_that("set design is generated", { + skip_if_not_installed(all_packages) + + autos = mlr_auto$mget(mlr_auto$keys()) + xdt = map_dtr(autos, function(auto) auto$design_set(tsk("penguins"), msr("classif.ce"), 10L), .fill = TRUE) + expect_data_table(xdt, nrows = 70L) + expect_set_equal(xdt$branch.selection, c("glmnet", "kknn", "ranger", "svm", "xgboost", "catboost","lightgbm")) +}) + +test_that("estimate memory works", { + skip_if_not_installed(all_packages) + + autos = mlr_auto$mget(mlr_auto$keys()) + memory = map_dbl(autos, function(auto) auto$estimate_memory(tsk("penguins"))) + expect_numeric(memory) +}) diff --git a/tests/testthat/test_LearnerClassifAuto.R b/tests/testthat/test_LearnerClassifAuto.R index cb4e9fb..c10d608 100644 --- a/tests/testthat/test_LearnerClassifAuto.R +++ b/tests/testthat/test_LearnerClassifAuto.R @@ -1,50 +1,3 @@ -all_packages = c("glmnet", "kknn", "ranger", "e1071", "xgboost", "catboost","lightgbm", "fastai", "mlr3torch") - -test_that("default design is generated", { - skip_if_not_installed(all_packages) - - autos = mlr_auto$mget(mlr_auto$keys()) - xdt = map_dtr(autos, function(auto) auto$design_default(tsk("penguins")), .fill = TRUE) - expect_data_table(xdt, nrows = length(autos)) - expect_set_equal(xdt$branch.selection, mlr_auto$keys()) -}) - -test_that("lhs design is generated", { - skip_if_not_installed(all_packages) - - autos = mlr_auto$mget(mlr_auto$keys()) - xdt = map_dtr(autos, function(auto) auto$design_lhs(tsk("penguins"), 10L), .fill = TRUE) - expect_data_table(xdt, nrows = length(autos) * 10 - 20 + 2) - expect_set_equal(xdt$branch.selection, mlr_auto$keys()) -}) - -test_that("random design is generated", { - skip_if_not_installed(all_packages) - - autos = mlr_auto$mget(mlr_auto$keys()) - xdt = map_dtr(autos, function(auto) auto$design_random(tsk("penguins"), 10L), .fill = TRUE) - expect_data_table(xdt, nrows = length(autos) * 10 - 20 + 2) - expect_set_equal(xdt$branch.selection, mlr_auto$keys()) - -}) - -test_that("set design is generated", { - skip_if_not_installed(all_packages) - - autos = mlr_auto$mget(mlr_auto$keys()) - xdt = map_dtr(autos, function(auto) auto$design_set(tsk("penguins"), msr("classif.ce"), 10L), .fill = TRUE) - expect_data_table(xdt, nrows = 70L) - expect_set_equal(xdt$branch.selection, c("glmnet", "kknn", "ranger", "svm", "xgboost", "catboost","lightgbm")) -}) - -test_that("estimate memory works", { - skip_if_not_installed(all_packages) - - autos = mlr_auto$mget(mlr_auto$keys()) - memory = map_dbl(autos, function(auto) auto$estimate_memory(tsk("penguins"))) - expect_numeric(memory) -}) - test_that("LearnerClassifAuto is initialized", { learner = lrn("classif.auto", measure = msr("classif.ce"), @@ -119,34 +72,29 @@ test_that("lightgbm works", { }) test_that("mlp works", { - skip_if(TRUE) - - test_classif_learner("mlp") + task = tsk("penguins") + task$filter(c(1, 153, 277)) + test_classif_learner("mlp", task = task, check_learners = FALSE) }) test_that("resnet works", { - skip_if(TRUE) - - test_classif_learner("resnet") + task = tsk("penguins") + task$filter(c(1, 153, 277)) + test_classif_learner("resnet", task = task, check_learners = FALSE) }) test_that("ft_transformer works", { - skip_if(TRUE) - - test_classif_learner("ft_transformer") + task = tsk("penguins") + task$filter(c(1, 153, 277)) + test_classif_learner("ft_transformer", task = task, check_learners = FALSE) }) test_that("tabpfn works", { - skip_if(TRUE) - + task = tsk("penguins") + task$filter(c(1, 153, 277)) test_classif_learner("tabpfn") }) - -# test_that("fastai works", { -# test_classif_learner("fastai") -# }) - test_that("xgboost, catboost and lightgbm work", { test_classif_learner(c("xgboost", "catboost", "lightgbm")) }) From d2c652793dcdb02afa483e5f9b456a7b416c01c8 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 24 Oct 2025 13:13:52 +0200 Subject: [PATCH 6/8] ... --- R/AutoTabpfn.R | 2 +- tests/testthat/test_LearnerClassifAuto.R | 170 +++++++++++++++++++++-- 2 files changed, 158 insertions(+), 14 deletions(-) diff --git a/R/AutoTabpfn.R b/R/AutoTabpfn.R index d8ba24d..f9b9c34 100644 --- a/R/AutoTabpfn.R +++ b/R/AutoTabpfn.R @@ -33,7 +33,7 @@ AutoTabpfn = R6Class("AutoTabpfn", #' @description #' Check if the auto is compatible with the task. check = function(task, memory_limit = Inf, large_data_set = FALSE, devices = "cpu") { - ok = check_python_packages(c("torch")) + ok = check_python_packages(c("torch", "tabpfn")) if (!isTRUE(ok)) { lg$info(ok) lg$info("Remove tabpfn from search space") diff --git a/tests/testthat/test_LearnerClassifAuto.R b/tests/testthat/test_LearnerClassifAuto.R index c10d608..4fa5e94 100644 --- a/tests/testthat/test_LearnerClassifAuto.R +++ b/tests/testthat/test_LearnerClassifAuto.R @@ -72,27 +72,171 @@ test_that("lightgbm works", { }) test_that("mlp works", { - task = tsk("penguins") - task$filter(c(1, 153, 277)) - test_classif_learner("mlp", task = task, check_learners = FALSE) + skip_on_cran() + skip_if_not_installed(unlist(map(mlr_auto$mget("mlp"), "packages"))) + skip_if_not_installed("rush") + flush_redis() + + expect_true(callr::r(function() { + Sys.setenv(RETICULATE_PYTHON = "managed") + library(mlr3automl) + library(testthat) + library(checkmate) + + rush_plan(n_workers = 2, worker_type = "remote") + mirai::daemons(2) + + mirai::everywhere({ + Sys.setenv(RETICULATE_PYTHON = "managed") + }) + + task = tsk("penguins") + task$filter(c(1, 153, 277)) + + learner = lrn("classif.auto", + learner_ids = "mlp", + small_data_size = 1, + resampling = rsmp("holdout"), + measure = msr("classif.ce"), + terminator = trm("evals", n_evals = 4), + initial_design_type = "lhs", + initial_design_size = 2, + encapsulate_learner = FALSE, + encapsulate_mbo = FALSE, + check_learners = FALSE) + + expect_class(learner$train(task), "LearnerClassifAuto") + expect_subset(learner$model$instance$result$branch.selection, "mlp") + expect_set_equal(learner$model$instance$archive$data$branch.selection, "mlp") + + TRUE + })) }) test_that("resnet works", { - task = tsk("penguins") - task$filter(c(1, 153, 277)) - test_classif_learner("resnet", task = task, check_learners = FALSE) + skip_on_cran() + skip_if_not_installed(unlist(map(mlr_auto$mget("resnet"), "packages"))) + skip_if_not_installed("rush") + flush_redis() + + expect_true(callr::r(function() { + Sys.setenv(RETICULATE_PYTHON = "managed") + library(mlr3automl) + library(testthat) + library(checkmate) + + rush_plan(n_workers = 2, worker_type = "remote") + mirai::daemons(2) + + mirai::everywhere({ + Sys.setenv(RETICULATE_PYTHON = "managed") + }) + + task = tsk("penguins") + task$filter(c(1, 153, 277)) + + learner = lrn("classif.auto", + learner_ids = "resnet", + small_data_size = 1, + resampling = rsmp("holdout"), + measure = msr("classif.ce"), + terminator = trm("evals", n_evals = 4), + initial_design_type = "lhs", + initial_design_size = 2, + encapsulate_learner = FALSE, + encapsulate_mbo = FALSE, + check_learners = FALSE) + + expect_class(learner$train(task), "LearnerClassifAuto") + expect_subset(learner$model$instance$result$branch.selection, "resnet") + expect_set_equal(learner$model$instance$archive$data$branch.selection, "resnet") + + TRUE + })) }) test_that("ft_transformer works", { - task = tsk("penguins") - task$filter(c(1, 153, 277)) - test_classif_learner("ft_transformer", task = task, check_learners = FALSE) + skip_on_cran() + skip_if_not_installed(unlist(map(mlr_auto$mget("ft_transformer"), "packages"))) + skip_if_not_installed("rush") + flush_redis() + + expect_true(callr::r(function() { + Sys.setenv(RETICULATE_PYTHON = "managed") + library(mlr3automl) + library(testthat) + library(checkmate) + + rush_plan(n_workers = 2, worker_type = "remote") + mirai::daemons(2) + + mirai::everywhere({ + Sys.setenv(RETICULATE_PYTHON = "managed") + }) + + task = tsk("penguins") + task$filter(c(1, 153, 277)) + + learner = lrn("classif.auto", + learner_ids = "ft_transformer", + small_data_size = 1, + resampling = rsmp("holdout"), + measure = msr("classif.ce"), + terminator = trm("evals", n_evals = 4), + initial_design_type = "lhs", + initial_design_size = 2, + encapsulate_learner = FALSE, + encapsulate_mbo = FALSE, + check_learners = FALSE) + + expect_class(learner$train(task), "LearnerClassifAuto") + expect_subset(learner$model$instance$result$branch.selection, "ft_transformer") + expect_set_equal(learner$model$instance$archive$data$branch.selection, "ft_transformer") + + TRUE + })) }) test_that("tabpfn works", { - task = tsk("penguins") - task$filter(c(1, 153, 277)) - test_classif_learner("tabpfn") + skip_on_cran() + skip_if_not_installed(unlist(map(mlr_auto$mget("tabpfn"), "packages"))) + skip_if_not_installed("rush") + flush_redis() + + expect_true(callr::r(function() { + Sys.setenv(RETICULATE_PYTHON = "managed") + library(mlr3automl) + library(testthat) + library(checkmate) + + rush_plan(n_workers = 2, worker_type = "remote") + mirai::daemons(2) + + mirai::everywhere({ + Sys.setenv(RETICULATE_PYTHON = "managed") + }) + + task = tsk("penguins") + task$filter(c(1, 153, 277)) + + learner = lrn("classif.auto", + learner_ids = "tabpfn", + small_data_size = 1, + resampling = rsmp("holdout"), + measure = msr("classif.ce"), + terminator = trm("evals", n_evals = 4), + initial_design_type = "lhs", + initial_design_size = 2, + encapsulate_learner = FALSE, + encapsulate_mbo = FALSE, + check_learners = TRUE) + + expect_class(learner$train(task), "LearnerClassifAuto") + expect_subset(learner$model$instance$result$branch.selection, "tabpfn") + expect_set_equal(learner$model$instance$archive$data$branch.selection, "tabpfn") + + TRUE + })) }) test_that("xgboost, catboost and lightgbm work", { @@ -100,7 +244,7 @@ test_that("xgboost, catboost and lightgbm work", { }) test_that("all learner work", { - test_classif_learner(c("catboost", "glmnet", "kknn", "lightgbm", "mlp", "ranger", "svm", "xgboost", "lda", "extra_trees"), initial_design_type = c("lhs", "default")) + test_classif_learner(c("catboost", "glmnet", "kknn", "lightgbm", "ranger", "svm", "xgboost", "lda", "extra_trees"), initial_design_type = c("lhs", "default")) }) test_that("memory limit works", { From 030bb018d1285aaaf78fcb6972eb77fea1838116 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 24 Oct 2025 17:24:44 +0200 Subject: [PATCH 7/8] ... --- R/train_auto.R | 9 +++++++++ tests/testthat/test_LearnerClassifAuto.R | 3 ++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/R/train_auto.R b/R/train_auto.R index f60f13f..8823e0a 100644 --- a/R/train_auto.R +++ b/R/train_auto.R @@ -1,3 +1,12 @@ +initialize_auto = function(self, private, task) { + pv = self$param_set$values + large_data_set = task$nrow * task$ncol > pv$large_data_size + n_workers = rush_config()$n_workers %??% 1L + n_threads = pv$n_threads %??% 1L + memory_limit = (pv$memory_limit %??% Inf) / n_workers + autos = mlr_auto$mget(private$.learner_ids) +} + train_auto = function(self, private, task) { pv = self$param_set$values large_data_set = task$nrow * task$ncol > pv$large_data_size diff --git a/tests/testthat/test_LearnerClassifAuto.R b/tests/testthat/test_LearnerClassifAuto.R index 4fa5e94..d816be7 100644 --- a/tests/testthat/test_LearnerClassifAuto.R +++ b/tests/testthat/test_LearnerClassifAuto.R @@ -198,6 +198,7 @@ test_that("ft_transformer works", { }) test_that("tabpfn works", { + skip_if(TRUE) skip_on_cran() skip_if_not_installed(unlist(map(mlr_auto$mget("tabpfn"), "packages"))) skip_if_not_installed("rush") @@ -243,7 +244,7 @@ test_that("xgboost, catboost and lightgbm work", { test_classif_learner(c("xgboost", "catboost", "lightgbm")) }) -test_that("all learner work", { +test_that("all learner on cpu work", { test_classif_learner(c("catboost", "glmnet", "kknn", "lightgbm", "ranger", "svm", "xgboost", "lda", "extra_trees"), initial_design_type = c("lhs", "default")) }) From 5be377e6c5e4adb4f743f786489c829889fd3ebe Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 24 Oct 2025 17:26:24 +0200 Subject: [PATCH 8/8] ... --- R/AutoXgboost.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/AutoXgboost.R b/R/AutoXgboost.R index a225a66..cc67fed 100644 --- a/R/AutoXgboost.R +++ b/R/AutoXgboost.R @@ -39,7 +39,7 @@ AutoXgboost = R6Class("AutoXgboost", require_namespaces("mlr3learners") - device = if ("cuda" %in% devices) "cuda" else "cpu" + device = if ("cuda" %in% devices) "cuda" learner = lrn(sprintf("%s.xgboost", task$task_type), id = "xgboost",