Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Imports:
checkmate,
data.table,
lhs,
mlr3learners (>= 0.12.0),
mlr3mbo (>= 0.2.8),
mlr3misc (>= 0.15.1),
mlr3pipelines,
Expand All @@ -38,7 +39,6 @@ Suggests:
MASS,
mirai,
mlr3extralearners,
mlr3learners (>= 0.12.0),
mlr3torch,
mlr3viz,
ranger,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import(checkmate)
import(data.table)
import(lhs)
import(mlr3)
import(mlr3learners)
import(mlr3mbo)
import(mlr3misc)
import(mlr3pipelines)
Expand Down
30 changes: 27 additions & 3 deletions R/Auto.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#' @template param_memory_limit
#' @template param_large_data_set
#' @template param_size
#' @template param_devices
#'
#' @export
Auto = R6Class("Auto",
Expand All @@ -30,15 +31,33 @@ Auto = R6Class("Auto",
#' @field packages (`character()`).
packages = NULL,

#' @field devices (`character()`).
devices = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id) {
#'
#' @param id (`character(1)`).
#' @param properties (`character()`).
#' @param task_types (`character()`).
#' @param packages (`character()`).
#' @param devices (`character()`).
initialize = function(id,
properties = character(0),
task_types = character(0),
packages = character(0),
devices = character(0)
) {
self$id = assert_string(id)
self$properties = assert_character(properties)
self$task_types = assert_character(task_types)
self$packages = assert_character(packages)
self$devices = assert_character(devices)
},

#' @description
#' Check if the auto is compatible with the task.
check = function(task, memory_limit = Inf, large_data_set = FALSE) {
check = function(task, memory_limit = Inf, large_data_set = FALSE, devices) {
if (!task$task_type %in% self$task_types) {
lg$info("Learner '%s' is not compatible with task type '%s'", self$id, task$task_type)
return(FALSE)
Expand All @@ -51,12 +70,17 @@ Auto = R6Class("Auto",
lg$info("Learner '%s' is not compatible with large data sets", self$id)
return(FALSE)
}
if (any(devices %nin% self$devices)) {
lg$info("Learner '%s' is not compatible with devices '%s'", self$id, as_short_string(devices))
return(FALSE)
}

TRUE
},

#' @description
#' Create the graph for the auto.
graph = function(task, measure, n_threads, timeout) {
graph = function(task, measure, n_threads, timeout, devices) {
stop("Abstract")
},

Expand Down
21 changes: 15 additions & 6 deletions R/AutoCatboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#' @template param_measure
#' @template param_n_threads
#' @template param_timeout
#' @template param_devices
#'
#' @export
AutoCatboost = R6Class("AutoCatboost",
Expand All @@ -19,28 +20,36 @@ AutoCatboost = R6Class("AutoCatboost",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "catboost") {
super$initialize(id = id)
self$task_types = c("classif", "regr")
self$properties = c("internal_tuning", "large_data_sets")
self$packages = c("mlr3", "mlr3extralearners", "catboost")
super$initialize(
id = id,
properties = c("internal_tuning", "large_data_sets"),
task_types = c("classif", "regr"),
packages = c("mlr3", "mlr3extralearners", "catboost"),
devices = c("cpu", "cuda")
)
},

#' @description
#' Create the graph for the auto.
graph = function(task, measure, n_threads, timeout) {
graph = function(task, measure, n_threads, timeout, devices) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
assert_count(timeout)
assert_subset(devices, self$devices)

require_namespaces("mlr3extralearners")

# catboost only supports gpu via cuda
task_type = if ("cuda" %in% devices) "GPU" else "CPU"

learner = lrn(sprintf("%s.catboost", task$task_type),
id = "catboost",
iterations = self$search_space(task)$upper["catboost.iterations"] %??% 1000L,
early_stopping_rounds = self$early_stopping_rounds(task),
use_best_model = TRUE,
eval_metric = self$internal_measure(measure, task))
eval_metric = self$internal_measure(measure, task),
task_type = task_type)
set_threads(learner, n_threads)

po("removeconstants", id = "catboost_removeconstants") %>>%
Expand Down
14 changes: 9 additions & 5 deletions R/AutoExtraTrees.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#' @template param_task
#' @template param_measure
#' @template param_size
#' @template param_devices
#'
#' @export
AutoExtraTrees = R6Class("AutoExtraTrees",
Expand All @@ -20,10 +21,13 @@ AutoExtraTrees = R6Class("AutoExtraTrees",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "extra_trees") {
super$initialize(id = id)
self$task_types = c("classif", "regr")
self$properties = c("large_data_sets", "hyperparameter-free")
self$packages = c("mlr3", "mlr3learners", "ranger")
super$initialize(
id = id,
properties = c("large_data_sets", "hyperparameter-free"),
task_types = c("classif", "regr"),
packages = c("mlr3", "mlr3learners", "ranger"),
devices = "cpu"
)
},

#' @description
Expand All @@ -33,7 +37,7 @@ AutoExtraTrees = R6Class("AutoExtraTrees",
#' @param measure ([mlr3::Measure]).
#' @param n_threads (`numeric(1)`).
#' @param timeout (`numeric(1)`).
graph = function(task, measure, n_threads, timeout) {
graph = function(task, measure, n_threads, timeout, devices) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
Expand Down
20 changes: 14 additions & 6 deletions R/AutoFTTransformer.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#' @template param_measure
#' @template param_n_threads
#' @template param_timeout
#' @template param_devices
#'
#' @export
AutoFTTransformer = R6Class("AutoFTTransformer",
Expand All @@ -19,22 +20,28 @@ AutoFTTransformer = R6Class("AutoFTTransformer",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "ft_transformer") {
super$initialize(id = id)
self$task_types = c("classif", "regr")
self$properties = "internal_tuning"
self$packages = c("mlr3", "mlr3torch")
super$initialize(
id = id,
properties = "internal_tuning",
task_types = c("classif", "regr"),
packages = c("mlr3", "mlr3torch"),
devices = "cuda"
)
},

#' @description
#' Create the graph for the auto.
graph = function(task, measure, n_threads, timeout) {
graph = function(task, measure, n_threads, timeout, devices) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
assert_count(timeout)
assert_subset(devices, c("cuda", "cpu"))

require_namespaces("mlr3torch")

device = if ("cuda" %in% devices) "cuda" else "auto"

# copied from mlr3tuningspaces
no_wd = function(name) {
# this will also disable weight decay for the input projection bias of the attention heads
Expand Down Expand Up @@ -74,7 +81,8 @@ AutoFTTransformer = R6Class("AutoFTTransformer",
patience = self$early_stopping_rounds(task),
batch_size = 32L,
attention_n_heads = 8L,
opt.param_groups = rtdl_param_groups
opt.param_groups = rtdl_param_groups,
device = device
)
set_threads(learner, n_threads)

Expand Down
13 changes: 8 additions & 5 deletions R/AutoGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#' @template param_measure
#' @template param_n_threads
#' @template param_timeout
#' @template param_devices
#'
#' @export
AutoGlmnet = R6Class("AutoGlmnet",
Expand All @@ -19,15 +20,17 @@ AutoGlmnet = R6Class("AutoGlmnet",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "glmnet") {
super$initialize(id = id)
self$task_types = c("classif", "regr")
self$properties = character()
self$packages = c("mlr3", "mlr3learners", "glmnet")
super$initialize(id = id,
properties = character(),
task_types = c("classif", "regr"),
packages = c("mlr3", "mlr3learners", "glmnet"),
devices = "cpu"
)
},

#' @description
#' Create the graph for the auto.
graph = function(task, measure, n_threads, timeout) {
graph = function(task, measure, n_threads, timeout, devices) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
Expand Down
13 changes: 8 additions & 5 deletions R/AutoKknn.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#' @template param_measure
#' @template param_n_threads
#' @template param_timeout
#' @template param_devices
#'
#' @export
AutoKknn = R6Class("AutoKknn",
Expand All @@ -19,15 +20,17 @@ AutoKknn = R6Class("AutoKknn",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "kknn") {
super$initialize(id = id)
self$task_types = c("classif", "regr")
self$properties = character()
self$packages = c("mlr3", "mlr3learners", "kknn")
super$initialize(id = id,
properties = character(),
task_types = c("classif", "regr"),
packages = c("mlr3", "mlr3learners", "kknn"),
devices = "cpu"
)
},

#' @description
#' Create the graph for the auto.
graph = function(task, measure, n_threads, timeout) {
graph = function(task, measure, n_threads, timeout, devices) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
Expand Down
13 changes: 8 additions & 5 deletions R/AutoLda.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#' @template param_size
#' @template param_n_threads
#' @template param_timeout
#' @template param_devices
#'
#' @export
AutoLda = R6Class("AutoLda",
Expand All @@ -22,15 +23,17 @@ AutoLda = R6Class("AutoLda",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "lda") {
super$initialize(id = id)
self$task_types = "classif"
self$properties = "hyperparameter-free"
self$packages = c("mlr3", "mlr3learners", "MASS")
super$initialize(
id = id,
task_types = "classif",
properties = "hyperparameter-free",
packages = c("mlr3", "mlr3learners", "MASS"),
devices = "cpu")
},

#' @description
#' Create the graph for the auto.
graph = function(task, measure, n_threads, timeout) {
graph = function(task, measure, n_threads, timeout, devices) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
Expand Down
19 changes: 13 additions & 6 deletions R/AutoLightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#' @template param_measure
#' @template param_n_threads
#' @template param_timeout
#' @template param_devices
#'
#' @export
AutoLightgbm = R6Class("AutoLightgbm",
Expand All @@ -19,27 +20,33 @@ AutoLightgbm = R6Class("AutoLightgbm",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "lightgbm") {
super$initialize(id = id)
self$task_types = c("classif", "regr")
self$properties = c("internal_tuning", "large_data_sets")
self$packages = c("mlr3", "mlr3extralearners", "lightgbm")
super$initialize(id = id,
properties = c("internal_tuning", "large_data_sets"),
task_types = c("classif", "regr"),
packages = c("mlr3", "mlr3extralearners", "lightgbm"),
devices = c("cpu", "cuda")
)
},

#' @description
#' Create the graph for the auto.
graph = function(task, measure, n_threads, timeout) {
graph = function(task, measure, n_threads, timeout, devices) {
assert_task(task)
assert_measure(measure)
assert_count(n_threads)
assert_count(timeout)
assert_subset(devices, self$devices)

require_namespaces("mlr3extralearners")

device_type = if ("cuda" %in% devices) "gpu" else "cpu"

learner = lrn(sprintf("%s.lightgbm", task$task_type),
id = "lightgbm",
early_stopping_rounds = self$early_stopping_rounds(task),
callbacks = list(cb_timeout_lightgbm(timeout * 0.8)),
eval = self$internal_measure(measure, task))
eval = self$internal_measure(measure, task),
device_type = device_type)
set_threads(learner, n_threads)

learner
Expand Down
Loading
Loading