From 5112b36998f5dc7d75f73f8ca5711a6c055b8116 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Thu, 29 Aug 2024 14:40:13 +0200 Subject: [PATCH 01/41] feat: cost over time --- DESCRIPTION | 5 +++++ R/visualization.R | 25 +++++++++++++++++++++++++ R/visualization_app.R | 23 +++++++++++++++++++++++ man-roxygen/param_instance.R | 2 ++ man-roxygen/param_theme.R | 2 ++ tests/testthat/test_visualization.R | 17 +++++++++++++++++ working.R | 28 ++++++++++++++++++++++++++++ 7 files changed, 102 insertions(+) create mode 100644 R/visualization.R create mode 100644 R/visualization_app.R create mode 100644 man-roxygen/param_instance.R create mode 100644 man-roxygen/param_theme.R create mode 100644 tests/testthat/test_visualization.R create mode 100644 working.R diff --git a/DESCRIPTION b/DESCRIPTION index 4ab70c0..56549a0 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -18,8 +18,10 @@ Depends: R (>= 3.1.0), rush Imports: + bslib, checkmate, data.table, + ggplot2, lhs, mlr3mbo, mlr3misc (>= 0.15.1), @@ -27,6 +29,7 @@ Imports: mlr3tuningspaces, paradox (>= 1.0.1), R6, + shiny, utils Suggests: catboost, @@ -62,4 +65,6 @@ Collate: 'LearnerClassifAutoXgboost.R' 'LearnerRegrAuto.R' 'helper.R' + 'visualization_app.R' + 'visualization.R' 'zzz.R' diff --git a/R/visualization.R b/R/visualization.R new file mode 100644 index 0000000..9785bc7 --- /dev/null +++ b/R/visualization.R @@ -0,0 +1,25 @@ +#' @title Cost-Over-Time Plot +#' +#' @description Plots cost over time using [ggplot2]. +#' +#' @template param_instance +#' @template param_theme +#' +#' @export +cost_over_time = function(instance, x = "config_id", theme = ggplot2::theme_minimal(), ...) { + # there should be only a single objective, e.g. `classif.ce` + cost = instance$objective$codomain$data$id[[1]] + + .data = NULL + ggplot2::ggplot(data = instance$archive$data, ggplot2::aes( + x = seq_len(nrow(instance$archive$data)), + y = .data[[cost]] + )) + + ggplot2::geom_point() + + ggplot2::geom_line() + + ggplot2::labs( + title = "Cost over time", + x = "configuration ID" + ) + + theme +} diff --git a/R/visualization_app.R b/R/visualization_app.R new file mode 100644 index 0000000..5581061 --- /dev/null +++ b/R/visualization_app.R @@ -0,0 +1,23 @@ +#' @title Shiny App for Visualizing AutoML Results +#' +#' @template param_instance +#' @export +visualize = function(instance) { + ui = bslib::page_navbar( + title = "Visualization for mlr3automl", + bslib::nav_panel("Cost over time", shiny::plotOutput("cost_over_time")), + bslib::nav_panel("Marginal plot"), + bslib::nav_panel("Parallel coordinates"), + bslib::nav_panel("Partial dependency plot") + ) + + server = function(input, output, session) { + session$onSessionEnded(stopApp) + output$cost_over_time = renderPlot({ + cost_over_time(instance) + }) + } + + shiny::shinyApp(ui = ui, server = server) +} + diff --git a/man-roxygen/param_instance.R b/man-roxygen/param_instance.R new file mode 100644 index 0000000..fdc1225 --- /dev/null +++ b/man-roxygen/param_instance.R @@ -0,0 +1,2 @@ +#' @param instance (`[TuningInstanceAsyncSingleCrit]`) +#' The tuning instance to visualize. \ No newline at end of file diff --git a/man-roxygen/param_theme.R b/man-roxygen/param_theme.R new file mode 100644 index 0000000..dd9bcfd --- /dev/null +++ b/man-roxygen/param_theme.R @@ -0,0 +1,2 @@ +#' @param theme ([ggplot2::theme()])\cr +#' The [ggplot2::theme_minimal()] is applied by default to all plots. \ No newline at end of file diff --git a/tests/testthat/test_visualization.R b/tests/testthat/test_visualization.R new file mode 100644 index 0000000..4777464 --- /dev/null +++ b/tests/testthat/test_visualization.R @@ -0,0 +1,17 @@ +rush_plan(n_workers = 2) +skip_if_not_installed("ranger") + +task = tsk("penguins") +learner = lrn("classif.auto", + learner_ids = "ranger", + small_data_size = 1, + resampling = rsmp("holdout"), + measure = msr("classif.ce"), + terminator = trm("evals", n_evals = 6) +) +learner$train(task) + +test_that("cost over time works", { + plot = cost_over_time(learner$instance) + expect_class(plot, c("gg", "ggplot")) +}) diff --git a/working.R b/working.R new file mode 100644 index 0000000..6b6ae2d --- /dev/null +++ b/working.R @@ -0,0 +1,28 @@ +rush_plan(n_workers = 3) +task = tsk("spam") +learner = lrn("classif.auto", + learner_ids = c("glmnet", "svm", "ranger"), + small_data_size = 1, + resampling = rsmp("holdout"), + measure = msr("classif.ce"), + terminator = trm("evals", n_evals = 20) +) +learner$train(task) + +# instance = learner$instance + + +# cost over time plot +cost_over_time = function(instance) { + archive_table = copy(learner$instance$archive$data) + set(archive_table, j = "config_id", value = seq_len(nrow(archive_table))) + + # there should be only a single objective, e.g. `classif.ce` + cost = instance$objective$codomain$data$id[[1]] + + ggplot2::ggplot(data = archive_table) + + ggplot2::geom_point(ggplot2::aes(x = config_id, y = .data[[cost]])) +} + +cost_over_time(instance) + + ggplot2::theme_bw() From 0041485141da494c03142e6b462334891aa7a1b7 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Thu, 29 Aug 2024 15:54:50 +0200 Subject: [PATCH 02/41] feat: marginal plot --- R/visualization.R | 58 +++++++++++++++++++++++--- R/visualization_app.R | 48 +++++++++++++++++---- man-roxygen/archive.R | 2 + man-roxygen/param_instance.R | 2 - man-roxygen/{param_theme.R => theme.R} | 0 5 files changed, 95 insertions(+), 15 deletions(-) create mode 100644 man-roxygen/archive.R delete mode 100644 man-roxygen/param_instance.R rename man-roxygen/{param_theme.R => theme.R} (100%) diff --git a/R/visualization.R b/R/visualization.R index 9785bc7..2e05bcd 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -2,17 +2,18 @@ #' #' @description Plots cost over time using [ggplot2]. #' -#' @template param_instance -#' @template param_theme +#' @template archive +#' @template theme #' #' @export -cost_over_time = function(instance, x = "config_id", theme = ggplot2::theme_minimal(), ...) { +cost_over_time = function(archive, theme = ggplot2::theme_minimal()) { # there should be only a single objective, e.g. `classif.ce` - cost = instance$objective$codomain$data$id[[1]] + cost = archive$codomain$data$id[[1]] .data = NULL - ggplot2::ggplot(data = instance$archive$data, ggplot2::aes( - x = seq_len(nrow(instance$archive$data)), + n_configs = nrow(archive$data) + ggplot2::ggplot(data = archive$data, ggplot2::aes( + x = seq_len(n_configs), y = .data[[cost]] )) + ggplot2::geom_point() + @@ -23,3 +24,48 @@ cost_over_time = function(instance, x = "config_id", theme = ggplot2::theme_mini ) + theme } + + +#' @title 2D Marginal Plot for Hyperparameters +#' +#' @description +#' +#' @template instance +#' @param x (`character(1)`) +#' Name of the parameter to be mapped to the x-axis. +#' @param y (`character(1)`) +#' Name of the parameter to be mapped to the y-axis. +#' @template theme +#' +#' @export +marginal_plot = function(archive, x = NULL, y = NULL, theme = ggplot2::theme_minimal()) { + param_ids = archive$search_space$data$id + assert_choice(x, param_ids) + assert_choice(y, param_ids) + + # there should be only a single objective, e.g. `classif.ce` + cost = archive$codomain$data$id[[1]] + + .data = NULL + g = ggplot2::ggplot(data = archive$data, ggplot2::aes( + x = .data[[x]], + y = .data[[y]], + fill = .data[[cost]] + )) + + ggplot2::geom_point() + + ggplot2::scale_fill_viridis_c() + + ggplot2::labs( + title = "2D marginal plot" + ) + + theme + + # log-scale params + if (archive$search_space$is_logscale[[x]]) { + g = g + ggplot2::scale_x_log10() + } + if (archive$search_space$is_logscale[[y]]) { + g = g + ggplot2::scale_y_log10() + } + + return(g) +} diff --git a/R/visualization_app.R b/R/visualization_app.R index 5581061..1af88a3 100644 --- a/R/visualization_app.R +++ b/R/visualization_app.R @@ -1,20 +1,54 @@ #' @title Shiny App for Visualizing AutoML Results #' -#' @template param_instance +#' @template archive #' @export -visualize = function(instance) { +visualize = function(archive) { + param_ids = archive$search_space$data$id + ui = bslib::page_navbar( + id = "nav", title = "Visualization for mlr3automl", - bslib::nav_panel("Cost over time", shiny::plotOutput("cost_over_time")), - bslib::nav_panel("Marginal plot"), - bslib::nav_panel("Parallel coordinates"), - bslib::nav_panel("Partial dependency plot") + sidebar = bslib::sidebar( + shiny::conditionalPanel( + "input.nav === 'Marginal Plots'", + shiny::selectInput("x", + label = "Select parameter for x-axis", + choices = param_ids, + ), + shiny::selectInput("y", + label = "Select parameter for y-axis", + choices = param_ids, + ) + # TBD: helpText + ) + ), + bslib::nav_panel( + "Cost Over Time", + bslib::card(shiny::plotOutput("cost_over_time")) + ), + bslib::nav_panel( + "Marginal Plots", + bslib::card(shiny::plotOutput("marginal_plot")) + ), + bslib::nav_panel( + "Parallel Coordinates", + bslib::card("TBD") + ), + bslib::nav_panel( + "Partial Dependency Plots", + bslib::card("TBD") + ) ) server = function(input, output, session) { session$onSessionEnded(stopApp) + output$cost_over_time = renderPlot({ - cost_over_time(instance) + cost_over_time(archive) + }) + + output$marginal_plot = renderPlot({ + marginal_plot(archive, x = input$x, y = input$y) }) } diff --git a/man-roxygen/archive.R b/man-roxygen/archive.R new file mode 100644 index 0000000..2fef8ae --- /dev/null +++ b/man-roxygen/archive.R @@ -0,0 +1,2 @@ +#' @param archive (`[ArchiveAsyncTuning]`) +#' The tuning archive to visualize. \ No newline at end of file diff --git a/man-roxygen/param_instance.R b/man-roxygen/param_instance.R deleted file mode 100644 index fdc1225..0000000 --- a/man-roxygen/param_instance.R +++ /dev/null @@ -1,2 +0,0 @@ -#' @param instance (`[TuningInstanceAsyncSingleCrit]`) -#' The tuning instance to visualize. \ No newline at end of file diff --git a/man-roxygen/param_theme.R b/man-roxygen/theme.R similarity index 100% rename from man-roxygen/param_theme.R rename to man-roxygen/theme.R From e60210bf0b687e347ddba4c1503564c2b098c054 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Thu, 29 Aug 2024 16:27:35 +0200 Subject: [PATCH 03/41] refactor: layout & comments --- R/visualization.R | 12 +++++++----- R/visualization_app.R | 10 ++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index 2e05bcd..4979dce 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -1,13 +1,13 @@ #' @title Cost-Over-Time Plot #' -#' @description Plots cost over time using [ggplot2]. +#' @description #' #' @template archive #' @template theme #' #' @export cost_over_time = function(archive, theme = ggplot2::theme_minimal()) { - # there should be only a single objective, e.g. `classif.ce` + # there should only be one objective, e.g. `classif.ce` cost = archive$codomain$data$id[[1]] .data = NULL @@ -43,11 +43,13 @@ marginal_plot = function(archive, x = NULL, y = NULL, theme = ggplot2::theme_min assert_choice(x, param_ids) assert_choice(y, param_ids) - # there should be only a single objective, e.g. `classif.ce` + # there should only be one objective, e.g. `classif.ce` cost = archive$codomain$data$id[[1]] - .data = NULL - g = ggplot2::ggplot(data = archive$data, ggplot2::aes( + data = na.omit(archive$data, cols = c(x, y)) + + .data = NULL + g = ggplot2::ggplot(data = data, ggplot2::aes( x = .data[[x]], y = .data[[y]], fill = .data[[cost]] diff --git a/R/visualization_app.R b/R/visualization_app.R index 1af88a3..3c5408e 100644 --- a/R/visualization_app.R +++ b/R/visualization_app.R @@ -11,15 +11,14 @@ visualize = function(archive) { sidebar = bslib::sidebar( shiny::conditionalPanel( "input.nav === 'Marginal Plots'", - shiny::selectInput("x", - label = "Select parameter for x-axis", - choices = param_ids, - ), shiny::selectInput("y", label = "Select parameter for y-axis", choices = param_ids, + ), + shiny::selectInput("x", + label = "Select parameter for x-axis", + choices = param_ids, ) - # TBD: helpText ) ), bslib::nav_panel( @@ -54,4 +53,3 @@ visualize = function(archive) { shiny::shinyApp(ui = ui, server = server) } - From f6c1b77e765c1b771c5ea6ec506c48446b175ca1 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Thu, 29 Aug 2024 18:29:05 +0200 Subject: [PATCH 04/41] feat: parallel coordinates --- R/visualization.R | 145 +++++++++++++++++++++++++++++++++++++----- R/visualization_app.R | 45 +++++++++++-- working.R | 2 +- 3 files changed, 171 insertions(+), 21 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index 4979dce..339a6af 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -8,13 +8,13 @@ #' @export cost_over_time = function(archive, theme = ggplot2::theme_minimal()) { # there should only be one objective, e.g. `classif.ce` - cost = archive$codomain$data$id[[1]] + measure = archive$codomain$data$id[[1]] .data = NULL n_configs = nrow(archive$data) ggplot2::ggplot(data = archive$data, ggplot2::aes( x = seq_len(n_configs), - y = .data[[cost]] + y = .data[[measure]] )) + ggplot2::geom_point() + ggplot2::geom_line() + @@ -25,43 +25,59 @@ cost_over_time = function(archive, theme = ggplot2::theme_minimal()) { theme } - -#' @title 2D Marginal Plot for Hyperparameters +#' @title Marginal Plot for Hyperparameters #' #' @description #' #' @template instance #' @param x (`character(1)`) -#' Name of the parameter to be mapped to the x-axis. +#' Name of the parameter to be mapped to the x-axis. #' @param y (`character(1)`) -#' Name of the parameter to be mapped to the y-axis. +#' Name of the parameter to be mapped to the y-axis. +#' If `NULL` (default), the measure (e.g. `classif.ce`) is mapped to the y-axis. #' @template theme #' #' @export -marginal_plot = function(archive, x = NULL, y = NULL, theme = ggplot2::theme_minimal()) { - param_ids = archive$search_space$data$id +marginal_plot = function(archive, x, y = NULL, theme = ggplot2::theme_minimal()) { + param_ids = archive$cols_x assert_choice(x, param_ids) - assert_choice(y, param_ids) + assert_choice(y, param_ids, null.ok = TRUE) # there should only be one objective, e.g. `classif.ce` - cost = archive$codomain$data$id[[1]] + measure = archive$cols_y data = na.omit(archive$data, cols = c(x, y)) .data = NULL + + # no param provided for y + if (is.null(y)) { + g = ggplot2:: ggplot(data = data, ggplot2::aes( + x = .data[[x]], + y = .data[[measure]] + )) + + ggplot2::geom_point(alpha = 0.6) + + ggplot2::labs(title = "Marginal plot") + + theme + + if (archive$search_space$is_logscale[[x]]) { + g = g + ggplot2::scale_x_log10() + } + + return(g) + } + + # param provided for y g = ggplot2::ggplot(data = data, ggplot2::aes( x = .data[[x]], y = .data[[y]], - fill = .data[[cost]] + fill = .data[[measure]] )) + - ggplot2::geom_point() + + ggplot2::geom_point(alpha = 0.6) + ggplot2::scale_fill_viridis_c() + - ggplot2::labs( - title = "2D marginal plot" - ) + + ggplot2::labs(title = "Marginal plot") + theme - # log-scale params if (archive$search_space$is_logscale[[x]]) { g = g + ggplot2::scale_x_log10() } @@ -71,3 +87,100 @@ marginal_plot = function(archive, x = NULL, y = NULL, theme = ggplot2::theme_min return(g) } + + +#' @title Parallel Coordinates Plot +#' +#' @description Adapted from [mlr3viz::autoplot()] with `type == "parallel"`. Since the hyperparameters of each individual learner are conditioned on `branch.selection`, missing values are expected in the archive data. When standardizing the hyperparameter values (referred to as "x values" in the following to be consistent with `mlr3viz` documentation), `na.omit == TRUE` is used to compute `mean()` and `sd()`. +#' +#' @template archive +#' @param cols_x (`character()`) +#' Column names of x values. +#' By default, all untransformed x values from the search space are plotted. +#' @param trafo (`character(1)`) +#' If `FALSE` (default), the untransformed x values are plotted. +#' If `TRUE`, the transformed x values are plotted. +#' @template theme +#' +#' @export +parallel_coordinates = function(archive, cols_x = NULL, trafo = FALSE, theme = ggplot2::theme_minimal()) { + assert_subset(cols_x, c(archive$cols_x, paste0("x_domain_", archive$cols_x))) + assert_flag(trafo) + + if (is.null(cols_x)) { + cols_x = archive$cols_x + } + if (trafo) { + cols_x = paste0("x_domain_", cols_x) + } + cols_y = archive$cols_y + + data = as.data.table(archive) + data = data[, c(cols_x, cols_y), with = FALSE] + x_axis = data.table(x = seq(names(data)), variable = names(data)) + + # split data + data_l = data[, .SD, .SDcols = which(sapply(data, function(x) is.character(x) || is.logical(x)))] + data_n = data[, .SD, .SDcols = which(sapply(data, is.numeric))] + data_y = data[, cols_y, with = FALSE] + + # factor columns to numeric + data_c = data_l[, lapply(.SD, function(x) as.numeric(as.factor(x)))] + + # rescale + data_n = data_n[, lapply(.SD, function(x) { + if (sd(x, na.rm = TRUE) > 0) { + (x - mean(x, na.rm = TRUE)) / sd(x, na.rm = TRUE) + } else { + rep(0, length(x)) + } + })] + data_c = data_c[, lapply(.SD, function(x) { + if (sd(x, na.rm = TRUE) > 0) { + (x - mean(unique(x), na.rm = TRUE)) / sd(unique(x), na.rm = TRUE) + } else { + rep(0, length(x)) + } + })] + + # to long format + set(data_n, j = "id", value = seq_row(data_n)) + set(data_y, j = "id", value = seq_row(data_y)) + data_n = melt(data_n, measure.var = setdiff(names(data_n), "id")) + + if (nrow(data_c)) { + # Skip if no factor column is present + set(data_c, j = "id", value = seq_row(data_c)) + data_c = melt(data_c, measure.var = setdiff(names(data_c), "id")) + data_l = data_l[, lapply(.SD, as.character)] # Logical to character + data_l = melt(data_l, measure.var = names(data_l), value.name = "label")[, "label"] + set(data_c, j = "label", value = data_l) + } + + # merge + data = rbindlist(list(data_c, data_n), fill = TRUE) + data = merge(data, x_axis, by = "variable") + data = merge(data, data_y, by = "id") + setorderv(data, "x") + + ggplot2::ggplot(data, + mapping = ggplot2::aes( + x = .data[["x"]], + y = .data[["value"]])) + + ggplot2::geom_line( + mapping = ggplot2::aes( + group = .data$id, + color = .data[[cols_y]]), + linewidth = 1) + + ggplot2::geom_vline(ggplot2::aes(xintercept = x)) + + { + if (nrow(data_c)) ggplot2::geom_label( + mapping = ggplot2::aes(label = .data$label), + data = data[!is.na(data$label), ]) + } + + ggplot2::scale_x_continuous(breaks = x_axis$x, labels = x_axis$variable) + + ggplot2::scale_color_viridis_c() + + ggplot2::guides(color = ggplot2::guide_colorbar(barwidth = 0.5, barheight = 10)) + + theme + + ggplot2::theme(axis.title.x = ggplot2::element_blank()) +} diff --git a/R/visualization_app.R b/R/visualization_app.R index 3c5408e..9cf666d 100644 --- a/R/visualization_app.R +++ b/R/visualization_app.R @@ -3,7 +3,7 @@ #' @template archive #' @export visualize = function(archive) { - param_ids = archive$search_space$data$id + param_ids = archive$cols_x ui = bslib::page_navbar( id = "nav", @@ -13,12 +13,33 @@ visualize = function(archive) { "input.nav === 'Marginal Plots'", shiny::selectInput("y", label = "Select parameter for y-axis", - choices = param_ids, + choices = c(param_ids, "NULL"), ), shiny::selectInput("x", label = "Select parameter for x-axis", choices = param_ids, ) + ), + shiny::conditionalPanel( + "input.nav === 'Parallel Coordinates'", + shiny::checkboxGroupInput("cols_x", + label = "Select hyperparameters to plot:", + choices = param_ids, + # select all by default + selected = param_ids + ), + shiny::actionButton("unselect_all", + label = "Unselect all" + ), + shiny::actionButton("select_all", + label = "Select all" + ), + shiny::radioButtons("trafo", + label = "Apply transformation?", + choices = list("No", "Yes"), + selected = "No", + inline = TRUE + ) ) ), bslib::nav_panel( @@ -31,7 +52,7 @@ visualize = function(archive) { ), bslib::nav_panel( "Parallel Coordinates", - bslib::card("TBD") + bslib::card(shiny::plotOutput("parallel_coordinates")) ), bslib::nav_panel( "Partial Dependency Plots", @@ -47,7 +68,23 @@ visualize = function(archive) { }) output$marginal_plot = renderPlot({ - marginal_plot(archive, x = input$x, y = input$y) + if (input$y == "NULL") { + marginal_plot(archive, x = input$x) + } else { + marginal_plot(archive, x = input$x, y = input$y) + } + }) + + output$parallel_coordinates = renderPlot({ + if (is.null(input$cols_x)) return() # nothing selected + trafo = input$trafo == "yes" + parallel_coordinates(archive, cols_x = input$cols_x, trafo = trafo) + }) + shiny::observeEvent(input$unselect_all, { + shiny::updateCheckboxGroupInput(session, "cols_x", choices = param_ids, selected = NULL) + }) + shiny::observeEvent(input$select_all, { + shiny::updateCheckboxGroupInput(session, "cols_x", choices = param_ids, selected = param_ids) }) } diff --git a/working.R b/working.R index 6b6ae2d..090b6f6 100644 --- a/working.R +++ b/working.R @@ -1,7 +1,7 @@ rush_plan(n_workers = 3) task = tsk("spam") learner = lrn("classif.auto", - learner_ids = c("glmnet", "svm", "ranger"), + learner_ids = c("glmnet", "svm"), small_data_size = 1, resampling = rsmp("holdout"), measure = msr("classif.ce"), From 007d3b94024297295f5f9e1966965d7816c7e41d Mon Sep 17 00:00:00 2001 From: b-zhou Date: Wed, 4 Sep 2024 13:59:33 +0200 Subject: [PATCH 05/41] fix: trafo --- R/visualization.R | 32 +++++++++++++++++--------------- R/visualization_app.R | 2 +- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index 339a6af..3f24ebc 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -8,12 +8,11 @@ #' @export cost_over_time = function(archive, theme = ggplot2::theme_minimal()) { # there should only be one objective, e.g. `classif.ce` - measure = archive$codomain$data$id[[1]] + measure = archive$cols_y .data = NULL - n_configs = nrow(archive$data) ggplot2::ggplot(data = archive$data, ggplot2::aes( - x = seq_len(n_configs), + x = seq_row(archive$data), y = .data[[measure]] )) + ggplot2::geom_point() + @@ -43,17 +42,20 @@ marginal_plot = function(archive, x, y = NULL, theme = ggplot2::theme_minimal()) assert_choice(x, param_ids) assert_choice(y, param_ids, null.ok = TRUE) + # use transformed values if trafo is set + x_trafo = paste0("x_domain_", x) + # there should only be one objective, e.g. `classif.ce` measure = archive$cols_y - data = na.omit(archive$data, cols = c(x, y)) + data = na.omit(as.data.table(archive), cols = c(x_trafo, y)) .data = NULL # no param provided for y if (is.null(y)) { g = ggplot2:: ggplot(data = data, ggplot2::aes( - x = .data[[x]], + x = .data[[x_trafo]], y = .data[[measure]] )) + ggplot2::geom_point(alpha = 0.6) + @@ -69,13 +71,13 @@ marginal_plot = function(archive, x, y = NULL, theme = ggplot2::theme_minimal()) # param provided for y g = ggplot2::ggplot(data = data, ggplot2::aes( - x = .data[[x]], + x = .data[[x_trafo]], y = .data[[y]], - fill = .data[[measure]] + col = .data[[measure]] )) + ggplot2::geom_point(alpha = 0.6) + - ggplot2::scale_fill_viridis_c() + - ggplot2::labs(title = "Marginal plot") + + ggplot2::scale_color_viridis_c() + + ggplot2::labs(title = "Marginal plot", x = x) + theme if (archive$search_space$is_logscale[[x]]) { @@ -129,17 +131,17 @@ parallel_coordinates = function(archive, cols_x = NULL, trafo = FALSE, theme = g # rescale data_n = data_n[, lapply(.SD, function(x) { - if (sd(x, na.rm = TRUE) > 0) { - (x - mean(x, na.rm = TRUE)) / sd(x, na.rm = TRUE) - } else { + if (sd(x, na.rm = TRUE) %in% c(0, NA)) { rep(0, length(x)) + } else { + (x - mean(x, na.rm = TRUE)) / sd(x, na.rm = TRUE) } })] data_c = data_c[, lapply(.SD, function(x) { - if (sd(x, na.rm = TRUE) > 0) { - (x - mean(unique(x), na.rm = TRUE)) / sd(unique(x), na.rm = TRUE) - } else { + if (sd(x, na.rm = TRUE) %in% c(0, NA)) { rep(0, length(x)) + } else { + (x - mean(unique(x), na.rm = TRUE)) / sd(unique(x), na.rm = TRUE) } })] diff --git a/R/visualization_app.R b/R/visualization_app.R index 9cf666d..4061998 100644 --- a/R/visualization_app.R +++ b/R/visualization_app.R @@ -77,7 +77,7 @@ visualize = function(archive) { output$parallel_coordinates = renderPlot({ if (is.null(input$cols_x)) return() # nothing selected - trafo = input$trafo == "yes" + trafo = input$trafo == "Yes" parallel_coordinates(archive, cols_x = input$cols_x, trafo = trafo) }) shiny::observeEvent(input$unselect_all, { From f8966f98a211075ab4bcf5ee0c90b7f03a41235a Mon Sep 17 00:00:00 2001 From: b-zhou Date: Wed, 4 Sep 2024 14:00:17 +0200 Subject: [PATCH 06/41] feat: pdp --- DESCRIPTION | 1 + R/visualization.R | 111 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) diff --git a/DESCRIPTION b/DESCRIPTION index 56549a0..8959a21 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -22,6 +22,7 @@ Imports: checkmate, data.table, ggplot2, + iml, lhs, mlr3mbo, mlr3misc (>= 0.15.1), diff --git a/R/visualization.R b/R/visualization.R index 3f24ebc..6869140 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -186,3 +186,114 @@ parallel_coordinates = function(archive, cols_x = NULL, trafo = FALSE, theme = g theme + ggplot2::theme(axis.title.x = ggplot2::element_blank()) } + + +#' @title Partial Dependence Plot +#' +#' @description Creates a partial dependenc plot (PDP) via the `[iml]` package. +#' +#' @param instance ([TuningInstanceSingleCritAsync]) +#' Tuning instance, e.g., stored in the field `$instance` of any `mlr3automl` learner. +#' @param x (`character(1)`) +#' Name of the parameter to be mapped to the x-axis. +#' @param y (`character(1)`) +#' Name of the parameter to be mapped to the y-axis. +#' If `NULL` (default), the measure (e.g. `classif.ce`) is mapped to the y-axis. +#' @param grid_size (`numeric(1)` | `numeric(2)`) +#' The size of the grid. See `grid.size` of `[iml::FeatureEffect]`. +#' @param center_at (`numeric(1)`) +#' Value at which the plot was centered. Ignored in the case of two features. +#' See `center.at` of `[iml::FeatureEffect]`. +#' @param type (`character(1)`) +#' Type of the two-parameter partial dependence plot. Possible options are listed below. +#' \itemize{ +#' \item `"contour"`: Create a contour plot using `[ggplot2::geom_contour_filled]`. Only supported if both parameters are numerical. +#' \item `"heatmap"`: Create a heatmap using `[ggplot2::geom_raster]`. This is the default setting in `iml` +#' } +#' Ignored if only one parameter is provided. +#' @template theme +#' +#' @export +partial_dependence_plot = function( + instance, x, y = NULL, grid_size = 20, center_at = NULL, + type = "heatmap", + theme = ggplot2::theme_minimal() +) { + archive = instance$archive$clone(deep = TRUE) + param_ids = c(x, y) + assert_subset(param_ids, archive$cols_x) + + branch = tstrsplit(param_ids, "\\.")[[1]] + branch = unique(branch) + if (length(branch) > 1) { + stop("Parameters from different branches cannot be plotted in the same PDP.") + } + + if (!is.null(y)) { + assert_choice(type, c("contour", "heatmap")) + } + + non_numeric = some(param_ids, function(param_id) { + !is.numeric(archive$data[[param_id]]) + }) + if (non_numeric && type == "contour") { + stop("Contour plot not supported for non-numeric parameters") + } + + # iml does not allow logical columns, so encode into factor + # NOT WORKING + walk(param_ids, function(param_id) { + if (is.logical(archive$data[[param_id]])) { + set(archive$data, j = param_id, value = as.factor(archive$data[[param_id]])) + } + }) + + surrogate = default_surrogate(instance) + surrogate$archive = archive + surrogate$update() + + predictor = iml::Predictor$new( + model = surrogate, + data = as.data.table(archive)[branch.selection == branch, param_ids, with = FALSE], + predict.function = function(model, newdata) { + model$predict(setDT(newdata)[, param_ids, with = FALSE])$mean + } + ) + + eff = iml::FeatureEffect$new( + predictor, + param_ids, + method = "pdp", + center.at = center_at, + grid.size = grid_size + ) + + measure = archive$cols_y + .data = NULL + + if (is.null(y)) { + g = eff$plot() + + ggplot2::scale_fill_viridis_c(direction = -1) + + ggplot2::labs(fill = measure) + + theme + return(g) + } + + g = switch(type, + + contour = ggplot2::ggplot(eff$results, ggplot2::aes( + x = .data[[x]], y = .data[[y]], z = .data$.value + )) + + ggplot2::geom_contour_filled() + + ggplot2::scale_fill_viridis_d(direction = -1), + + heatmap = ggplot2::ggplot(eff$results, ggplot2::aes( + x = .data[[x]], y = .data[[y]], + fill = .data$.value, color = .data$.value + )) + + ggplot2::geom_raster() + + ggplot2::scale_fill_viridis_c(direction = -1) + ) + + g + ggplot2::labs(fill = measure) + theme +} From ff704444e519be1eaa0dd7f3934be56d741ddbe3 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Wed, 4 Sep 2024 14:38:06 +0200 Subject: [PATCH 07/41] fix: feature type --- R/visualization.R | 18 ++++++++---------- working.R | 35 +++++++++++++++++++++-------------- 2 files changed, 29 insertions(+), 24 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index 6869140..eb13ac8 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -219,8 +219,9 @@ partial_dependence_plot = function( type = "heatmap", theme = ggplot2::theme_minimal() ) { - archive = instance$archive$clone(deep = TRUE) + archive = instance$archive param_ids = c(x, y) + objective = archive$cols_y assert_subset(param_ids, archive$cols_x) branch = tstrsplit(param_ids, "\\.")[[1]] @@ -242,11 +243,11 @@ partial_dependence_plot = function( # iml does not allow logical columns, so encode into factor # NOT WORKING - walk(param_ids, function(param_id) { - if (is.logical(archive$data[[param_id]])) { - set(archive$data, j = param_id, value = as.factor(archive$data[[param_id]])) - } - }) + # walk(param_ids, function(param_id) { + # if (is.logical(archive$data[[param_id]])) { + # set(archive$data, j = param_id, value = as.factor(archive$data[[param_id]])) + # } + # }) surrogate = default_surrogate(instance) surrogate$archive = archive @@ -255,9 +256,7 @@ partial_dependence_plot = function( predictor = iml::Predictor$new( model = surrogate, data = as.data.table(archive)[branch.selection == branch, param_ids, with = FALSE], - predict.function = function(model, newdata) { - model$predict(setDT(newdata)[, param_ids, with = FALSE])$mean - } + y = as.data.table(archive)[branch.selection = branch, objective, with = FALSE] ) eff = iml::FeatureEffect$new( @@ -268,7 +267,6 @@ partial_dependence_plot = function( grid.size = grid_size ) - measure = archive$cols_y .data = NULL if (is.null(y)) { diff --git a/working.R b/working.R index 090b6f6..2fc1f65 100644 --- a/working.R +++ b/working.R @@ -1,7 +1,7 @@ rush_plan(n_workers = 3) task = tsk("spam") learner = lrn("classif.auto", - learner_ids = c("glmnet", "svm"), + learner_ids = c("ranger", "glmnet", "svm"), small_data_size = 1, resampling = rsmp("holdout"), measure = msr("classif.ce"), @@ -9,20 +9,27 @@ learner = lrn("classif.auto", ) learner$train(task) -# instance = learner$instance +archive = learner$instance$archive$clone(deep = TRUE) +param_ids = archive$cols_x[startsWith(archive$cols_x, "glmnet")] +branch = "glmnet" -# cost over time plot -cost_over_time = function(instance) { - archive_table = copy(learner$instance$archive$data) - set(archive_table, j = "config_id", value = seq_len(nrow(archive_table))) +surrogate = default_surrogate(learner$instance) +surrogate$archive = archive +surrogate$update() - # there should be only a single objective, e.g. `classif.ce` - cost = instance$objective$codomain$data$id[[1]] - - ggplot2::ggplot(data = archive_table) + - ggplot2::geom_point(ggplot2::aes(x = config_id, y = .data[[cost]])) -} +predictor = iml::Predictor$new( + model = surrogate, + data = as.data.table(archive)[branch.selection == branch, param_ids, with = FALSE], + predict.function = function(model, newdata) { + model$predict(setDT(newdata)[, param_ids, with = FALSE])$mean + } +) -cost_over_time(instance) + - ggplot2::theme_bw() +effects = iml::FeatureEffect$new( + predictor, + param_ids, + method = "pdp" +) +effects$plot() + + ggplot2::scale_fill_viridis_c() From cecca8074f8ab8c97172403d47205ab654a7b530 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Wed, 4 Sep 2024 17:29:01 +0200 Subject: [PATCH 08/41] fix: imputeoor --- R/visualization.R | 45 ++++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index eb13ac8..c4f2c63 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -207,8 +207,8 @@ parallel_coordinates = function(archive, cols_x = NULL, trafo = FALSE, theme = g #' @param type (`character(1)`) #' Type of the two-parameter partial dependence plot. Possible options are listed below. #' \itemize{ +#' \item `"default"`: Use the default setting in `iml`. #' \item `"contour"`: Create a contour plot using `[ggplot2::geom_contour_filled]`. Only supported if both parameters are numerical. -#' \item `"heatmap"`: Create a heatmap using `[ggplot2::geom_raster]`. This is the default setting in `iml` #' } #' Ignored if only one parameter is provided. #' @template theme @@ -220,43 +220,47 @@ partial_dependence_plot = function( theme = ggplot2::theme_minimal() ) { archive = instance$archive - param_ids = c(x, y) objective = archive$cols_y - assert_subset(param_ids, archive$cols_x) + assert_subset(c(x, y), archive$cols_x) - branch = tstrsplit(param_ids, "\\.")[[1]] + branch = tstrsplit(c(x, y), "\\.")[[1]] branch = unique(branch) if (length(branch) > 1) { stop("Parameters from different branches cannot be plotted in the same PDP.") } if (!is.null(y)) { - assert_choice(type, c("contour", "heatmap")) + assert_choice(type, c("contour", "default")) } - non_numeric = some(param_ids, function(param_id) { + non_numeric = some(c(x, y), function(param_id) { !is.numeric(archive$data[[param_id]]) }) if (non_numeric && type == "contour") { stop("Contour plot not supported for non-numeric parameters") } - # iml does not allow logical columns, so encode into factor - # NOT WORKING - # walk(param_ids, function(param_id) { - # if (is.logical(archive$data[[param_id]])) { - # set(archive$data, j = param_id, value = as.factor(archive$data[[param_id]])) - # } - # }) + # use all parameters on the branch for surrogate model + # param_ids = archive$cols_x[startsWith(archive$cols_x, branch)] + param_ids = c(x, y) surrogate = default_surrogate(instance) surrogate$archive = archive surrogate$update() + # store the data.table format for later use in predict.function + prototype = archive$data[0, archive$cols_x, with = FALSE] + predictor = iml::Predictor$new( model = surrogate, data = as.data.table(archive)[branch.selection == branch, param_ids, with = FALSE], - y = as.data.table(archive)[branch.selection = branch, objective, with = FALSE] + y = as.data.table(archive)[branch.selection == branch, objective, with = FALSE], + predict.function = function(model, newdata) { + setDT(newdata) + # reconstruct task layout from training to prevent error in imputeoor + newdata = merge(newdata, prototype, by = param_ids, all = TRUE) + return(model$predict(newdata)$mean) + } ) eff = iml::FeatureEffect$new( @@ -277,6 +281,8 @@ partial_dependence_plot = function( return(g) } + # x = param_ids[[1]] + # y = param_ids[[2]] g = switch(type, contour = ggplot2::ggplot(eff$results, ggplot2::aes( @@ -285,13 +291,10 @@ partial_dependence_plot = function( ggplot2::geom_contour_filled() + ggplot2::scale_fill_viridis_d(direction = -1), - heatmap = ggplot2::ggplot(eff$results, ggplot2::aes( - x = .data[[x]], y = .data[[y]], - fill = .data$.value, color = .data$.value - )) + - ggplot2::geom_raster() + - ggplot2::scale_fill_viridis_c(direction = -1) + default = eff$plot() ) + + # TBD: remove existing scales, use viridis instead - g + ggplot2::labs(fill = measure) + theme + g + ggplot2::labs(fill = objective) + theme } From ca190bbaf37f001d37b72a42b18d3b3ac25ad25b Mon Sep 17 00:00:00 2001 From: b-zhou Date: Fri, 6 Sep 2024 13:04:08 +0200 Subject: [PATCH 09/41] fix: imputeoor --- R/visualization.R | 45 ++++++++++++++++++++------------------------- working.R | 35 ----------------------------------- 2 files changed, 20 insertions(+), 60 deletions(-) delete mode 100644 working.R diff --git a/R/visualization.R b/R/visualization.R index c4f2c63..32a1de4 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -28,7 +28,7 @@ cost_over_time = function(archive, theme = ggplot2::theme_minimal()) { #' #' @description #' -#' @template instance +#' @param instance (`[mlr3tuning::TuningInstanceAsyncSingleCrit]`) #' @param x (`character(1)`) #' Name of the parameter to be mapped to the x-axis. #' @param y (`character(1)`) @@ -192,13 +192,11 @@ parallel_coordinates = function(archive, cols_x = NULL, trafo = FALSE, theme = g #' #' @description Creates a partial dependenc plot (PDP) via the `[iml]` package. #' -#' @param instance ([TuningInstanceSingleCritAsync]) -#' Tuning instance, e.g., stored in the field `$instance` of any `mlr3automl` learner. +#' @param instance ([`mlr3tuning::TuningInstanceAsyncSingleCrit`]) #' @param x (`character(1)`) #' Name of the parameter to be mapped to the x-axis. #' @param y (`character(1)`) #' Name of the parameter to be mapped to the y-axis. -#' If `NULL` (default), the measure (e.g. `classif.ce`) is mapped to the y-axis. #' @param grid_size (`numeric(1)` | `numeric(2)`) #' The size of the grid. See `grid.size` of `[iml::FeatureEffect]`. #' @param center_at (`numeric(1)`) @@ -216,14 +214,17 @@ parallel_coordinates = function(archive, cols_x = NULL, trafo = FALSE, theme = g #' @export partial_dependence_plot = function( instance, x, y = NULL, grid_size = 20, center_at = NULL, - type = "heatmap", + type = "default", theme = ggplot2::theme_minimal() ) { archive = instance$archive objective = archive$cols_y - assert_subset(c(x, y), archive$cols_x) + assert_choice(x, archive$cols_x) + assert_choice(y, archive$cols_x, null.ok = TRUE) + + param_ids = c(x, y) - branch = tstrsplit(c(x, y), "\\.")[[1]] + branch = tstrsplit(param_ids, "\\.")[[1]] branch = unique(branch) if (length(branch) > 1) { stop("Parameters from different branches cannot be plotted in the same PDP.") @@ -233,17 +234,13 @@ partial_dependence_plot = function( assert_choice(type, c("contour", "default")) } - non_numeric = some(c(x, y), function(param_id) { + non_numeric = some(param_ids, function(param_id) { !is.numeric(archive$data[[param_id]]) }) if (non_numeric && type == "contour") { stop("Contour plot not supported for non-numeric parameters") } - # use all parameters on the branch for surrogate model - # param_ids = archive$cols_x[startsWith(archive$cols_x, branch)] - param_ids = c(x, y) - surrogate = default_surrogate(instance) surrogate$archive = archive surrogate$update() @@ -259,6 +256,15 @@ partial_dependence_plot = function( setDT(newdata) # reconstruct task layout from training to prevent error in imputeoor newdata = merge(newdata, prototype, by = param_ids, all = TRUE) + setcolorder(newdata, names(prototype)) + + # convert numeric to integer to match with training task + walk(param_ids, function(param_id) { + if (is.integer(archive$data[[param_id]])) { + set(newdata, j = param_id, value = as.integer(newdata[[param_id]])) + } + }) + return(model$predict(newdata)$mean) } ) @@ -273,16 +279,6 @@ partial_dependence_plot = function( .data = NULL - if (is.null(y)) { - g = eff$plot() + - ggplot2::scale_fill_viridis_c(direction = -1) + - ggplot2::labs(fill = measure) + - theme - return(g) - } - - # x = param_ids[[1]] - # y = param_ids[[2]] g = switch(type, contour = ggplot2::ggplot(eff$results, ggplot2::aes( @@ -291,10 +287,9 @@ partial_dependence_plot = function( ggplot2::geom_contour_filled() + ggplot2::scale_fill_viridis_d(direction = -1), - default = eff$plot() + default = eff$plot(rug = FALSE) ) # TBD: remove existing scales, use viridis instead - - g + ggplot2::labs(fill = objective) + theme + g + ggplot2::scale_fill_viridis_c(name = objective) + theme } diff --git a/working.R b/working.R deleted file mode 100644 index 2fc1f65..0000000 --- a/working.R +++ /dev/null @@ -1,35 +0,0 @@ -rush_plan(n_workers = 3) -task = tsk("spam") -learner = lrn("classif.auto", - learner_ids = c("ranger", "glmnet", "svm"), - small_data_size = 1, - resampling = rsmp("holdout"), - measure = msr("classif.ce"), - terminator = trm("evals", n_evals = 20) -) -learner$train(task) - -archive = learner$instance$archive$clone(deep = TRUE) - -param_ids = archive$cols_x[startsWith(archive$cols_x, "glmnet")] -branch = "glmnet" - -surrogate = default_surrogate(learner$instance) -surrogate$archive = archive -surrogate$update() - -predictor = iml::Predictor$new( - model = surrogate, - data = as.data.table(archive)[branch.selection == branch, param_ids, with = FALSE], - predict.function = function(model, newdata) { - model$predict(setDT(newdata)[, param_ids, with = FALSE])$mean - } -) - -effects = iml::FeatureEffect$new( - predictor, - param_ids, - method = "pdp" -) -effects$plot() + - ggplot2::scale_fill_viridis_c() From 608c92f2c207c45255faafa20ed529e8fdc201b7 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Fri, 6 Sep 2024 13:04:34 +0200 Subject: [PATCH 10/41] feat: pdp app --- R/visualization_app.R | 41 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/R/visualization_app.R b/R/visualization_app.R index 4061998..dacbba9 100644 --- a/R/visualization_app.R +++ b/R/visualization_app.R @@ -1,9 +1,11 @@ #' @title Shiny App for Visualizing AutoML Results #' -#' @template archive +#' @param instance (`[mlr3tuning::TuningInstanceAsyncSingleCrit]`) #' @export -visualize = function(archive) { +visualize = function(instance) { + archive = instance$archive param_ids = archive$cols_x + branches = unique(archive$data$branch.selection) ui = bslib::page_navbar( id = "nav", @@ -40,6 +42,23 @@ visualize = function(archive) { selected = "No", inline = TRUE ) + ), + shiny::conditionalPanel( + "input.nav === 'Partial Dependence Plots'", + shiny::radioButtons("select_branch", + label = "Select branch:", + choices = branches + ), + shiny::selectInput("select_x", + label = "Select parameter for x-axis:", + choices = param_ids, + selected = param_ids[[1]] + ), + shiny::selectInput("select_y", + label = "Select parameter for y-axis:", + choices = param_ids, + selected = param_ids[[2]] + ) ) ), bslib::nav_panel( @@ -55,8 +74,8 @@ visualize = function(archive) { bslib::card(shiny::plotOutput("parallel_coordinates")) ), bslib::nav_panel( - "Partial Dependency Plots", - bslib::card("TBD") + "Partial Dependence Plots", + bslib::card(shiny::plotOutput("pdp")) ) ) @@ -86,6 +105,20 @@ visualize = function(archive) { shiny::observeEvent(input$select_all, { shiny::updateCheckboxGroupInput(session, "cols_x", choices = param_ids, selected = param_ids) }) + + + shiny::observeEvent(input$select_branch, { + selectable_ids = param_ids[startsWith(param_ids, input$select_branch)] + shiny::updateSelectInput(session, "select_x", choices = selectable_ids, selected = selectable_ids[[1]]) + shiny::updateSelectInput(session, "select_y", choices = selectable_ids, selected = selectable_ids[[2]]) + }) + output$pdp = renderPlot({ + if (is.null(input$select_x)) return() + partial_dependence_plot( + instance, x = input$select_x, y = input$select_y, + type = "default", grid_size = 20 + ) + }) } shiny::shinyApp(ui = ui, server = server) From aadf2a26952c4093b14f964732c8330db58aa64f Mon Sep 17 00:00:00 2001 From: b-zhou Date: Fri, 6 Sep 2024 13:17:16 +0200 Subject: [PATCH 11/41] refactor: rename shiny components --- R/visualization_app.R | 53 +++++++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/R/visualization_app.R b/R/visualization_app.R index dacbba9..94bc699 100644 --- a/R/visualization_app.R +++ b/R/visualization_app.R @@ -10,33 +10,36 @@ visualize = function(instance) { ui = bslib::page_navbar( id = "nav", title = "Visualization for mlr3automl", + # TBD: cost over time: select timestamp_x / timestamp_y / config_id sidebar = bslib::sidebar( shiny::conditionalPanel( + # TBD: select branch, then parameter, as in PDP "input.nav === 'Marginal Plots'", - shiny::selectInput("y", + shiny::selectInput("mp_y", label = "Select parameter for y-axis", choices = c(param_ids, "NULL"), ), - shiny::selectInput("x", + shiny::selectInput("mp_x", label = "Select parameter for x-axis", choices = param_ids, ) ), shiny::conditionalPanel( + # TBD: select branch, then parameter, as in PDP "input.nav === 'Parallel Coordinates'", - shiny::checkboxGroupInput("cols_x", + shiny::checkboxGroupInput("pc_cols_x", label = "Select hyperparameters to plot:", choices = param_ids, # select all by default selected = param_ids ), - shiny::actionButton("unselect_all", + shiny::actionButton("pc_unselect_all", label = "Unselect all" ), - shiny::actionButton("select_all", + shiny::actionButton("pc_select_all", label = "Select all" ), - shiny::radioButtons("trafo", + shiny::radioButtons("pc_trafo", label = "Apply transformation?", choices = list("No", "Yes"), selected = "No", @@ -45,16 +48,16 @@ visualize = function(instance) { ), shiny::conditionalPanel( "input.nav === 'Partial Dependence Plots'", - shiny::radioButtons("select_branch", + shiny::radioButtons("pdp_branch", label = "Select branch:", choices = branches ), - shiny::selectInput("select_x", + shiny::selectInput("pdp_x", label = "Select parameter for x-axis:", choices = param_ids, selected = param_ids[[1]] ), - shiny::selectInput("select_y", + shiny::selectInput("pdp_y", label = "Select parameter for y-axis:", choices = param_ids, selected = param_ids[[2]] @@ -87,35 +90,35 @@ visualize = function(instance) { }) output$marginal_plot = renderPlot({ - if (input$y == "NULL") { - marginal_plot(archive, x = input$x) + if (input$mp_y == "NULL") { + marginal_plot(archive, x = input$mp_x) } else { - marginal_plot(archive, x = input$x, y = input$y) + marginal_plot(archive, x = input$mp_x, y = input$mp_y) } }) output$parallel_coordinates = renderPlot({ - if (is.null(input$cols_x)) return() # nothing selected - trafo = input$trafo == "Yes" - parallel_coordinates(archive, cols_x = input$cols_x, trafo = trafo) + if (is.null(input$pc_cols_x)) return() # nothing selected + trafo = input$pc_trafo == "Yes" + parallel_coordinates(archive, cols_x = input$pc_cols_x, trafo = trafo) }) - shiny::observeEvent(input$unselect_all, { - shiny::updateCheckboxGroupInput(session, "cols_x", choices = param_ids, selected = NULL) + shiny::observeEvent(input$pc_unselect_all, { + shiny::updateCheckboxGroupInput(session, "pc_cols_x", choices = param_ids, selected = NULL) }) - shiny::observeEvent(input$select_all, { - shiny::updateCheckboxGroupInput(session, "cols_x", choices = param_ids, selected = param_ids) + shiny::observeEvent(input$pc_select_all, { + shiny::updateCheckboxGroupInput(session, "pc_cols_x", choices = param_ids, selected = param_ids) }) - shiny::observeEvent(input$select_branch, { - selectable_ids = param_ids[startsWith(param_ids, input$select_branch)] - shiny::updateSelectInput(session, "select_x", choices = selectable_ids, selected = selectable_ids[[1]]) - shiny::updateSelectInput(session, "select_y", choices = selectable_ids, selected = selectable_ids[[2]]) + shiny::observeEvent(input$pdp_branch, { + selectable_ids = param_ids[startsWith(param_ids, input$pdp_branch)] + shiny::updateSelectInput(session, "pdp_x", choices = selectable_ids, selected = selectable_ids[[1]]) + shiny::updateSelectInput(session, "pdp_y", choices = selectable_ids, selected = selectable_ids[[2]]) }) output$pdp = renderPlot({ - if (is.null(input$select_x)) return() + if (is.null(input$pdp_x)) return() partial_dependence_plot( - instance, x = input$select_x, y = input$select_y, + instance, x = input$pdp_x, y = input$pdp_y, type = "default", grid_size = 20 ) }) From 93318c0ab238ad0059d2489b26212a59c61dbb2a Mon Sep 17 00:00:00 2001 From: b-zhou Date: Fri, 6 Sep 2024 13:26:52 +0200 Subject: [PATCH 12/41] refactor: pdp --- R/visualization.R | 12 +++++------- R/visualization_app.R | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index 32a1de4..1cf9266 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -197,11 +197,8 @@ parallel_coordinates = function(archive, cols_x = NULL, trafo = FALSE, theme = g #' Name of the parameter to be mapped to the x-axis. #' @param y (`character(1)`) #' Name of the parameter to be mapped to the y-axis. -#' @param grid_size (`numeric(1)` | `numeric(2)`) +#' @param grid.size (`numeric(1)` | `numeric(2)`) #' The size of the grid. See `grid.size` of `[iml::FeatureEffect]`. -#' @param center_at (`numeric(1)`) -#' Value at which the plot was centered. Ignored in the case of two features. -#' See `center.at` of `[iml::FeatureEffect]`. #' @param type (`character(1)`) #' Type of the two-parameter partial dependence plot. Possible options are listed below. #' \itemize{ @@ -213,7 +210,7 @@ parallel_coordinates = function(archive, cols_x = NULL, trafo = FALSE, theme = g #' #' @export partial_dependence_plot = function( - instance, x, y = NULL, grid_size = 20, center_at = NULL, + instance, x, y, grid.size = 20, center_at = NULL, type = "default", theme = ggplot2::theme_minimal() ) { @@ -273,8 +270,7 @@ partial_dependence_plot = function( predictor, param_ids, method = "pdp", - center.at = center_at, - grid.size = grid_size + grid.size = grid.size ) .data = NULL @@ -287,6 +283,8 @@ partial_dependence_plot = function( ggplot2::geom_contour_filled() + ggplot2::scale_fill_viridis_d(direction = -1), + # FIXME: rug = TRUE causes error when, e.g., x = "svm.cost", y = "svm.degree" + # related to the problem that degree is missing for some instances? default = eff$plot(rug = FALSE) ) diff --git a/R/visualization_app.R b/R/visualization_app.R index 94bc699..877aeb1 100644 --- a/R/visualization_app.R +++ b/R/visualization_app.R @@ -119,7 +119,7 @@ visualize = function(instance) { if (is.null(input$pdp_x)) return() partial_dependence_plot( instance, x = input$pdp_x, y = input$pdp_y, - type = "default", grid_size = 20 + type = "default" ) }) } From 9f4f9196c901598f4c7cbc6352b602d29bb0dde6 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Thu, 12 Sep 2024 15:03:16 +0200 Subject: [PATCH 13/41] feat: pareto front --- R/visualization.R | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/R/visualization.R b/R/visualization.R index 1cf9266..ee42fc1 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -291,3 +291,29 @@ partial_dependence_plot = function( # TBD: remove existing scales, use viridis instead g + ggplot2::scale_fill_viridis_c(name = objective) + theme } + + +#' @title Pareto Front +#' +#' @description Plots the Pareto front with x-axis representing the tuning objective (e.g. `"classif.ce`) and y-axis representing time (the `runtime_learners` column in the archive). +#' +#' @param instance ([`mlr3tuning::TuningInstanceAsyncSingleCrit`]) +pareto_front = function(instance, theme = ggplot2::theme_minimal()) { + # adopted from `Archive$best()` for multi-crit + archive = instance$archive + tab = archive$finished_data + ymat = t(as.matrix(tab[, c(archive$cols_y, "runtime_learners"), with = FALSE])) + ymat = archive$codomain$maximization_to_minimization * ymat + best = tab[!bbotk::is_dominated(ymat)] + + .data = NULL + ggplot2::ggplot() + + ggplot2::geom_point(data = archive$data, + ggplot2::aes(x = .data[[archive$cols_y]], y = .data$runtime_learners), + alpha = 0.2 + ) + + ggplot2::geom_step(data = best, + ggplot2::aes(x = .data[[archive$cols_y]], y = .data$runtime_learners) + ) + + theme +} From 1389a1593d51507fe15e30cc94854b64269f687b Mon Sep 17 00:00:00 2001 From: b-zhou Date: Thu, 12 Sep 2024 15:06:16 +0200 Subject: [PATCH 14/41] docs: templates --- R/visualization.R | 21 ++++++++++++--------- man-roxygen/archive.R | 2 -- man-roxygen/param_instance.R | 3 +++ man-roxygen/param_theme.R | 2 ++ man-roxygen/theme.R | 2 -- 5 files changed, 17 insertions(+), 13 deletions(-) delete mode 100644 man-roxygen/archive.R create mode 100644 man-roxygen/param_instance.R create mode 100644 man-roxygen/param_theme.R delete mode 100644 man-roxygen/theme.R diff --git a/R/visualization.R b/R/visualization.R index ee42fc1..aa210a8 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -2,8 +2,8 @@ #' #' @description #' -#' @template archive -#' @template theme +#' @template param_instance +#' @template param_theme #' #' @export cost_over_time = function(archive, theme = ggplot2::theme_minimal()) { @@ -28,13 +28,13 @@ cost_over_time = function(archive, theme = ggplot2::theme_minimal()) { #' #' @description #' -#' @param instance (`[mlr3tuning::TuningInstanceAsyncSingleCrit]`) +#' @template param_instance #' @param x (`character(1)`) #' Name of the parameter to be mapped to the x-axis. #' @param y (`character(1)`) #' Name of the parameter to be mapped to the y-axis. #' If `NULL` (default), the measure (e.g. `classif.ce`) is mapped to the y-axis. -#' @template theme +#' @template param_theme #' #' @export marginal_plot = function(archive, x, y = NULL, theme = ggplot2::theme_minimal()) { @@ -95,14 +95,14 @@ marginal_plot = function(archive, x, y = NULL, theme = ggplot2::theme_minimal()) #' #' @description Adapted from [mlr3viz::autoplot()] with `type == "parallel"`. Since the hyperparameters of each individual learner are conditioned on `branch.selection`, missing values are expected in the archive data. When standardizing the hyperparameter values (referred to as "x values" in the following to be consistent with `mlr3viz` documentation), `na.omit == TRUE` is used to compute `mean()` and `sd()`. #' -#' @template archive +#' @template param_instance #' @param cols_x (`character()`) #' Column names of x values. #' By default, all untransformed x values from the search space are plotted. #' @param trafo (`character(1)`) #' If `FALSE` (default), the untransformed x values are plotted. #' If `TRUE`, the transformed x values are plotted. -#' @template theme +#' @template param_theme #' #' @export parallel_coordinates = function(archive, cols_x = NULL, trafo = FALSE, theme = ggplot2::theme_minimal()) { @@ -192,7 +192,7 @@ parallel_coordinates = function(archive, cols_x = NULL, trafo = FALSE, theme = g #' #' @description Creates a partial dependenc plot (PDP) via the `[iml]` package. #' -#' @param instance ([`mlr3tuning::TuningInstanceAsyncSingleCrit`]) +#' @template param_instance #' @param x (`character(1)`) #' Name of the parameter to be mapped to the x-axis. #' @param y (`character(1)`) @@ -206,7 +206,7 @@ parallel_coordinates = function(archive, cols_x = NULL, trafo = FALSE, theme = g #' \item `"contour"`: Create a contour plot using `[ggplot2::geom_contour_filled]`. Only supported if both parameters are numerical. #' } #' Ignored if only one parameter is provided. -#' @template theme +#' @template param_theme #' #' @export partial_dependence_plot = function( @@ -297,7 +297,10 @@ partial_dependence_plot = function( #' #' @description Plots the Pareto front with x-axis representing the tuning objective (e.g. `"classif.ce`) and y-axis representing time (the `runtime_learners` column in the archive). #' -#' @param instance ([`mlr3tuning::TuningInstanceAsyncSingleCrit`]) +#' @template param_instance +#' @template param_theme +#' +#' @export pareto_front = function(instance, theme = ggplot2::theme_minimal()) { # adopted from `Archive$best()` for multi-crit archive = instance$archive diff --git a/man-roxygen/archive.R b/man-roxygen/archive.R deleted file mode 100644 index 2fef8ae..0000000 --- a/man-roxygen/archive.R +++ /dev/null @@ -1,2 +0,0 @@ -#' @param archive (`[ArchiveAsyncTuning]`) -#' The tuning archive to visualize. \ No newline at end of file diff --git a/man-roxygen/param_instance.R b/man-roxygen/param_instance.R new file mode 100644 index 0000000..0a13556 --- /dev/null +++ b/man-roxygen/param_instance.R @@ -0,0 +1,3 @@ +#' @param instance (`[TuningInstanceAsyncSingleCrit]`)\cr +#' Single-criterion tuning instance with Rush. +#' For [mlr3automl] learners, the tuning instance is stored in the field `$instance`. \ No newline at end of file diff --git a/man-roxygen/param_theme.R b/man-roxygen/param_theme.R new file mode 100644 index 0000000..b4a6efd --- /dev/null +++ b/man-roxygen/param_theme.R @@ -0,0 +1,2 @@ +#' @param theme ([ggplot2::theme()])\cr +#' The [ggplot2::theme_minimal()] is applied by default to all plots. \ No newline at end of file diff --git a/man-roxygen/theme.R b/man-roxygen/theme.R deleted file mode 100644 index dd9bcfd..0000000 --- a/man-roxygen/theme.R +++ /dev/null @@ -1,2 +0,0 @@ -#' @param theme ([ggplot2::theme()])\cr -#' The [ggplot2::theme_minimal()] is applied by default to all plots. \ No newline at end of file From 64337dedc2c20d76860e86b70eb030cbb3be5ff5 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Thu, 12 Sep 2024 15:08:58 +0200 Subject: [PATCH 15/41] feat: pareto front app --- R/visualization_app.R | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/R/visualization_app.R b/R/visualization_app.R index 877aeb1..052edf4 100644 --- a/R/visualization_app.R +++ b/R/visualization_app.R @@ -79,6 +79,10 @@ visualize = function(instance) { bslib::nav_panel( "Partial Dependence Plots", bslib::card(shiny::plotOutput("pdp")) + ), + bslib::nav_panel( + "Pareto Front", + bslib::card(shiny::plotOutput("pf")) ) ) @@ -122,6 +126,11 @@ visualize = function(instance) { type = "default" ) }) + + + output$pf = renderPlot({ + pareto_front(instance) + }) } shiny::shinyApp(ui = ui, server = server) From eb03e54fad9038d848a9902234d1251f73e6c863 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Thu, 12 Sep 2024 15:51:51 +0200 Subject: [PATCH 16/41] feat: select branch ui --- DESCRIPTION | 1 + R/helpers_app.R | 39 ++++++++++++++++++++++ R/visualization.R | 3 +- R/visualization_app.R | 78 ++++++++++++++++++++++++++----------------- 4 files changed, 89 insertions(+), 32 deletions(-) create mode 100644 R/helpers_app.R diff --git a/DESCRIPTION b/DESCRIPTION index 8959a21..bea7757 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -66,6 +66,7 @@ Collate: 'LearnerClassifAutoXgboost.R' 'LearnerRegrAuto.R' 'helper.R' + 'helpers_app.R' 'visualization_app.R' 'visualization.R' 'zzz.R' diff --git a/R/helpers_app.R b/R/helpers_app.R new file mode 100644 index 0000000..06d20c0 --- /dev/null +++ b/R/helpers_app.R @@ -0,0 +1,39 @@ +#' @title Custom conditionalPanel for hyperparameter selection +#' +#' @description +#' Used for Marginal Plots and Partial Dependence Plots. +#' +#' @param condition (`character(1)`)\cr +#' Passed to the `condition` argument of shiny::conditionalPanel. +#' @param prefix (`character(1)`)\cr +#' Prefix of input slot names. +#' @param learner_ids (`character()`)\cr +#' Vector of all possible learner/branch IDs. +#' @param param_ids (`character()`)\cr +#' Vector of all possible param IDs. +#' +param_panel = function(condition, prefix, learner_ids, param_ids) { + assert_string(condition) + assert_string(prefix) + assert_character(learner_ids) + assert_character(param_ids) + + shiny::conditionalPanel( + condition, + shiny::selectInput(paste0(prefix, "_branch"), + label = "Select branch:", + choices = learner_ids + ), + # choices and selected are just placeholders for initialization + shiny::selectInput(paste0(prefix, "_x"), + label = "Select x-axis:", + choices = param_ids, + selected = param_ids[[1]] + ), + shiny::selectInput(paste0(prefix, "_y"), + label = "Select y-axis:", + choices = param_ids, + selected = param_ids[[2]] + ) + ) +} diff --git a/R/visualization.R b/R/visualization.R index aa210a8..b092a57 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -59,7 +59,6 @@ marginal_plot = function(archive, x, y = NULL, theme = ggplot2::theme_minimal()) y = .data[[measure]] )) + ggplot2::geom_point(alpha = 0.6) + - ggplot2::labs(title = "Marginal plot") + theme if (archive$search_space$is_logscale[[x]]) { @@ -77,7 +76,7 @@ marginal_plot = function(archive, x, y = NULL, theme = ggplot2::theme_minimal()) )) + ggplot2::geom_point(alpha = 0.6) + ggplot2::scale_color_viridis_c() + - ggplot2::labs(title = "Marginal plot", x = x) + + ggplot2::labs(x = x) + theme if (archive$search_space$is_logscale[[x]]) { diff --git a/R/visualization_app.R b/R/visualization_app.R index 052edf4..0f3df12 100644 --- a/R/visualization_app.R +++ b/R/visualization_app.R @@ -5,28 +5,27 @@ visualize = function(instance) { archive = instance$archive param_ids = archive$cols_x - branches = unique(archive$data$branch.selection) + learner_ids = unique(archive$data$branch.selection) ui = bslib::page_navbar( id = "nav", title = "Visualization for mlr3automl", - # TBD: cost over time: select timestamp_x / timestamp_y / config_id + sidebar = bslib::sidebar( - shiny::conditionalPanel( - # TBD: select branch, then parameter, as in PDP + # TBD: cost over time: select timestamp_x / timestamp_y / config_id + param_panel( "input.nav === 'Marginal Plots'", - shiny::selectInput("mp_y", - label = "Select parameter for y-axis", - choices = c(param_ids, "NULL"), - ), - shiny::selectInput("mp_x", - label = "Select parameter for x-axis", - choices = param_ids, - ) + "mp", + learner_ids, + param_ids ), shiny::conditionalPanel( # TBD: select branch, then parameter, as in PDP "input.nav === 'Parallel Coordinates'", + shiny::selectInput("pc_branch", + label = "Select branch:", + choices = learner_ids + ), shiny::checkboxGroupInput("pc_cols_x", label = "Select hyperparameters to plot:", choices = param_ids, @@ -46,24 +45,14 @@ visualize = function(instance) { inline = TRUE ) ), - shiny::conditionalPanel( + param_panel( "input.nav === 'Partial Dependence Plots'", - shiny::radioButtons("pdp_branch", - label = "Select branch:", - choices = branches - ), - shiny::selectInput("pdp_x", - label = "Select parameter for x-axis:", - choices = param_ids, - selected = param_ids[[1]] - ), - shiny::selectInput("pdp_y", - label = "Select parameter for y-axis:", - choices = param_ids, - selected = param_ids[[2]] - ) + "pdp", + learner_ids, + param_ids ) ), + bslib::nav_panel( "Cost Over Time", bslib::card(shiny::plotOutput("cost_over_time")) @@ -89,10 +78,20 @@ visualize = function(instance) { server = function(input, output, session) { session$onSessionEnded(stopApp) + + # Cost over time output$cost_over_time = renderPlot({ cost_over_time(archive) }) + + # Marginal plots + shiny::observeEvent(input$mp_branch, { + selectable_ids = param_ids[startsWith(param_ids, input$mp_branch)] + shiny::updateSelectInput(session, "mp_x", choices = selectable_ids, selected = selectable_ids[[1]]) + shiny::updateSelectInput(session, "mp_y", choices = selectable_ids, selected = selectable_ids[[2]]) + }) + output$marginal_plot = renderPlot({ if (input$mp_y == "NULL") { marginal_plot(archive, x = input$mp_x) @@ -101,24 +100,42 @@ visualize = function(instance) { } }) + + # Parallel Coordinates + shiny::observeEvent(input$pc_branch, { + selectable_ids = param_ids[startsWith(param_ids, input$pc_branch)] + shiny::updateCheckboxGroupInput(session, + "pc_cols_x", + choices = selectable_ids, + # select all by default + selected = selectable_ids + ) + }) + output$parallel_coordinates = renderPlot({ if (is.null(input$pc_cols_x)) return() # nothing selected trafo = input$pc_trafo == "Yes" parallel_coordinates(archive, cols_x = input$pc_cols_x, trafo = trafo) }) + shiny::observeEvent(input$pc_unselect_all, { - shiny::updateCheckboxGroupInput(session, "pc_cols_x", choices = param_ids, selected = NULL) + selectable_ids = param_ids[startsWith(param_ids, input$pc_branch)] + shiny::updateCheckboxGroupInput(session, "pc_cols_x", choices = selectable_ids, selected = NULL) }) + shiny::observeEvent(input$pc_select_all, { - shiny::updateCheckboxGroupInput(session, "pc_cols_x", choices = param_ids, selected = param_ids) + selectable_ids = param_ids[startsWith(param_ids, input$pc_branch)] + shiny::updateCheckboxGroupInput(session, "pc_cols_x", choices = selectable_ids, selected = selectable_ids) }) + # Partial Dependence Plots shiny::observeEvent(input$pdp_branch, { selectable_ids = param_ids[startsWith(param_ids, input$pdp_branch)] shiny::updateSelectInput(session, "pdp_x", choices = selectable_ids, selected = selectable_ids[[1]]) shiny::updateSelectInput(session, "pdp_y", choices = selectable_ids, selected = selectable_ids[[2]]) }) + output$pdp = renderPlot({ if (is.null(input$pdp_x)) return() partial_dependence_plot( @@ -126,8 +143,9 @@ visualize = function(instance) { type = "default" ) }) - + + # Pareto Front output$pf = renderPlot({ pareto_front(instance) }) From 938e132229dea2db641d0637558836f3a5940c32 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Thu, 12 Sep 2024 16:25:33 +0200 Subject: [PATCH 17/41] feat: select time variable for cost over time --- R/visualization.R | 33 ++++++++++++++++++++++----------- R/visualization_app.R | 13 ++++++++++++- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index b092a57..708a686 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -3,24 +3,35 @@ #' @description #' #' @template param_instance +#' @param time (`character(1)`)\cr +#' Column in the archive to be interpreted as the time variable, e.g. "timestamp_xs", "timestamp_ys". +#' If `NULL` (default), the configuration ID will be used. #' @template param_theme #' #' @export -cost_over_time = function(archive, theme = ggplot2::theme_minimal()) { +cost_over_time = function(archive, time = NULL, theme = ggplot2::theme_minimal()) { # there should only be one objective, e.g. `classif.ce` - measure = archive$cols_y - + objective = archive$cols_y + .data = NULL - ggplot2::ggplot(data = archive$data, ggplot2::aes( - x = seq_row(archive$data), - y = .data[[measure]] + if (is.null(time)) { + x = seq_row(archive$data) + g = ggplot2::ggplot(data = as.data.table(archive), ggplot2::aes( + x = x, + y = .data[[objective]] )) + - ggplot2::geom_point() + + ggplot2::labs(x = "configuration ID") + } else { + assert_choice(time, names(as.data.table(archive))) + g = ggplot2::ggplot(data = as.data.table(archive), ggplot2::aes( + x = .data[[time]], + y = .data[[objective]] + )) + } + + + g + ggplot2::geom_point() + ggplot2::geom_line() + - ggplot2::labs( - title = "Cost over time", - x = "configuration ID" - ) + theme } diff --git a/R/visualization_app.R b/R/visualization_app.R index 0f3df12..4baceb2 100644 --- a/R/visualization_app.R +++ b/R/visualization_app.R @@ -13,6 +13,13 @@ visualize = function(instance) { sidebar = bslib::sidebar( # TBD: cost over time: select timestamp_x / timestamp_y / config_id + shiny::conditionalPanel( + "input.nav === 'Cost Over Time'", + shiny::radioButtons("cot_x", + label = "Select x-axis:", + choices = c("configuration ID", "timestamp_xs", "timestamp_ys") + ) + ), param_panel( "input.nav === 'Marginal Plots'", "mp", @@ -81,7 +88,11 @@ visualize = function(instance) { # Cost over time output$cost_over_time = renderPlot({ - cost_over_time(archive) + if (input$cot_x == "configuration ID") { + cost_over_time(archive) + } else { + cost_over_time(archive, time = input$cot_x) + } }) From 2a5efcca618570665cda65f05e7b0cdd9bc416fe Mon Sep 17 00:00:00 2001 From: b-zhou Date: Fri, 4 Oct 2024 11:53:27 +0200 Subject: [PATCH 18/41] chore: add author --- DESCRIPTION | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 3d91b39..ec2a66a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -4,7 +4,8 @@ Version: 0.0.1 Authors@R: c( person("Damir", "Pulatov", , "damirpolat@protonmail.com", role = c("cre", "aut")), person("Marc", "Becker", , "marcbecker@posteo.de", role = "aut", - comment = c(ORCID = "0000-0002-8115-0400")) + comment = c(ORCID = "0000-0002-8115-0400")), + person("Baisu", "Zhou", "baisu.zhou@outlook.com", role = "aut") ) Description: Flexible AutoML system for the 'mlr3' ecosystem. License: LGPL-3 From 4edfbb67230e7806c6b360b4f4f60c53f90214c1 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Fri, 4 Oct 2024 14:12:20 +0200 Subject: [PATCH 19/41] fix: pdp faeture type --- R/visualization.R | 64 +++++++++++++++++++++++++++-------------------- 1 file changed, 37 insertions(+), 27 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index 708a686..89e92ef 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -225,13 +225,10 @@ partial_dependence_plot = function( theme = ggplot2::theme_minimal() ) { archive = instance$archive - objective = archive$cols_y assert_choice(x, archive$cols_x) assert_choice(y, archive$cols_x, null.ok = TRUE) - param_ids = c(x, y) - - branch = tstrsplit(param_ids, "\\.")[[1]] + branch = tstrsplit(c(x, y), "\\.")[[1]] branch = unique(branch) if (length(branch) > 1) { stop("Parameters from different branches cannot be plotted in the same PDP.") @@ -241,44 +238,57 @@ partial_dependence_plot = function( assert_choice(type, c("contour", "default")) } - non_numeric = some(param_ids, function(param_id) { + non_numeric = some(c(x, y), function(param_id) { !is.numeric(archive$data[[param_id]]) }) if (non_numeric && type == "contour") { stop("Contour plot not supported for non-numeric parameters") } - surrogate = default_surrogate(instance) - surrogate$archive = archive - surrogate$update() - - # store the data.table format for later use in predict.function - prototype = archive$data[0, archive$cols_x, with = FALSE] + # prepare data for surrogate model + archive_data = as.data.table(archive)[, c(archive$cols_x, archive$cols_y), with = FALSE] + archive_data = archive_data[!is.na(archive_data[[archive$cols_y]]), ] + archive_data[, archive$cols_x := lapply(.SD, function(col) { + # iml does not accept lgcl features + if (is.logical(col)) return(factor(col, levels = c(FALSE, TRUE))) + # also convert integer to double to avoid imputeoor error + if (is.integer(col)) return(as.numeric(col)) + return(col) + }), .SDcols = archive$cols_x] + task = as_task_regr(archive_data, target = archive$cols_y) + + # train surrogate model + surrogate = po("imputeoor", + multiplier = 3, + affect_columns = selector_type(c("numeric", "character", "factor", "ordered")) + ) %>>% default_rf() + surrogate = GraphLearner$new(surrogate) + surrogate$train(task) + + # # store the data.table format for later use in predict.function + # prototype = archive_data[0, archive$cols_x, with = FALSE] + + # new data to compute PDP + pdp_data = generate_design_random(archive$search_space, n = 1e3)$data + # same type conversion as above + pdp_data[, archive$cols_x := lapply(.SD, function(col) { + if (is.logical(col)) return(factor(col, levels = c(FALSE, TRUE))) + if (is.integer(col)) return(as.numeric(col)) + return(col) + }), .SDcols = archive$cols_x] + pdp_data_types = pdp_data[, lapply(.SD, storage.mode)] predictor = iml::Predictor$new( model = surrogate, - data = as.data.table(archive)[branch.selection == branch, param_ids, with = FALSE], - y = as.data.table(archive)[branch.selection == branch, objective, with = FALSE], + data = pdp_data[, archive$cols_x, with = FALSE], predict.function = function(model, newdata) { - setDT(newdata) - # reconstruct task layout from training to prevent error in imputeoor - newdata = merge(newdata, prototype, by = param_ids, all = TRUE) - setcolorder(newdata, names(prototype)) - - # convert numeric to integer to match with training task - walk(param_ids, function(param_id) { - if (is.integer(archive$data[[param_id]])) { - set(newdata, j = param_id, value = as.integer(newdata[[param_id]])) - } - }) - - return(model$predict(newdata)$mean) + model$predict_newdata(newdata)$response } ) eff = iml::FeatureEffect$new( predictor, - param_ids, + c(x, y), method = "pdp", grid.size = grid.size ) From 14c7ef572d7d0fc234bdb2ba9a386947fad7da64 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Fri, 4 Oct 2024 15:23:21 +0200 Subject: [PATCH 20/41] fix: pdp app --- R/helpers_app.R | 9 ++++++--- R/visualization.R | 2 +- R/visualization_app.R | 37 +++++++++++++++++++++++-------------- 3 files changed, 30 insertions(+), 18 deletions(-) diff --git a/R/helpers_app.R b/R/helpers_app.R index 06d20c0..f537f5c 100644 --- a/R/helpers_app.R +++ b/R/helpers_app.R @@ -4,15 +4,17 @@ #' Used for Marginal Plots and Partial Dependence Plots. #' #' @param condition (`character(1)`)\cr -#' Passed to the `condition` argument of shiny::conditionalPanel. +#' Passed to the `condition` argument of `[shiny::conditionalPanel]`. #' @param prefix (`character(1)`)\cr #' Prefix of input slot names. #' @param learner_ids (`character()`)\cr #' Vector of all possible learner/branch IDs. #' @param param_ids (`character()`)\cr #' Vector of all possible param IDs. +#' @param ... (anything) +#' Additional arguments passed to `[shiny::conditionalPanel]`. #' -param_panel = function(condition, prefix, learner_ids, param_ids) { +param_panel = function(condition, prefix, learner_ids, param_ids, ...) { assert_string(condition) assert_string(prefix) assert_character(learner_ids) @@ -34,6 +36,7 @@ param_panel = function(condition, prefix, learner_ids, param_ids) { label = "Select y-axis:", choices = param_ids, selected = param_ids[[2]] - ) + ), + ... ) } diff --git a/R/visualization.R b/R/visualization.R index 89e92ef..1118db5 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -309,7 +309,7 @@ partial_dependence_plot = function( ) # TBD: remove existing scales, use viridis instead - g + ggplot2::scale_fill_viridis_c(name = objective) + theme + g + ggplot2::scale_fill_viridis_c(name = archive$cols_y, direction = -1) + theme } diff --git a/R/visualization_app.R b/R/visualization_app.R index 4baceb2..5b9c3b4 100644 --- a/R/visualization_app.R +++ b/R/visualization_app.R @@ -12,7 +12,6 @@ visualize = function(instance) { title = "Visualization for mlr3automl", sidebar = bslib::sidebar( - # TBD: cost over time: select timestamp_x / timestamp_y / config_id shiny::conditionalPanel( "input.nav === 'Cost Over Time'", shiny::radioButtons("cot_x", @@ -27,7 +26,6 @@ visualize = function(instance) { param_ids ), shiny::conditionalPanel( - # TBD: select branch, then parameter, as in PDP "input.nav === 'Parallel Coordinates'", shiny::selectInput("pc_branch", label = "Select branch:", @@ -56,7 +54,10 @@ visualize = function(instance) { "input.nav === 'Partial Dependence Plots'", "pdp", learner_ids, - param_ids + param_ids, + shiny::actionButton("pdp_process", + label = "Process" + ) ) ), @@ -103,7 +104,7 @@ visualize = function(instance) { shiny::updateSelectInput(session, "mp_y", choices = selectable_ids, selected = selectable_ids[[2]]) }) - output$marginal_plot = renderPlot({ + output$marginal_plot = shiny::renderPlot({ if (input$mp_y == "NULL") { marginal_plot(archive, x = input$mp_x) } else { @@ -123,7 +124,7 @@ visualize = function(instance) { ) }) - output$parallel_coordinates = renderPlot({ + output$parallel_coordinates = shiny::renderPlot({ if (is.null(input$pc_cols_x)) return() # nothing selected trafo = input$pc_trafo == "Yes" parallel_coordinates(archive, cols_x = input$pc_cols_x, trafo = trafo) @@ -146,18 +147,26 @@ visualize = function(instance) { shiny::updateSelectInput(session, "pdp_x", choices = selectable_ids, selected = selectable_ids[[1]]) shiny::updateSelectInput(session, "pdp_y", choices = selectable_ids, selected = selectable_ids[[2]]) }) - - output$pdp = renderPlot({ - if (is.null(input$pdp_x)) return() - partial_dependence_plot( - instance, x = input$pdp_x, y = input$pdp_y, - type = "default" - ) - }) + + # generate plot only after pressing the "Process" button + # because it takes quite a while... + output$pdp = shiny::bindEvent( + shiny::renderPlot({ + if (is.null(input$pdp_x)) return() + progress <- shiny::Progress$new() + on.exit(progress$close()) + progress$set(message = "Making plot. Please wait.") + partial_dependence_plot( + instance, x = input$pdp_x, y = input$pdp_y, + type = "default" + ) + }), + input$pdp_process + ) # Pareto Front - output$pf = renderPlot({ + output$pf = shiny::renderPlot({ pareto_front(instance) }) } From 7c1bfe569ff8d94e11e36519cc9061648962860b Mon Sep 17 00:00:00 2001 From: b-zhou Date: Fri, 4 Oct 2024 17:30:30 +0200 Subject: [PATCH 21/41] refactor: pdp params --- R/visualization.R | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index 1118db5..5bdc80d 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -207,8 +207,6 @@ parallel_coordinates = function(archive, cols_x = NULL, trafo = FALSE, theme = g #' Name of the parameter to be mapped to the x-axis. #' @param y (`character(1)`) #' Name of the parameter to be mapped to the y-axis. -#' @param grid.size (`numeric(1)` | `numeric(2)`) -#' The size of the grid. See `grid.size` of `[iml::FeatureEffect]`. #' @param type (`character(1)`) #' Type of the two-parameter partial dependence plot. Possible options are listed below. #' \itemize{ @@ -217,11 +215,12 @@ parallel_coordinates = function(archive, cols_x = NULL, trafo = FALSE, theme = g #' } #' Ignored if only one parameter is provided. #' @template param_theme +#' @param ... (anything) +#' Arguments passed to `[iml::FeatureEffect]`. #' #' @export partial_dependence_plot = function( - instance, x, y, grid.size = 20, center_at = NULL, - type = "default", + instance, x, y, type = "default", theme = ggplot2::theme_minimal() ) { archive = instance$archive @@ -289,8 +288,7 @@ partial_dependence_plot = function( eff = iml::FeatureEffect$new( predictor, c(x, y), - method = "pdp", - grid.size = grid.size + method = "pdp" ) .data = NULL From 7eae7bd61061300c809af5f4d33736ef66769996 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Mon, 7 Oct 2024 18:59:39 +0200 Subject: [PATCH 22/41] test: cot & mp --- DESCRIPTION | 3 +- tests/testthat/test_visualization.R | 136 +++++++++++++++++++++++++--- 2 files changed, 125 insertions(+), 14 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index ec2a66a..7556563 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -46,7 +46,8 @@ Suggests: ranger, rpart, testthat (>= 3.0.0), - xgboost + xgboost, + vdiffr Remotes: catboost/catboost/catboost/R-package, mlr-org/mlr3, diff --git a/tests/testthat/test_visualization.R b/tests/testthat/test_visualization.R index 4777464..267cdaa 100644 --- a/tests/testthat/test_visualization.R +++ b/tests/testthat/test_visualization.R @@ -1,17 +1,127 @@ +skip_on_cran() +skip_if_not_installed("rush") +flush_redis() + rush_plan(n_workers = 2) -skip_if_not_installed("ranger") - -task = tsk("penguins") -learner = lrn("classif.auto", - learner_ids = "ranger", - small_data_size = 1, - resampling = rsmp("holdout"), - measure = msr("classif.ce"), - terminator = trm("evals", n_evals = 6) -) -learner$train(task) +skip_if_not_installed(c("glmnet", "kknn", "ranger", "e1071")) +# cost over time test_that("cost over time works", { - plot = cost_over_time(learner$instance) - expect_class(plot, c("gg", "ggplot")) + task = tsk("penguins") + + set.seed(1453) + learner = lrn("classif.auto_ranger", + small_data_size = 1, + resampling = rsmp("holdout"), + measure = msr("classif.ce"), + terminator = trm("evals", n_evals = 6) + ) + learner$train(task) + + vdiffr::expect_doppelganger("cot-config-id", cost_over_time(learner$instance)) + vdiffr::expect_doppelganger("cot-timestamp-x", cost_over_time(learner$instance, time = "timestamp_x")) + vdiffr::expect_doppelganger("cot-timestamp-y", cost_over_time(learner$instance, time = "timestamp_y")) }) + + +# marginal plots +test_that("marginal plot works", { + task = tsk("penguins") + + # numeric vs numeric + set.seed(1453) + learner_glmnet = lrn("classif.auto_glmnet", + small_data_size = 1, + resampling = rsmp("holdout"), + measure = msr("classif.ce"), + terminator = trm("evals", n_evals = 6) + ) + learner_glmnet$train(task) + vdiffr::expect_doppelganger( + "mp-numeric-numeric", + marginal_plot(learner_glmnet$instance, x = "glmnet.alpha", y = "glmnet.s") + ) + vdiffr::expect_doppelganger( + "mp-numeric-numeric", + marginal_plot(learner_glmnet$instance, x = "glmnet.s", y = "glmnet.alpha") + ) + + # numeric vs factor + set.seed(1453) + learner_kknn = lrn("classif.auto_kknn", + small_data_size = 1, + resampling = rsmp("holdout"), + measure = msr("classif.ce"), + terminator = trm("evals", n_evals = 6) + ) + learner_kknn$train(task) + vdiffr::expect_doppelganger( + "mp-numeric-factor", + marginal_plot(learner_kknn$instance, x = "kknn.distance", y = "kknn.kernel") + ) + vdiffr::expect_doppelganger( + "mp-factor-numeric", + marginal_plot(learner_kknn$instance, x = "kknn.kernel", y = "kknn.distance") + ) + + # numeric vs logical + set.seed(1453) + learner_ranger = lrn("classif.auto_ranger", + small_data_size = 1, + resampling = rsmp("holdout"), + measure = msr("classif.ce"), + terminator = trm("evals", n_evals = 6) + ) + learner_ranger$train(task) + vdiffr::expect_doppelganger( + "mp-numeric-logical", + marginal_plot(learner$instance, x = "ranger.num.trees", y = "ranger.replace") + ) + vdiffr::expect_doppelganger( + "mp-logical-numeric", + marginal_plot(learner$instance, x = "ranger.replace", y = "ranger.num.trees") + ) +}) + +test_that("marginal plot throws error if params on different branches", { + task = tsk("penguins") + + set.seed(1453) + learner = lrn("classif.auto", + learner_ids = c("kknn", "svm") + small_data_size = 1, + resampling = rsmp("holdout"), + measure = msr("classif.ce"), + terminator = trm("evals", n_evals = 6) + ) + expect_error( + marginal_plot(learner$instance, x = "kknn.distance", y = "svm.cost") + ) +}) + +test_that("marginal plot handles dependence", { + task = tsk("penguins") + + set.seed(1453) + learner_svm = lrn("classif.auto_svm", + small_data_size = 1, + resampling = rsmp("holdout"), + measure = msr("classif.ce"), + terminator = trm("evals", n_evals = 6) + ) + learner_svm$train(task) + vdiffr::expect_doppelganger( + "mp-dependence", + marginal_plot(learner$instance, x = "svm.kernel", y = "svm.degree") + ) +}) + + +# parallel coordinates + +# pdp + +# pareto front + +# footprint + From 4da1d35649716a3e57a64edb177f414af9d912b7 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Mon, 7 Oct 2024 19:34:24 +0200 Subject: [PATCH 23/41] fix: cot --- R/visualization.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index 5bdc80d..465711c 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -9,7 +9,8 @@ #' @template param_theme #' #' @export -cost_over_time = function(archive, time = NULL, theme = ggplot2::theme_minimal()) { +cost_over_time = function(instance, time = NULL, theme = ggplot2::theme_minimal()) { + archive = instance$archive # there should only be one objective, e.g. `classif.ce` objective = archive$cols_y @@ -29,7 +30,6 @@ cost_over_time = function(archive, time = NULL, theme = ggplot2::theme_minimal() )) } - g + ggplot2::geom_point() + ggplot2::geom_line() + theme From 5a503cfb96372942d885bed752600b49e52e52d8 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Mon, 7 Oct 2024 19:41:36 +0200 Subject: [PATCH 24/41] fix: mp trafo --- R/visualization.R | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index 465711c..a762042 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -48,18 +48,20 @@ cost_over_time = function(instance, time = NULL, theme = ggplot2::theme_minimal( #' @template param_theme #' #' @export -marginal_plot = function(archive, x, y = NULL, theme = ggplot2::theme_minimal()) { +marginal_plot = function(instance, x, y = NULL, theme = ggplot2::theme_minimal()) { + archive = instance$archive param_ids = archive$cols_x assert_choice(x, param_ids) assert_choice(y, param_ids, null.ok = TRUE) # use transformed values if trafo is set x_trafo = paste0("x_domain_", x) + y_trafo = if (!is.null(y)) paste0("x_domain_", y) else NULL # there should only be one objective, e.g. `classif.ce` measure = archive$cols_y - data = na.omit(as.data.table(archive), cols = c(x_trafo, y)) + data = na.omit(as.data.table(archive), cols = c(x_trafo, y_trafo)) .data = NULL @@ -82,12 +84,12 @@ marginal_plot = function(archive, x, y = NULL, theme = ggplot2::theme_minimal()) # param provided for y g = ggplot2::ggplot(data = data, ggplot2::aes( x = .data[[x_trafo]], - y = .data[[y]], + y = .data[[y_trafo]], col = .data[[measure]] )) + ggplot2::geom_point(alpha = 0.6) + ggplot2::scale_color_viridis_c() + - ggplot2::labs(x = x) + + ggplot2::labs(x = x, y = y) + theme if (archive$search_space$is_logscale[[x]]) { From 4c2e357a536a553ea1af4119f7d4c79283750b0a Mon Sep 17 00:00:00 2001 From: b-zhou Date: Mon, 7 Oct 2024 19:43:56 +0200 Subject: [PATCH 25/41] fix: pc archive --- R/visualization.R | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/R/visualization.R b/R/visualization.R index a762042..91abd42 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -117,7 +117,11 @@ marginal_plot = function(instance, x, y = NULL, theme = ggplot2::theme_minimal() #' @template param_theme #' #' @export -parallel_coordinates = function(archive, cols_x = NULL, trafo = FALSE, theme = ggplot2::theme_minimal()) { +parallel_coordinates = function( + instance, cols_x = NULL, trafo = FALSE, + theme = ggplot2::theme_minimal() +) { + archive = instance$archive assert_subset(cols_x, c(archive$cols_x, paste0("x_domain_", archive$cols_x))) assert_flag(trafo) From 43ea5a50f733773c9103a459b5dde840bb4ac1d6 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Mon, 7 Oct 2024 19:42:55 +0200 Subject: [PATCH 26/41] fix: pdp extra args --- R/visualization.R | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index 91abd42..aeafaa4 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -227,7 +227,8 @@ parallel_coordinates = function( #' @export partial_dependence_plot = function( instance, x, y, type = "default", - theme = ggplot2::theme_minimal() + theme = ggplot2::theme_minimal(), + ... ) { archive = instance$archive assert_choice(x, archive$cols_x) @@ -294,7 +295,8 @@ partial_dependence_plot = function( eff = iml::FeatureEffect$new( predictor, c(x, y), - method = "pdp" + method = "pdp", + ... ) .data = NULL From dec1f9c56f1273dc98c06deb7024ff6ea8602db2 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Mon, 7 Oct 2024 19:43:37 +0200 Subject: [PATCH 27/41] test: cot & mp fixes --- tests/testthat/test_visualization.R | 32 +++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/tests/testthat/test_visualization.R b/tests/testthat/test_visualization.R index 267cdaa..92b43f6 100644 --- a/tests/testthat/test_visualization.R +++ b/tests/testthat/test_visualization.R @@ -1,8 +1,5 @@ skip_on_cran() skip_if_not_installed("rush") -flush_redis() - -rush_plan(n_workers = 2) skip_if_not_installed(c("glmnet", "kknn", "ranger", "e1071")) # cost over time @@ -10,6 +7,9 @@ test_that("cost over time works", { task = tsk("penguins") set.seed(1453) + flush_redis() + rush_plan(n_workers = 2) + learner = lrn("classif.auto_ranger", small_data_size = 1, resampling = rsmp("holdout"), @@ -19,8 +19,8 @@ test_that("cost over time works", { learner$train(task) vdiffr::expect_doppelganger("cot-config-id", cost_over_time(learner$instance)) - vdiffr::expect_doppelganger("cot-timestamp-x", cost_over_time(learner$instance, time = "timestamp_x")) - vdiffr::expect_doppelganger("cot-timestamp-y", cost_over_time(learner$instance, time = "timestamp_y")) + vdiffr::expect_doppelganger("cot-timestamp-xs", cost_over_time(learner$instance, time = "timestamp_xs")) + vdiffr::expect_doppelganger("cot-timestamp-ys", cost_over_time(learner$instance, time = "timestamp_ys")) }) @@ -30,6 +30,9 @@ test_that("marginal plot works", { # numeric vs numeric set.seed(1453) + flush_redis() + rush_plan(n_workers = 2) + learner_glmnet = lrn("classif.auto_glmnet", small_data_size = 1, resampling = rsmp("holdout"), @@ -48,6 +51,9 @@ test_that("marginal plot works", { # numeric vs factor set.seed(1453) + flush_redis() + rush_plan(n_workers = 2) + learner_kknn = lrn("classif.auto_kknn", small_data_size = 1, resampling = rsmp("holdout"), @@ -66,6 +72,9 @@ test_that("marginal plot works", { # numeric vs logical set.seed(1453) + flush_redis() + rush_plan(n_workers = 2) + learner_ranger = lrn("classif.auto_ranger", small_data_size = 1, resampling = rsmp("holdout"), @@ -75,11 +84,11 @@ test_that("marginal plot works", { learner_ranger$train(task) vdiffr::expect_doppelganger( "mp-numeric-logical", - marginal_plot(learner$instance, x = "ranger.num.trees", y = "ranger.replace") + marginal_plot(learner_ranger$instance, x = "ranger.num.trees", y = "ranger.replace") ) vdiffr::expect_doppelganger( "mp-logical-numeric", - marginal_plot(learner$instance, x = "ranger.replace", y = "ranger.num.trees") + marginal_plot(learner_ranger$instance, x = "ranger.replace", y = "ranger.num.trees") ) }) @@ -87,13 +96,17 @@ test_that("marginal plot throws error if params on different branches", { task = tsk("penguins") set.seed(1453) + flush_redis() + rush_plan(n_workers = 2) + learner = lrn("classif.auto", - learner_ids = c("kknn", "svm") + learner_ids = c("kknn", "svm"), small_data_size = 1, resampling = rsmp("holdout"), measure = msr("classif.ce"), terminator = trm("evals", n_evals = 6) ) + learner$train(task) expect_error( marginal_plot(learner$instance, x = "kknn.distance", y = "svm.cost") ) @@ -103,6 +116,9 @@ test_that("marginal plot handles dependence", { task = tsk("penguins") set.seed(1453) + flush_redis() + rush_plan(n_workers = 2) + learner_svm = lrn("classif.auto_svm", small_data_size = 1, resampling = rsmp("holdout"), From 5e00a219a68c2a358c9ce84087d9c9040812dc80 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Mon, 7 Oct 2024 19:46:50 +0200 Subject: [PATCH 28/41] fix: change archive to instance --- R/visualization_app.R | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/R/visualization_app.R b/R/visualization_app.R index 5b9c3b4..2685961 100644 --- a/R/visualization_app.R +++ b/R/visualization_app.R @@ -90,9 +90,9 @@ visualize = function(instance) { # Cost over time output$cost_over_time = renderPlot({ if (input$cot_x == "configuration ID") { - cost_over_time(archive) + cost_over_time(instance) } else { - cost_over_time(archive, time = input$cot_x) + cost_over_time(instance, time = input$cot_x) } }) @@ -106,9 +106,9 @@ visualize = function(instance) { output$marginal_plot = shiny::renderPlot({ if (input$mp_y == "NULL") { - marginal_plot(archive, x = input$mp_x) + marginal_plot(instance, x = input$mp_x) } else { - marginal_plot(archive, x = input$mp_x, y = input$mp_y) + marginal_plot(instance, x = input$mp_x, y = input$mp_y) } }) @@ -127,7 +127,7 @@ visualize = function(instance) { output$parallel_coordinates = shiny::renderPlot({ if (is.null(input$pc_cols_x)) return() # nothing selected trafo = input$pc_trafo == "Yes" - parallel_coordinates(archive, cols_x = input$pc_cols_x, trafo = trafo) + parallel_coordinates(instance, cols_x = input$pc_cols_x, trafo = trafo) }) shiny::observeEvent(input$pc_unselect_all, { From e6f0311ea283993ff1d91ead54d3e5a99cb52d0e Mon Sep 17 00:00:00 2001 From: b-zhou Date: Mon, 7 Oct 2024 21:05:55 +0200 Subject: [PATCH 29/41] test: mp different branches --- tests/testthat/test_visualization.R | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test_visualization.R b/tests/testthat/test_visualization.R index 92b43f6..5a47fd9 100644 --- a/tests/testthat/test_visualization.R +++ b/tests/testthat/test_visualization.R @@ -92,7 +92,7 @@ test_that("marginal plot works", { ) }) -test_that("marginal plot throws error if params on different branches", { +test_that("marginal plot accepts params on different branches", { task = tsk("penguins") set.seed(1453) @@ -107,7 +107,8 @@ test_that("marginal plot throws error if params on different branches", { terminator = trm("evals", n_evals = 6) ) learner$train(task) - expect_error( + expect_doppelganger( + "mp-different-branches", marginal_plot(learner$instance, x = "kknn.distance", y = "svm.cost") ) }) From d7b461a70994aa43a0d63a21722082fffc630e4a Mon Sep 17 00:00:00 2001 From: b-zhou Date: Mon, 7 Oct 2024 22:22:58 +0200 Subject: [PATCH 30/41] test: mp fixes --- tests/testthat/test_visualization.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/testthat/test_visualization.R b/tests/testthat/test_visualization.R index 5a47fd9..05c7344 100644 --- a/tests/testthat/test_visualization.R +++ b/tests/testthat/test_visualization.R @@ -45,7 +45,7 @@ test_that("marginal plot works", { marginal_plot(learner_glmnet$instance, x = "glmnet.alpha", y = "glmnet.s") ) vdiffr::expect_doppelganger( - "mp-numeric-numeric", + "mp-numeric-numeric2", marginal_plot(learner_glmnet$instance, x = "glmnet.s", y = "glmnet.alpha") ) @@ -107,7 +107,7 @@ test_that("marginal plot accepts params on different branches", { terminator = trm("evals", n_evals = 6) ) learner$train(task) - expect_doppelganger( + vdiffr::expect_doppelganger( "mp-different-branches", marginal_plot(learner$instance, x = "kknn.distance", y = "svm.cost") ) @@ -129,7 +129,7 @@ test_that("marginal plot handles dependence", { learner_svm$train(task) vdiffr::expect_doppelganger( "mp-dependence", - marginal_plot(learner$instance, x = "svm.kernel", y = "svm.degree") + marginal_plot(learner_svm$instance, x = "svm.kernel", y = "svm.degree") ) }) From da0e9473992432f95c0ae8fba8c3c529e73f0ecc Mon Sep 17 00:00:00 2001 From: b-zhou Date: Tue, 8 Oct 2024 11:24:02 +0200 Subject: [PATCH 31/41] fix: pdp sample size --- R/visualization.R | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index aeafaa4..41d5281 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -275,7 +275,15 @@ partial_dependence_plot = function( # prototype = archive_data[0, archive$cols_x, with = FALSE] # new data to compute PDP - pdp_data = generate_design_random(archive$search_space, n = 1e3)$data + # https://github.com/automl/DeepCAVE/blob/58d6801508468841eda038803b12fa2bbf7a0cb8/deepcave/plugins/hyperparameter/pdp.py#L334 + samples_per_param = 10 + num_samples = samples_per_param * nrow(archive_data) + max_samples = 10000 + if (num_samples > max_samples) { + num_samples = max_samples + } + + pdp_data = generate_design_random(archive$search_space, n = num_samples)$data # same type conversion as above pdp_data[, archive$cols_x := lapply(.SD, function(col) { if (is.logical(col)) return(factor(col, levels = c(FALSE, TRUE))) @@ -307,15 +315,16 @@ partial_dependence_plot = function( x = .data[[x]], y = .data[[y]], z = .data$.value )) + ggplot2::geom_contour_filled() + - ggplot2::scale_fill_viridis_d(direction = -1), + ggplot2::scale_fill_viridis_d(), # FIXME: rug = TRUE causes error when, e.g., x = "svm.cost", y = "svm.degree" # related to the problem that degree is missing for some instances? default = eff$plot(rug = FALSE) + + ggplot2::scale_fill_viridis_c(name = archive$cols_y) ) # TBD: remove existing scales, use viridis instead - g + ggplot2::scale_fill_viridis_c(name = archive$cols_y, direction = -1) + theme + g + theme } From 77f94a4853d43a898c7c1d4277078b8f7df9003c Mon Sep 17 00:00:00 2001 From: b-zhou Date: Tue, 8 Oct 2024 12:50:57 +0200 Subject: [PATCH 32/41] fix: rcmd check warnings --- DESCRIPTION | 8 ++++--- NAMESPACE | 6 +++++ R/visualization.R | 17 +++++++------- R/visualization_app.R | 4 ++-- man-roxygen/param_instance.R | 2 +- man-roxygen/param_theme.R | 2 +- man/cost_over_time.Rd | 23 ++++++++++++++++++ man/marginal_plot.Rd | 26 ++++++++++++++++++++ man/mlr3automl-package.Rd | 1 + man/parallel_coordinates.Rd | 32 +++++++++++++++++++++++++ man/param_panel.Rd | 27 +++++++++++++++++++++ man/pareto_front.Rd | 19 +++++++++++++++ man/partial_dependence_plot.Rd | 43 ++++++++++++++++++++++++++++++++++ man/visualize.Rd | 14 +++++++++++ 14 files changed, 209 insertions(+), 15 deletions(-) create mode 100644 man/cost_over_time.Rd create mode 100644 man/marginal_plot.Rd create mode 100644 man/parallel_coordinates.Rd create mode 100644 man/param_panel.Rd create mode 100644 man/pareto_front.Rd create mode 100644 man/partial_dependence_plot.Rd create mode 100644 man/visualize.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 7556563..2e87e7f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -5,7 +5,7 @@ Authors@R: c( person("Damir", "Pulatov", , "damirpolat@protonmail.com", role = c("cre", "aut")), person("Marc", "Becker", , "marcbecker@posteo.de", role = "aut", comment = c(ORCID = "0000-0002-8115-0400")), - person("Baisu", "Zhou", "baisu.zhou@outlook.com", role = "aut") + person("Baisu", "Zhou", , "baisu.zhou@outlook.com", role = "aut") ) Description: Flexible AutoML system for the 'mlr3' ecosystem. License: LGPL-3 @@ -19,6 +19,7 @@ Depends: R (>= 3.1.0), rush Imports: + bbotk, bslib, checkmate, data.table, @@ -32,6 +33,7 @@ Imports: paradox (>= 1.0.1), R6, shiny, + stats, utils Suggests: catboost, @@ -82,9 +84,9 @@ Collate: 'build_graph.R' 'estimate_memory.R' 'helper.R' + 'helpers_app.R' 'internal_measure.R' 'train_auto.R' - 'helpers_app.R' - 'visualization_app.R' 'visualization.R' + 'visualization_app.R' 'zzz.R' diff --git a/NAMESPACE b/NAMESPACE index a0dd2d3..df912ff 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -27,10 +27,16 @@ export(LearnerRegrAutoNnet) export(LearnerRegrAutoRanger) export(LearnerRegrAutoSVM) export(LearnerRegrAutoXgboost) +export(cost_over_time) export(estimate_memory) export(internal_measure_catboost) export(internal_measure_lightgbm) export(internal_measure_xgboost) +export(marginal_plot) +export(parallel_coordinates) +export(pareto_front) +export(partial_dependence_plot) +export(visualize) import(R6) import(checkmate) import(data.table) diff --git a/R/visualization.R b/R/visualization.R index 41d5281..3d453b0 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -1,6 +1,6 @@ #' @title Cost-Over-Time Plot #' -#' @description +#' @description Plots the cost (objective) over time, where the time variable can be set by the user. #' #' @template param_instance #' @param time (`character(1)`)\cr @@ -35,9 +35,9 @@ cost_over_time = function(instance, time = NULL, theme = ggplot2::theme_minimal( theme } -#' @title Marginal Plot for Hyperparameters +#' @title Marginal Plot #' -#' @description +#' @description Creates 2D marginal plots for evaluated configurations. #' #' @template param_instance #' @param x (`character(1)`) @@ -105,7 +105,7 @@ marginal_plot = function(instance, x, y = NULL, theme = ggplot2::theme_minimal() #' @title Parallel Coordinates Plot #' -#' @description Adapted from [mlr3viz::autoplot()] with `type == "parallel"`. Since the hyperparameters of each individual learner are conditioned on `branch.selection`, missing values are expected in the archive data. When standardizing the hyperparameter values (referred to as "x values" in the following to be consistent with `mlr3viz` documentation), `na.omit == TRUE` is used to compute `mean()` and `sd()`. +#' @description Adapted from [mlr3viz::autoplot()] with `type == "parallel"`. Since the hyperparameters of each individual learner are conditioned on `branch.selection`, missing values are expected in the archive data. When standardizing the hyperparameter values (referred to as "x values" in the following to be consistent with `mlr3viz` documentation), `na.omit == TRUE` is used to compute `mean()` and `stats::sd()`. #' #' @template param_instance #' @param cols_x (`character()`) @@ -147,17 +147,17 @@ parallel_coordinates = function( # rescale data_n = data_n[, lapply(.SD, function(x) { - if (sd(x, na.rm = TRUE) %in% c(0, NA)) { + if (stats::sd(x, na.rm = TRUE) %in% c(0, NA)) { rep(0, length(x)) } else { - (x - mean(x, na.rm = TRUE)) / sd(x, na.rm = TRUE) + (x - mean(x, na.rm = TRUE)) / stats::sd(x, na.rm = TRUE) } })] data_c = data_c[, lapply(.SD, function(x) { - if (sd(x, na.rm = TRUE) %in% c(0, NA)) { + if (stats::sd(x, na.rm = TRUE) %in% c(0, NA)) { rep(0, length(x)) } else { - (x - mean(unique(x), na.rm = TRUE)) / sd(unique(x), na.rm = TRUE) + (x - mean(unique(x), na.rm = TRUE)) / stats::sd(unique(x), na.rm = TRUE) } })] @@ -181,6 +181,7 @@ parallel_coordinates = function( data = merge(data, data_y, by = "id") setorderv(data, "x") + .data = NULL ggplot2::ggplot(data, mapping = ggplot2::aes( x = .data[["x"]], diff --git a/R/visualization_app.R b/R/visualization_app.R index 2685961..cc7e0c8 100644 --- a/R/visualization_app.R +++ b/R/visualization_app.R @@ -84,11 +84,11 @@ visualize = function(instance) { ) server = function(input, output, session) { - session$onSessionEnded(stopApp) + session$onSessionEnded(shiny::stopApp) # Cost over time - output$cost_over_time = renderPlot({ + output$cost_over_time = shiny::renderPlot({ if (input$cot_x == "configuration ID") { cost_over_time(instance) } else { diff --git a/man-roxygen/param_instance.R b/man-roxygen/param_instance.R index 0a13556..64750c5 100644 --- a/man-roxygen/param_instance.R +++ b/man-roxygen/param_instance.R @@ -1,3 +1,3 @@ #' @param instance (`[TuningInstanceAsyncSingleCrit]`)\cr #' Single-criterion tuning instance with Rush. -#' For [mlr3automl] learners, the tuning instance is stored in the field `$instance`. \ No newline at end of file +#' For [mlr3automl] learners, the tuning instance is stored in the field `$instance`. diff --git a/man-roxygen/param_theme.R b/man-roxygen/param_theme.R index b4a6efd..f976c95 100644 --- a/man-roxygen/param_theme.R +++ b/man-roxygen/param_theme.R @@ -1,2 +1,2 @@ #' @param theme ([ggplot2::theme()])\cr -#' The [ggplot2::theme_minimal()] is applied by default to all plots. \ No newline at end of file +#' The [ggplot2::theme_minimal()] is applied by default to all plots. diff --git a/man/cost_over_time.Rd b/man/cost_over_time.Rd new file mode 100644 index 0000000..14d52a9 --- /dev/null +++ b/man/cost_over_time.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/visualization.R +\name{cost_over_time} +\alias{cost_over_time} +\title{Cost-Over-Time Plot} +\usage{ +cost_over_time(instance, time = NULL, theme = ggplot2::theme_minimal()) +} +\arguments{ +\item{instance}{(\verb{[TuningInstanceAsyncSingleCrit]})\cr +Single-criterion tuning instance with Rush. +For \link{mlr3automl} learners, the tuning instance is stored in the field \verb{$instance}.} + +\item{time}{(\code{character(1)})\cr +Column in the archive to be interpreted as the time variable, e.g. "timestamp_xs", "timestamp_ys". +If \code{NULL} (default), the configuration ID will be used.} + +\item{theme}{(\code{\link[ggplot2:theme]{ggplot2::theme()}})\cr +The \code{\link[ggplot2:ggtheme]{ggplot2::theme_minimal()}} is applied by default to all plots.} +} +\description{ +Plots the cost (objective) over time, where the time variable can be set by the user. +} diff --git a/man/marginal_plot.Rd b/man/marginal_plot.Rd new file mode 100644 index 0000000..e5ea290 --- /dev/null +++ b/man/marginal_plot.Rd @@ -0,0 +1,26 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/visualization.R +\name{marginal_plot} +\alias{marginal_plot} +\title{Marginal Plot} +\usage{ +marginal_plot(instance, x, y = NULL, theme = ggplot2::theme_minimal()) +} +\arguments{ +\item{instance}{(\verb{[TuningInstanceAsyncSingleCrit]})\cr +Single-criterion tuning instance with Rush. +For \link{mlr3automl} learners, the tuning instance is stored in the field \verb{$instance}.} + +\item{x}{(\code{character(1)}) +Name of the parameter to be mapped to the x-axis.} + +\item{y}{(\code{character(1)}) +Name of the parameter to be mapped to the y-axis. +If \code{NULL} (default), the measure (e.g. \code{classif.ce}) is mapped to the y-axis.} + +\item{theme}{(\code{\link[ggplot2:theme]{ggplot2::theme()}})\cr +The \code{\link[ggplot2:ggtheme]{ggplot2::theme_minimal()}} is applied by default to all plots.} +} +\description{ +Creates 2D marginal plots for evaluated configurations. +} diff --git a/man/mlr3automl-package.Rd b/man/mlr3automl-package.Rd index 6fd4cfd..b6d272c 100644 --- a/man/mlr3automl-package.Rd +++ b/man/mlr3automl-package.Rd @@ -24,6 +24,7 @@ Useful links: Authors: \itemize{ \item Marc Becker \email{marcbecker@posteo.de} (\href{https://orcid.org/0000-0002-8115-0400}{ORCID}) + \item Baisu Zhou \email{baisu.zhou@outlook.com} } } diff --git a/man/parallel_coordinates.Rd b/man/parallel_coordinates.Rd new file mode 100644 index 0000000..699ce67 --- /dev/null +++ b/man/parallel_coordinates.Rd @@ -0,0 +1,32 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/visualization.R +\name{parallel_coordinates} +\alias{parallel_coordinates} +\title{Parallel Coordinates Plot} +\usage{ +parallel_coordinates( + instance, + cols_x = NULL, + trafo = FALSE, + theme = ggplot2::theme_minimal() +) +} +\arguments{ +\item{instance}{(\verb{[TuningInstanceAsyncSingleCrit]})\cr +Single-criterion tuning instance with Rush. +For \link{mlr3automl} learners, the tuning instance is stored in the field \verb{$instance}.} + +\item{cols_x}{(\code{character()}) +Column names of x values. +By default, all untransformed x values from the search space are plotted.} + +\item{trafo}{(\code{character(1)}) +If \code{FALSE} (default), the untransformed x values are plotted. +If \code{TRUE}, the transformed x values are plotted.} + +\item{theme}{(\code{\link[ggplot2:theme]{ggplot2::theme()}})\cr +The \code{\link[ggplot2:ggtheme]{ggplot2::theme_minimal()}} is applied by default to all plots.} +} +\description{ +Adapted from \code{\link[mlr3viz:reexports]{mlr3viz::autoplot()}} with \code{type == "parallel"}. Since the hyperparameters of each individual learner are conditioned on \code{branch.selection}, missing values are expected in the archive data. When standardizing the hyperparameter values (referred to as "x values" in the following to be consistent with \code{mlr3viz} documentation), \code{na.omit == TRUE} is used to compute \code{mean()} and \code{stats::sd()}. +} diff --git a/man/param_panel.Rd b/man/param_panel.Rd new file mode 100644 index 0000000..63aae15 --- /dev/null +++ b/man/param_panel.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/helpers_app.R +\name{param_panel} +\alias{param_panel} +\title{Custom conditionalPanel for hyperparameter selection} +\usage{ +param_panel(condition, prefix, learner_ids, param_ids, ...) +} +\arguments{ +\item{condition}{(\code{character(1)})\cr +Passed to the \code{condition} argument of \verb{[shiny::conditionalPanel]}.} + +\item{prefix}{(\code{character(1)})\cr +Prefix of input slot names.} + +\item{learner_ids}{(\code{character()})\cr +Vector of all possible learner/branch IDs.} + +\item{param_ids}{(\code{character()})\cr +Vector of all possible param IDs.} + +\item{...}{(anything) +Additional arguments passed to \verb{[shiny::conditionalPanel]}.} +} +\description{ +Used for Marginal Plots and Partial Dependence Plots. +} diff --git a/man/pareto_front.Rd b/man/pareto_front.Rd new file mode 100644 index 0000000..c58f17c --- /dev/null +++ b/man/pareto_front.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/visualization.R +\name{pareto_front} +\alias{pareto_front} +\title{Pareto Front} +\usage{ +pareto_front(instance, theme = ggplot2::theme_minimal()) +} +\arguments{ +\item{instance}{(\verb{[TuningInstanceAsyncSingleCrit]})\cr +Single-criterion tuning instance with Rush. +For \link{mlr3automl} learners, the tuning instance is stored in the field \verb{$instance}.} + +\item{theme}{(\code{\link[ggplot2:theme]{ggplot2::theme()}})\cr +The \code{\link[ggplot2:ggtheme]{ggplot2::theme_minimal()}} is applied by default to all plots.} +} +\description{ +Plots the Pareto front with x-axis representing the tuning objective (e.g. \verb{"classif.ce}) and y-axis representing time (the \code{runtime_learners} column in the archive). +} diff --git a/man/partial_dependence_plot.Rd b/man/partial_dependence_plot.Rd new file mode 100644 index 0000000..f4c76ea --- /dev/null +++ b/man/partial_dependence_plot.Rd @@ -0,0 +1,43 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/visualization.R +\name{partial_dependence_plot} +\alias{partial_dependence_plot} +\title{Partial Dependence Plot} +\usage{ +partial_dependence_plot( + instance, + x, + y, + type = "default", + theme = ggplot2::theme_minimal(), + ... +) +} +\arguments{ +\item{instance}{(\verb{[TuningInstanceAsyncSingleCrit]})\cr +Single-criterion tuning instance with Rush. +For \link{mlr3automl} learners, the tuning instance is stored in the field \verb{$instance}.} + +\item{x}{(\code{character(1)}) +Name of the parameter to be mapped to the x-axis.} + +\item{y}{(\code{character(1)}) +Name of the parameter to be mapped to the y-axis.} + +\item{type}{(\code{character(1)}) +Type of the two-parameter partial dependence plot. Possible options are listed below. +\itemize{ +\item \code{"default"}: Use the default setting in \code{iml}. +\item \code{"contour"}: Create a contour plot using \verb{[ggplot2::geom_contour_filled]}. Only supported if both parameters are numerical. +} +Ignored if only one parameter is provided.} + +\item{theme}{(\code{\link[ggplot2:theme]{ggplot2::theme()}})\cr +The \code{\link[ggplot2:ggtheme]{ggplot2::theme_minimal()}} is applied by default to all plots.} + +\item{...}{(anything) +Arguments passed to \verb{[iml::FeatureEffect]}.} +} +\description{ +Creates a partial dependenc plot (PDP) via the \verb{[iml]} package. +} diff --git a/man/visualize.Rd b/man/visualize.Rd new file mode 100644 index 0000000..3e07cef --- /dev/null +++ b/man/visualize.Rd @@ -0,0 +1,14 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/visualization_app.R +\name{visualize} +\alias{visualize} +\title{Shiny App for Visualizing AutoML Results} +\usage{ +visualize(instance) +} +\arguments{ +\item{instance}{(\verb{[mlr3tuning::TuningInstanceAsyncSingleCrit]})} +} +\description{ +Shiny App for Visualizing AutoML Results +} From c1ed1cec2516c6a72c9d82193443e01b0ffc9fed Mon Sep 17 00:00:00 2001 From: b-zhou Date: Sun, 27 Oct 2024 17:42:41 +0100 Subject: [PATCH 33/41] feat: larger base_size --- R/visualization.R | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index 3d453b0..048eec1 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -9,7 +9,7 @@ #' @template param_theme #' #' @export -cost_over_time = function(instance, time = NULL, theme = ggplot2::theme_minimal()) { +cost_over_time = function(instance, time = NULL, theme = ggplot2::theme_minimal(base_size = 14)) { archive = instance$archive # there should only be one objective, e.g. `classif.ce` objective = archive$cols_y @@ -48,7 +48,7 @@ cost_over_time = function(instance, time = NULL, theme = ggplot2::theme_minimal( #' @template param_theme #' #' @export -marginal_plot = function(instance, x, y = NULL, theme = ggplot2::theme_minimal()) { +marginal_plot = function(instance, x, y = NULL, theme = ggplot2::theme_minimal(base_size = 14)) { archive = instance$archive param_ids = archive$cols_x assert_choice(x, param_ids) @@ -119,7 +119,7 @@ marginal_plot = function(instance, x, y = NULL, theme = ggplot2::theme_minimal() #' @export parallel_coordinates = function( instance, cols_x = NULL, trafo = FALSE, - theme = ggplot2::theme_minimal() + theme = ggplot2::theme_minimal(base_size = 14) ) { archive = instance$archive assert_subset(cols_x, c(archive$cols_x, paste0("x_domain_", archive$cols_x))) @@ -228,7 +228,7 @@ parallel_coordinates = function( #' @export partial_dependence_plot = function( instance, x, y, type = "default", - theme = ggplot2::theme_minimal(), + theme = ggplot2::theme_minimal(base_size = 14), ... ) { archive = instance$archive @@ -337,7 +337,7 @@ partial_dependence_plot = function( #' @template param_theme #' #' @export -pareto_front = function(instance, theme = ggplot2::theme_minimal()) { +pareto_front = function(instance, theme = ggplot2::theme_minimal(base_size = 14)) { # adopted from `Archive$best()` for multi-crit archive = instance$archive tab = archive$finished_data From f66d6f97af6e2a601d0b80b8a80fb93be1d5f1c2 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Sun, 27 Oct 2024 18:43:34 +0100 Subject: [PATCH 34/41] feat: cost over time incumbent --- R/visualization.R | 56 +++++++++++++++++++++++++++++++------------ R/visualization_app.R | 8 +++++-- 2 files changed, 47 insertions(+), 17 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index 048eec1..9defdc5 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -6,32 +6,58 @@ #' @param time (`character(1)`)\cr #' Column in the archive to be interpreted as the time variable, e.g. "timestamp_xs", "timestamp_ys". #' If `NULL` (default), the configuration ID will be used. +#' @param incumbent (`logical(1)`)\cr +#' Whether to plot the incumbent points only. Defaults to `TRUE`. #' @template param_theme #' #' @export -cost_over_time = function(instance, time = NULL, theme = ggplot2::theme_minimal(base_size = 14)) { +cost_over_time = function( + instance, time = NULL, incumbent = TRUE, + theme = ggplot2::theme_minimal(base_size = 14) +) { archive = instance$archive - # there should only be one objective, e.g. `classif.ce` - objective = archive$cols_y + archive_data = as.data.table(archive) + set(archive_data, j = "config_id", value = seq_row(archive_data)) + archive_data = archive_data[archive_data$state == "finished"] + assert_choice(time, names(archive_data), null.ok = TRUE) + + x = if (is.null(time)) { + archive_data$config_id + } else { + archive_data[[time]] + } + xlabel = time %??% "configuration ID" .data = NULL - if (is.null(time)) { - x = seq_row(archive$data) - g = ggplot2::ggplot(data = as.data.table(archive), ggplot2::aes( + if (!incumbent) { + g = ggplot2::ggplot(data = archive_data, ggplot2::aes( x = x, - y = .data[[objective]] + y = .data[[archive$cols_y]] )) + - ggplot2::labs(x = "configuration ID") + ggplot2::geom_point() + + ggplot2::geom_line() + + ggplot2::labs(x = xlabel) + + theme + return(g) + } + + # if incumbent, plot the best objective at each time point + dt = data.table(x, archive_data[[archive$cols_y]]) + names(dt) = c("time", "objective") + min_or_max = if (archive$codomain$maximization_to_minimization == 1) { + min } else { - assert_choice(time, names(as.data.table(archive))) - g = ggplot2::ggplot(data = as.data.table(archive), ggplot2::aes( - x = .data[[time]], - y = .data[[objective]] - )) + max } - - g + ggplot2::geom_point() + + objective = NULL # avoid RMD check issue + dt[, objective := min_or_max(objective), by = "time"] + dt = unique(dt, by = "time") + + .data = NULL + ggplot2::ggplot(dt, ggplot2::aes(x = .data$time, y = .data$objective)) + + ggplot2::geom_point() + ggplot2::geom_line() + + ggplot2::labs(x = xlabel, y = archive$cols_y) + theme } diff --git a/R/visualization_app.R b/R/visualization_app.R index cc7e0c8..f0b2fa2 100644 --- a/R/visualization_app.R +++ b/R/visualization_app.R @@ -17,6 +17,10 @@ visualize = function(instance) { shiny::radioButtons("cot_x", label = "Select x-axis:", choices = c("configuration ID", "timestamp_xs", "timestamp_ys") + ), + shiny::checkboxInput("cot_incumbent", + label = "Plot incumbent", + value = TRUE ) ), param_panel( @@ -90,9 +94,9 @@ visualize = function(instance) { # Cost over time output$cost_over_time = shiny::renderPlot({ if (input$cot_x == "configuration ID") { - cost_over_time(instance) + cost_over_time(instance, incumbent = input$cot_incumbent) } else { - cost_over_time(instance, time = input$cot_x) + cost_over_time(instance, time = input$cot_x, incumbent = input$cot_incumbent) } }) From 536769f091e273aab98095e06b706eecb1c86cbf Mon Sep 17 00:00:00 2001 From: b-zhou Date: Sun, 27 Oct 2024 19:47:46 +0100 Subject: [PATCH 35/41] doc: plots --- man/cost_over_time.Rd | 10 +++++++++- man/marginal_plot.Rd | 7 ++++++- man/parallel_coordinates.Rd | 2 +- man/pareto_front.Rd | 2 +- man/partial_dependence_plot.Rd | 2 +- 5 files changed, 18 insertions(+), 5 deletions(-) diff --git a/man/cost_over_time.Rd b/man/cost_over_time.Rd index 14d52a9..869e1d0 100644 --- a/man/cost_over_time.Rd +++ b/man/cost_over_time.Rd @@ -4,7 +4,12 @@ \alias{cost_over_time} \title{Cost-Over-Time Plot} \usage{ -cost_over_time(instance, time = NULL, theme = ggplot2::theme_minimal()) +cost_over_time( + instance, + time = NULL, + incumbent = TRUE, + theme = ggplot2::theme_minimal(base_size = 14) +) } \arguments{ \item{instance}{(\verb{[TuningInstanceAsyncSingleCrit]})\cr @@ -15,6 +20,9 @@ For \link{mlr3automl} learners, the tuning instance is stored in the field \verb Column in the archive to be interpreted as the time variable, e.g. "timestamp_xs", "timestamp_ys". If \code{NULL} (default), the configuration ID will be used.} +\item{incumbent}{(\code{logical(1)})\cr +Whether to plot the incumbent points only. Defaults to \code{TRUE}.} + \item{theme}{(\code{\link[ggplot2:theme]{ggplot2::theme()}})\cr The \code{\link[ggplot2:ggtheme]{ggplot2::theme_minimal()}} is applied by default to all plots.} } diff --git a/man/marginal_plot.Rd b/man/marginal_plot.Rd index e5ea290..2cee31c 100644 --- a/man/marginal_plot.Rd +++ b/man/marginal_plot.Rd @@ -4,7 +4,12 @@ \alias{marginal_plot} \title{Marginal Plot} \usage{ -marginal_plot(instance, x, y = NULL, theme = ggplot2::theme_minimal()) +marginal_plot( + instance, + x, + y = NULL, + theme = ggplot2::theme_minimal(base_size = 14) +) } \arguments{ \item{instance}{(\verb{[TuningInstanceAsyncSingleCrit]})\cr diff --git a/man/parallel_coordinates.Rd b/man/parallel_coordinates.Rd index 699ce67..8c9b039 100644 --- a/man/parallel_coordinates.Rd +++ b/man/parallel_coordinates.Rd @@ -8,7 +8,7 @@ parallel_coordinates( instance, cols_x = NULL, trafo = FALSE, - theme = ggplot2::theme_minimal() + theme = ggplot2::theme_minimal(base_size = 14) ) } \arguments{ diff --git a/man/pareto_front.Rd b/man/pareto_front.Rd index c58f17c..8a3ab97 100644 --- a/man/pareto_front.Rd +++ b/man/pareto_front.Rd @@ -4,7 +4,7 @@ \alias{pareto_front} \title{Pareto Front} \usage{ -pareto_front(instance, theme = ggplot2::theme_minimal()) +pareto_front(instance, theme = ggplot2::theme_minimal(base_size = 14)) } \arguments{ \item{instance}{(\verb{[TuningInstanceAsyncSingleCrit]})\cr diff --git a/man/partial_dependence_plot.Rd b/man/partial_dependence_plot.Rd index f4c76ea..869eebf 100644 --- a/man/partial_dependence_plot.Rd +++ b/man/partial_dependence_plot.Rd @@ -9,7 +9,7 @@ partial_dependence_plot( x, y, type = "default", - theme = ggplot2::theme_minimal(), + theme = ggplot2::theme_minimal(base_size = 14), ... ) } From b4f5291e20daf9ddc5a870a4acb2ac545c00bdb2 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Sun, 3 Nov 2024 15:00:15 +0100 Subject: [PATCH 36/41] fix: cost over time correct incumbent --- R/visualization.R | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index 9defdc5..8645820 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -41,16 +41,16 @@ cost_over_time = function( return(g) } - # if incumbent, plot the best objective at each time point + # if incumbent: dt = data.table(x, archive_data[[archive$cols_y]]) names(dt) = c("time", "objective") min_or_max = if (archive$codomain$maximization_to_minimization == 1) { - min + cummin } else { - max + cummax } objective = NULL # avoid RMD check issue - dt[, objective := min_or_max(objective), by = "time"] + dt[, objective := min_or_max(objective)] dt = unique(dt, by = "time") .data = NULL From d3fe32ef1a16406fcaca3d8d319220403631bf4b Mon Sep 17 00:00:00 2001 From: b-zhou Date: Sun, 3 Nov 2024 18:28:48 +0100 Subject: [PATCH 37/41] feat: marginal plot surface --- R/visualization.R | 138 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 107 insertions(+), 31 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index 8645820..78886d6 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -66,65 +66,141 @@ cost_over_time = function( #' @description Creates 2D marginal plots for evaluated configurations. #' #' @template param_instance -#' @param x (`character(1)`) +#' @param x (`character(1)`)\cr #' Name of the parameter to be mapped to the x-axis. -#' @param y (`character(1)`) +#' @param y (`character(1)`)\cr #' Name of the parameter to be mapped to the y-axis. #' If `NULL` (default), the measure (e.g. `classif.ce`) is mapped to the y-axis. +#' @param trafo (`logical(1)`) +#' If `FALSE`, the untransformed parameter values are plotted. +#' If `TRUE` (default), the transformed values are plotted. +#' @param surface (`character(1)`)\cr +#' If `TRUE` (default), interpolate the prediction surface with a surrogate model. +#' Ignored if `y` is provided. +#' Not supported for categorical parameters. +#' @param grid_resolution (`numeric()`)\cr +#' Number of grid points per axis for the surface plot. +#' Ignored if `y` is not provided or `surface` is set to `FALSE`. +#' Not supported for categorical parameters. #' @template param_theme #' #' @export -marginal_plot = function(instance, x, y = NULL, theme = ggplot2::theme_minimal(base_size = 14)) { +marginal_plot = function( + instance, x, y = NULL, trafo = TRUE, + surface = TRUE, grid_resolution = 100, + theme = ggplot2::theme_minimal(base_size = 14) +) { archive = instance$archive - param_ids = archive$cols_x + param_ids = setdiff(archive$cols_x, "branch.selection") assert_choice(x, param_ids) assert_choice(y, param_ids, null.ok = TRUE) # use transformed values if trafo is set - x_trafo = paste0("x_domain_", x) - y_trafo = if (!is.null(y)) paste0("x_domain_", y) else NULL + xname = x + yname = y + x = if (trafo) paste0("x_domain_", x) else x + y = if (trafo && !is.null(y)) paste0("x_domain_", y) else NULL - # there should only be one objective, e.g. `classif.ce` - measure = archive$cols_y - - data = na.omit(as.data.table(archive), cols = c(x_trafo, y_trafo)) + data = na.omit(as.data.table(archive), cols = c(x, y, archive$cols_y)) .data = NULL # no param provided for y if (is.null(y)) { - g = ggplot2:: ggplot(data = data, ggplot2::aes( - x = .data[[x_trafo]], - y = .data[[measure]] + g = ggplot2::ggplot(data = data, ggplot2::aes( + x = .data[[x]], + y = .data[[archive$cols_y]] )) + ggplot2::geom_point(alpha = 0.6) + + ggplot2::labs(x = xname) + + ggplot2::scale_x_continuous( + expand = c(0.01, 0.01), + transform = if (archive$search_space$is_logscale[[xname]]) "log10" else "identity" + ) + theme - if (archive$search_space$is_logscale[[x]]) { - g = g + ggplot2::scale_x_log10() - } - return(g) } - # param provided for y - g = ggplot2::ggplot(data = data, ggplot2::aes( - x = .data[[x_trafo]], - y = .data[[y_trafo]], - col = .data[[measure]] - )) + + # param provided for y, but surface is FALSE + if (!surface) { + g = ggplot2::ggplot( + data = data, + ggplot2::aes(x = .data[[x]], y = .data[[y]], fill = .data[[archive$cols_y]]), + shape = 21, + size = 3, + stroke = 0.5 + ) + ggplot2::geom_point(alpha = 0.6) + - ggplot2::scale_color_viridis_c() + - ggplot2::labs(x = x, y = y) + + ggplot2::scale_x_continuous( + expand = c(0.01, 0.01), + transform = if (archive$search_space$is_logscale[[xname]]) "log10" else "identity" + ) + + ggplot2::scale_y_continuous( + expand = c(0.01, 0.01), + transform = if (archive$search_space$is_logscale[[yname]]) "log10" else "identity" + ) + + ggplot2::guides(fill = ggplot2::guide_colorbar(barwidth = 0.5, barheight = 10)) + + ggplot2::scale_fill_viridis_c() + + ggplot2::labs(x = xname, y = yname) + theme - if (archive$search_space$is_logscale[[x]]) { - g = g + ggplot2::scale_x_log10() - } - if (archive$search_space$is_logscale[[y]]) { - g = g + ggplot2::scale_y_log10() + return(g) } + # surface is TRUE + assert_data_table(data[, x, with = FALSE], types = "numeric") + assert_data_table(data[, y, with = FALSE], types = "numeric") + assert_number(grid_resolution) + + # adopted from https://github.com/mlr-org/mlr3viz/blob/db6e547bf25220e710599456f564494cbdfe6e68/R/OptimInstanceBatchSingleCrit.R#L268-L310 + surrogate_data = data[, c(x, y, archive$cols_y), with = FALSE] + task = mlr3::TaskRegr$new("surface", surrogate_data, target = archive$cols_y) + surrogate = lrn("regr.ranger") + surrogate$train(task) + # assert_learner(surrogate, task) + + x_min = archive$search_space$lower[[xname]] + x_max = archive$search_space$upper[[xname]] + y_min = archive$search_space$lower[[yname]] + y_max = archive$search_space$upper[[yname]] + + x_grid = seq(x_min, x_max, by = (x_max - x_min) / grid_resolution) + y_grid = seq(y_min, y_max, by = (y_max - y_min) / grid_resolution) + + x_grid = if (trafo && archive$search_space$is_logscale[[xname]]) exp(x_grid) else x_grid + y_grid = if (trafo && archive$search_space$is_logscale[[yname]]) exp(y_grid) else y_grid + + data_i = set_names(expand.grid(x_grid, y_grid), c(x, y)) + + setDT(data_i)[, (archive$cols_y) := surrogate$predict_newdata(data_i)$response] + + g = ggplot2::ggplot() + + ggplot2::geom_raster( + data = data_i, + ggplot2::aes(x = .data[[x]], y = .data[[y]], fill = .data[[archive$cols_y]]) + ) + + ggplot2::geom_point( + data, + mapping = ggplot2::aes(x = .data[[x]], y = .data[[y]], fill = .data[[archive$cols_y]]), + shape = 21, + size = 3, + stroke = 0.5, + alpha = 0.8 + ) + + ggplot2::scale_x_continuous( + expand = c(0.01, 0.01), + transform = if (archive$search_space$is_logscale[[xname]]) "log10" else "identity" + ) + + ggplot2::scale_y_continuous( + expand = c(0.01, 0.01), + transform = if (archive$search_space$is_logscale[[yname]]) "log10" else "identity" + ) + + ggplot2::guides(fill = ggplot2::guide_colorbar(barwidth = 0.5, barheight = 10)) + + ggplot2::scale_fill_viridis_c() + + ggplot2::labs(x = xname, y = yname) + + theme + return(g) } @@ -137,7 +213,7 @@ marginal_plot = function(instance, x, y = NULL, theme = ggplot2::theme_minimal(b #' @param cols_x (`character()`) #' Column names of x values. #' By default, all untransformed x values from the search space are plotted. -#' @param trafo (`character(1)`) +#' @param trafo (`logical(1)`) #' If `FALSE` (default), the untransformed x values are plotted. #' If `TRUE`, the transformed x values are plotted. #' @template param_theme From 24aeeeef468569ba36b401d41027dbc0aaa962c3 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Sun, 3 Nov 2024 18:29:04 +0100 Subject: [PATCH 38/41] feat: marginal plot surface selection button --- R/visualization_app.R | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/R/visualization_app.R b/R/visualization_app.R index f0b2fa2..19a5ad7 100644 --- a/R/visualization_app.R +++ b/R/visualization_app.R @@ -27,7 +27,13 @@ visualize = function(instance) { "input.nav === 'Marginal Plots'", "mp", learner_ids, - param_ids + param_ids, + shiny::radioButtons("mp_surface", + label = "Plot surface?", + choices = list("No", "Yes"), + selected = "No", + inline = TRUE + ) ), shiny::conditionalPanel( "input.nav === 'Parallel Coordinates'", @@ -109,10 +115,11 @@ visualize = function(instance) { }) output$marginal_plot = shiny::renderPlot({ + surface = input$mp_surface == "Yes" if (input$mp_y == "NULL") { - marginal_plot(instance, x = input$mp_x) + marginal_plot(instance, x = input$mp_x, surface = surface) } else { - marginal_plot(instance, x = input$mp_x, y = input$mp_y) + marginal_plot(instance, x = input$mp_x, y = input$mp_y, surface = surface) } }) From ee39b6a1dcb957d34d7c3fd03e8e4ed8c36fe588 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Sun, 3 Nov 2024 19:12:28 +0100 Subject: [PATCH 39/41] fix: check for param type before adjusting axis scale --- R/visualization.R | 64 +++++++++++++++++++++++++++-------------------- 1 file changed, 37 insertions(+), 27 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index 78886d6..658df4c 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -112,14 +112,16 @@ marginal_plot = function( y = .data[[archive$cols_y]] )) + ggplot2::geom_point(alpha = 0.6) + - ggplot2::labs(x = xname) + - ggplot2::scale_x_continuous( - expand = c(0.01, 0.01), - transform = if (archive$search_space$is_logscale[[xname]]) "log10" else "identity" - ) + - theme + ggplot2::labs(x = xname) - return(g) + if (archive$search_space$is_number[[xname]]) { + g = g + ggplot2::scale_x_continuous( + expand = c(0.01, 0.01), + transform = if (archive$search_space$is_logscale[[xname]]) "log10" else "identity" + ) + } + + return(g + theme) } # param provided for y, but surface is FALSE @@ -132,20 +134,24 @@ marginal_plot = function( stroke = 0.5 ) + ggplot2::geom_point(alpha = 0.6) + - ggplot2::scale_x_continuous( - expand = c(0.01, 0.01), - transform = if (archive$search_space$is_logscale[[xname]]) "log10" else "identity" - ) + - ggplot2::scale_y_continuous( - expand = c(0.01, 0.01), - transform = if (archive$search_space$is_logscale[[yname]]) "log10" else "identity" - ) + ggplot2::guides(fill = ggplot2::guide_colorbar(barwidth = 0.5, barheight = 10)) + ggplot2::scale_fill_viridis_c() + - ggplot2::labs(x = xname, y = yname) + - theme + ggplot2::labs(x = xname, y = yname) - return(g) + if (archive$search_space$is_number[[xname]]) { + g = g + ggplot2::scale_x_continuous( + expand = c(0.01, 0.01), + transform = if (archive$search_space$is_logscale[[xname]]) "log10" else "identity" + ) + } + if (archive$search_space$is_number[[yname]]) { + g = g + ggplot2::scale_y_continuous( + expand = c(0.01, 0.01), + transform = if (archive$search_space$is_logscale[[yname]]) "log10" else "identity" + ) + } + + return(g + theme) } # surface is TRUE @@ -188,20 +194,24 @@ marginal_plot = function( stroke = 0.5, alpha = 0.8 ) + - ggplot2::scale_x_continuous( + ggplot2::guides(fill = ggplot2::guide_colorbar(barwidth = 0.5, barheight = 10)) + + ggplot2::scale_fill_viridis_c() + + ggplot2::labs(x = xname, y = yname) + + if (archive$search_space$is_number[[xname]]) { + g = g + ggplot2::scale_x_continuous( expand = c(0.01, 0.01), transform = if (archive$search_space$is_logscale[[xname]]) "log10" else "identity" - ) + - ggplot2::scale_y_continuous( + ) + } + if (archive$search_space$is_number[[yname]]) { + g = g + ggplot2::scale_y_continuous( expand = c(0.01, 0.01), transform = if (archive$search_space$is_logscale[[yname]]) "log10" else "identity" - ) + - ggplot2::guides(fill = ggplot2::guide_colorbar(barwidth = 0.5, barheight = 10)) + - ggplot2::scale_fill_viridis_c() + - ggplot2::labs(x = xname, y = yname) + - theme + ) + } - return(g) + return(g + theme) } From a28987cded33abfcefd6ccb4e2e6138c8c3a99d8 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Sun, 3 Nov 2024 19:14:28 +0100 Subject: [PATCH 40/41] doc: update mp & fix pc --- R/visualization.R | 2 +- man/marginal_plot.Rd | 21 +++++++++++++++++++-- man/parallel_coordinates.Rd | 2 +- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/R/visualization.R b/R/visualization.R index 658df4c..4b8279f 100644 --- a/R/visualization.R +++ b/R/visualization.R @@ -78,7 +78,7 @@ cost_over_time = function( #' If `TRUE` (default), interpolate the prediction surface with a surrogate model. #' Ignored if `y` is provided. #' Not supported for categorical parameters. -#' @param grid_resolution (`numeric()`)\cr +#' @param grid_resolution (`numeric(1)`)\cr #' Number of grid points per axis for the surface plot. #' Ignored if `y` is not provided or `surface` is set to `FALSE`. #' Not supported for categorical parameters. diff --git a/man/marginal_plot.Rd b/man/marginal_plot.Rd index 2cee31c..c2f3690 100644 --- a/man/marginal_plot.Rd +++ b/man/marginal_plot.Rd @@ -8,6 +8,9 @@ marginal_plot( instance, x, y = NULL, + trafo = TRUE, + surface = TRUE, + grid_resolution = 100, theme = ggplot2::theme_minimal(base_size = 14) ) } @@ -16,13 +19,27 @@ marginal_plot( Single-criterion tuning instance with Rush. For \link{mlr3automl} learners, the tuning instance is stored in the field \verb{$instance}.} -\item{x}{(\code{character(1)}) +\item{x}{(\code{character(1)})\cr Name of the parameter to be mapped to the x-axis.} -\item{y}{(\code{character(1)}) +\item{y}{(\code{character(1)})\cr Name of the parameter to be mapped to the y-axis. If \code{NULL} (default), the measure (e.g. \code{classif.ce}) is mapped to the y-axis.} +\item{trafo}{(\code{logical(1)}) +If \code{FALSE}, the untransformed parameter values are plotted. +If \code{TRUE} (default), the transformed values are plotted.} + +\item{surface}{(\code{character(1)})\cr +If \code{TRUE} (default), interpolate the prediction surface with a surrogate model. +Ignored if \code{y} is provided. +Not supported for categorical parameters.} + +\item{grid_resolution}{(\code{numeric(1)})\cr +Number of grid points per axis for the surface plot. +Ignored if \code{y} is not provided or \code{surface} is set to \code{FALSE}. +Not supported for categorical parameters.} + \item{theme}{(\code{\link[ggplot2:theme]{ggplot2::theme()}})\cr The \code{\link[ggplot2:ggtheme]{ggplot2::theme_minimal()}} is applied by default to all plots.} } diff --git a/man/parallel_coordinates.Rd b/man/parallel_coordinates.Rd index 8c9b039..ddc3180 100644 --- a/man/parallel_coordinates.Rd +++ b/man/parallel_coordinates.Rd @@ -20,7 +20,7 @@ For \link{mlr3automl} learners, the tuning instance is stored in the field \verb Column names of x values. By default, all untransformed x values from the search space are plotted.} -\item{trafo}{(\code{character(1)}) +\item{trafo}{(\code{logical(1)}) If \code{FALSE} (default), the untransformed x values are plotted. If \code{TRUE}, the transformed x values are plotted.} From e69546abc2ba69a2d12cdd1d8b5993f4128bd583 Mon Sep 17 00:00:00 2001 From: b-zhou Date: Sun, 3 Nov 2024 19:15:14 +0100 Subject: [PATCH 41/41] test: update mp tests --- tests/testthat/test_visualization.R | 81 +++++++++++++++++------------ 1 file changed, 48 insertions(+), 33 deletions(-) diff --git a/tests/testthat/test_visualization.R b/tests/testthat/test_visualization.R index 05c7344..83cef1e 100644 --- a/tests/testthat/test_visualization.R +++ b/tests/testthat/test_visualization.R @@ -6,9 +6,9 @@ skip_if_not_installed(c("glmnet", "kknn", "ranger", "e1071")) test_that("cost over time works", { task = tsk("penguins") - set.seed(1453) flush_redis() - rush_plan(n_workers = 2) + set.seed(1453) + rush_plan(n_workers = 1) learner = lrn("classif.auto_ranger", small_data_size = 1, @@ -29,9 +29,9 @@ test_that("marginal plot works", { task = tsk("penguins") # numeric vs numeric - set.seed(1453) flush_redis() - rush_plan(n_workers = 2) + set.seed(1453) + rush_plan(n_workers = 1) learner_glmnet = lrn("classif.auto_glmnet", small_data_size = 1, @@ -44,15 +44,23 @@ test_that("marginal plot works", { "mp-numeric-numeric", marginal_plot(learner_glmnet$instance, x = "glmnet.alpha", y = "glmnet.s") ) + vdiffr::expect_doppelganger( + "mp-numeric-numeric-no-surface", + marginal_plot(learner_glmnet$instance, x = "glmnet.alpha", y = "glmnet.s", surface = FALSE) + ) + vdiffr::expect_doppelganger( + "mp-numeric-numeric-different-grid", + marginal_plot(learner_glmnet$instance, x = "glmnet.alpha", y = "glmnet.s", surface = TRUE, grid_resolution = 2) + ) vdiffr::expect_doppelganger( "mp-numeric-numeric2", marginal_plot(learner_glmnet$instance, x = "glmnet.s", y = "glmnet.alpha") ) # numeric vs factor - set.seed(1453) flush_redis() - rush_plan(n_workers = 2) + set.seed(1453) + rush_plan(n_workers = 1) learner_kknn = lrn("classif.auto_kknn", small_data_size = 1, @@ -61,19 +69,26 @@ test_that("marginal plot works", { terminator = trm("evals", n_evals = 6) ) learner_kknn$train(task) + + # plot surface by default => not supported for categorical param + expect_error( + marginal_plot(learner_kknn$instance, x = "kknn.distance", y = "kknn.kernel"), + "Assertion(.+)numeric(.+)character" + ) + vdiffr::expect_doppelganger( "mp-numeric-factor", - marginal_plot(learner_kknn$instance, x = "kknn.distance", y = "kknn.kernel") + marginal_plot(learner_kknn$instance, x = "kknn.distance", y = "kknn.kernel", surface = FALSE) ) vdiffr::expect_doppelganger( "mp-factor-numeric", - marginal_plot(learner_kknn$instance, x = "kknn.kernel", y = "kknn.distance") + marginal_plot(learner_kknn$instance, x = "kknn.kernel", y = "kknn.distance", surface = FALSE) ) # numeric vs logical - set.seed(1453) flush_redis() - rush_plan(n_workers = 2) + set.seed(1453) + rush_plan(n_workers = 1) learner_ranger = lrn("classif.auto_ranger", small_data_size = 1, @@ -84,41 +99,41 @@ test_that("marginal plot works", { learner_ranger$train(task) vdiffr::expect_doppelganger( "mp-numeric-logical", - marginal_plot(learner_ranger$instance, x = "ranger.num.trees", y = "ranger.replace") + marginal_plot(learner_ranger$instance, x = "ranger.num.trees", y = "ranger.replace", surface = FALSE) ) vdiffr::expect_doppelganger( "mp-logical-numeric", - marginal_plot(learner_ranger$instance, x = "ranger.replace", y = "ranger.num.trees") + marginal_plot(learner_ranger$instance, x = "ranger.replace", y = "ranger.num.trees", surface = FALSE) ) }) -test_that("marginal plot accepts params on different branches", { - task = tsk("penguins") +# test_that("marginal plot accepts params on different branches", { +# task = tsk("penguins") - set.seed(1453) - flush_redis() - rush_plan(n_workers = 2) +# set.seed(1453) +# flush_redis() +# rush_plan(n_workers = 1) - learner = lrn("classif.auto", - learner_ids = c("kknn", "svm"), - small_data_size = 1, - resampling = rsmp("holdout"), - measure = msr("classif.ce"), - terminator = trm("evals", n_evals = 6) - ) - learner$train(task) - vdiffr::expect_doppelganger( - "mp-different-branches", - marginal_plot(learner$instance, x = "kknn.distance", y = "svm.cost") - ) -}) +# learner = lrn("classif.auto", +# learner_ids = c("kknn", "svm"), +# small_data_size = 1, +# resampling = rsmp("holdout"), +# measure = msr("classif.ce"), +# terminator = trm("evals", n_evals = 6) +# ) +# learner$train(task) +# vdiffr::expect_doppelganger( +# "mp-different-branches", +# marginal_plot(learner$instance, x = "kknn.distance", y = "svm.cost") +# ) +# }) test_that("marginal plot handles dependence", { task = tsk("penguins") - set.seed(1453) flush_redis() - rush_plan(n_workers = 2) + set.seed(1453) + rush_plan(n_workers = 1) learner_svm = lrn("classif.auto_svm", small_data_size = 1, @@ -129,7 +144,7 @@ test_that("marginal plot handles dependence", { learner_svm$train(task) vdiffr::expect_doppelganger( "mp-dependence", - marginal_plot(learner_svm$instance, x = "svm.kernel", y = "svm.degree") + marginal_plot(learner_svm$instance, x = "svm.kernel", y = "svm.degree", surface = FALSE) ) })