diff --git a/R/ParamSet.R b/R/ParamSet.R index 1aeb3087..16275b13 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -379,7 +379,9 @@ ParamSet = R6Class("ParamSet", return(private$.values) } + xs = Reduce(function(val, fun) fun(val), self$callbacks, xs) self$assert(xs) + if (length(xs) == 0L) xs = named_list() private$.values = xs }, @@ -390,6 +392,14 @@ ParamSet = R6Class("ParamSet", extra_values = function() { private$.values[names(private$.values) %nin% names(private$.params)] + }, + + callbacks = function(val) { + if (!missing(val)) { + assert_list(val, types = "function", any.missing = FALSE) + private$.callbacks = val + } + private$.callbacks } ), @@ -399,6 +409,7 @@ ParamSet = R6Class("ParamSet", .params = NULL, .values = named_list(), .deps = data.table(id = character(0L), on = character(0L), cond = list()), + .callbacks = list(), # return a slot / AB, as a named vec, named with id (and can enforce a certain vec-type) get_member_with_idnames = function(member, astype) set_names(astype(map(self$params, member)), names(self$params)), diff --git a/R/ParamSetCollection.R b/R/ParamSetCollection.R index fe163466..2f7846bd 100644 --- a/R/ParamSetCollection.R +++ b/R/ParamSetCollection.R @@ -115,13 +115,11 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, }, values = function(xs) { - sets = private$.sets - names(sets) = map_chr(sets, "set_id") if (!missing(xs)) { - assert_list(xs) + xs = Reduce(function(val, fun) fun(val), self$callbacks, xs) self$assert(xs) # make sure everything is valid and feasible - for (s in sets) { + for (s in private$.sets) { # retrieve sublist for each set, then assign it in set (after removing prefix) psids = names(s$params) if (s$set_id != "") { @@ -134,6 +132,8 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet, s$values = pv } } + sets = private$.sets + names(sets) = map_chr(sets, "set_id") vals = map(sets, "values") vals = unlist(vals, recursive = FALSE) if (!length(vals)) vals = named_list() diff --git a/tests/testthat/test_ParamSet.R b/tests/testthat/test_ParamSet.R index 1956f01c..55b7146b 100644 --- a/tests/testthat/test_ParamSet.R +++ b/tests/testthat/test_ParamSet.R @@ -274,3 +274,130 @@ test_that("required tag, empty param set (#219)", { ps$ids() expect_identical(ps$ids(tags = "required"), character(0)) }) + +test_that("callbacks", { + ps = ParamSet$new(list( + ParamDbl$new(id = "x", lower = 1, tags = c("t1")), + ParamInt$new(id = "y", lower = 1, upper = 2), + ParamFct$new(id = "z", levels = letters[1:3], tags = c("t1")) + )) + ps$callbacks[[1]] = function(x) { + x$x = 2 + x + } + expect_equal(ps$values, named_list()) + ps$values$y = 1 + expect_equal(ps$values, list(y = 1, x = 2)) + ps$values$x = 1 + expect_equal(ps$values, list(y = 1, x = 2)) + ps$callbacks[[2]] = function(x) { + x$x = 1 + x + } + ps$values$y = 1 + expect_equal(ps$values, list(y = 1, x = 1)) + ps$callbacks[[2]] = function(x) { + x$x = 0 + x + } + expect_error({ps$values = list(y = 1)}, "is not >= 1") + + ps$callbacks[[2]] = function(x) { + x$x = 1 + x + } + ps$callbacks[[1]] = function(x) { + x$x = 0 + x + } + ps$values = list(y = 2) + expect_equal(ps$values, list(y = 2, x = 1)) + ps$callbacks[[1]] = NULL + ps$values = list(y = 1, x = 2) + expect_equal(ps$values, list(y = 1, x = 1)) +}) + +test_that("callbacks on ParamSetCollection", { + + psetset = function() { + ps = ParamSet$new(list(ParamUty$new("paramset", custom_check = function(x) check_class(x, "ParamSet", null.ok = TRUE)))) + psc = ParamSetCollection$new(list(ps)) + + psc$callbacks[[1]] = function(x) { + prevset = psc$values$paramset + newset = x$paramset + if (!identical(x$paramset, prevset)) { + psc$params$paramset$assert(newset) + if (!is.null(newset)) { + xcpy = x + xcpy$paramset = NULL + newset$assert(xcpy) + } else { + ParamSet$new()$assert(x) + } + psc$remove_sets("") + psc$add(ps) + if (!is.null(newset)) { + psc$add(newset) + } + } + x + } + psc + } + + ps = psetset() + + ps1 = ParamSet$new(list( + ParamDbl$new(id = "x", lower = 1, tags = c("t1")), + ParamInt$new(id = "y", lower = 1, upper = 2), + ParamFct$new(id = "z", levels = letters[1:3], tags = c("t1")) + )) + + ps2 = ParamDbl$new("a")$rep(3) + + expect_equal(names(ps$params), "paramset") + + ps$values$paramset = ps1 + + expect_equal(names(ps$params), c("paramset", "x", "y", "z")) + + ps$values$x = 1 + expect_equal(ps$values, list(paramset = ps1, x = 1)) + + ps$values = list(paramset = ps2, a_rep_1 = 0) + + expect_equal(ps$values, list(paramset = ps2, a_rep_1 = 0)) + + # The problem here is that there is an ambiguity. suppose + # > psB$values = list(x = 2) + # > ps$values = list(paramset = psA, x = 1) + # Now the command + # (A) > ps$values = list(paramset = psB, x = 1) + # and the command + # (B) > ps$values$paramset = psB + # are functionally the same, but in case (B) we wished we could + # keep the parameter values of psB. However, because (A) is done + # by things like tuning, it takes precedent and must work as + # expected. Therefore the following throws an error. + expect_error({ps$values$paramset = ps1}, "a_rep_1.* not available") + + expect_equal(ps$values, list(paramset = ps2, a_rep_1 = 0)) + + ps$values = c(list(paramset = ps1), ps1$values) + + expect_equal(ps$values, list(paramset = ps1, x = 1)) + + expect_error({ps$values = list(x = 2)}, "Parameter 'x' not available") + + expect_equal(ps$values, list(paramset = ps1, x = 1)) + + expect_error({ps$values$paramset = NULL}, "Parameter 'x' not available") + + expect_equal(ps$values, list(paramset = ps1, x = 1)) + + ps$values = list() + + expect_equal(ps$values, named_list()) + +})