From 0d464c72cd4f39d0e6d5e39f53c7db5a8a6a1396 Mon Sep 17 00:00:00 2001 From: Paolo Bosetti Date: Thu, 30 Oct 2025 11:11:50 +0100 Subject: [PATCH 1/2] add_predictions() supports further arguments to the underlying predict() --- DESCRIPTION | 2 +- R/predictions.R | 22 ++++++++++++++++------ man/add_predictions.Rd | 9 +++++++-- man/bootstrap.Rd | 4 ++-- man/resample_bootstrap.Rd | 4 ++-- man/resample_partition.Rd | 4 ++-- tests/testthat/test-predictions.R | 7 +++++++ 7 files changed, 37 insertions(+), 15 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 3513c46..de9d2f5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -31,4 +31,4 @@ Config/testthat/edition: 3 Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.2 diff --git a/R/predictions.R b/R/predictions.R index 3abfa14..155063f 100644 --- a/R/predictions.R +++ b/R/predictions.R @@ -7,7 +7,9 @@ #' `predict()` documentation for given `model` to determine valid values. #' @param ... `gather_predictions` and `spread_predictions` take #' multiple models. The name will be taken from either the argument -#' name of the name of the model. +#' name of the name of the model. `add_predictions` passes further arguments +#' to the underlying `predict` generic method; this allows, for example, to +#' also get confidence or prediction bands (when available). #' @param .pred,.model The variable names used by `gather_predictions`. #' @return A data frame. `add_prediction` adds a single new column, #' with default name `pred`, to the input `data`. @@ -26,11 +28,19 @@ #' grid <- data.frame(x = seq(0, 1, length = 10)) #' grid %>% add_predictions(m1) #' +#' # To also get confidence bands: +#' grid %>% add_predictions(m1, interval="confidence", level=0.99) +#' #' m2 <- lm(y ~ poly(x, 2), data = df) #' grid %>% spread_predictions(m1, m2) #' grid %>% gather_predictions(m1, m2) -add_predictions <- function(data, model, var = "pred", type = NULL) { - data[[var]] <- predict2(model, data, type = type) +add_predictions <- function(data, model, var = "pred", type = NULL, ...) { + pred <- predict2(model, data, type = type, ...) + if ("matrix" %in% class(pred)) { + data <- cbind(data, pred) + } else { + data[[var]] <- predict2(model, data, type = type, ...) + } data } @@ -56,10 +66,10 @@ gather_predictions <- function(data, ..., .pred = "pred", .model = "model", type vctrs::vec_rbind(!!!df, .names_to = .model) } -predict2 <- function(model, data, type = NULL) { +predict2 <- function(model, data, type = NULL, ...) { if (is.null(type)) { - stats::predict(model, data) + stats::predict(model, data, ...) } else { - stats::predict(model, data, type = type) + stats::predict(model, data, type = type, ...) } } diff --git a/man/add_predictions.Rd b/man/add_predictions.Rd index 886be73..ca58e46 100644 --- a/man/add_predictions.Rd +++ b/man/add_predictions.Rd @@ -6,7 +6,7 @@ \alias{gather_predictions} \title{Add predictions to a data frame} \usage{ -add_predictions(data, model, var = "pred", type = NULL) +add_predictions(data, model, var = "pred", type = NULL, ...) spread_predictions(data, ..., type = NULL) @@ -24,7 +24,9 @@ gather_predictions(data, ..., .pred = "pred", .model = "model", type = NULL) \item{...}{\code{gather_predictions} and \code{spread_predictions} take multiple models. The name will be taken from either the argument -name of the name of the model.} +name of the name of the model. \code{add_predictions} passes further arguments +to the underlying \code{predict} generic method; this allows, for example, to +also get confidence or prediction bands (when available).} \item{.pred, .model}{The variable names used by \code{gather_predictions}.} } @@ -49,6 +51,9 @@ m1 <- lm(y ~ x, data = df) grid <- data.frame(x = seq(0, 1, length = 10)) grid \%>\% add_predictions(m1) +# To also get confidence bands: +grid \%>\% add_predictions(m1, interval="confidence", level=0.99) + m2 <- lm(y ~ poly(x, 2), data = df) grid \%>\% spread_predictions(m1, m2) grid \%>\% gather_predictions(m1, m2) diff --git a/man/bootstrap.Rd b/man/bootstrap.Rd index 5c48ec0..26e9352 100644 --- a/man/bootstrap.Rd +++ b/man/bootstrap.Rd @@ -31,8 +31,8 @@ hist(subset(tidied, term == "(Intercept)")$estimate) } \seealso{ Other resampling techniques: +\code{\link{resample}()}, \code{\link{resample_bootstrap}()}, -\code{\link{resample_partition}()}, -\code{\link{resample}()} +\code{\link{resample_partition}()} } \concept{resampling techniques} diff --git a/man/resample_bootstrap.Rd b/man/resample_bootstrap.Rd index b56a243..f008b09 100644 --- a/man/resample_bootstrap.Rd +++ b/man/resample_bootstrap.Rd @@ -20,7 +20,7 @@ coef(lm(mpg ~ wt, data = resample_bootstrap(mtcars))) \seealso{ Other resampling techniques: \code{\link{bootstrap}()}, -\code{\link{resample_partition}()}, -\code{\link{resample}()} +\code{\link{resample}()}, +\code{\link{resample_partition}()} } \concept{resampling techniques} diff --git a/man/resample_partition.Rd b/man/resample_partition.Rd index 41490ce..6a5141f 100644 --- a/man/resample_partition.Rd +++ b/man/resample_partition.Rd @@ -24,7 +24,7 @@ rmse(mod, ex$train) \seealso{ Other resampling techniques: \code{\link{bootstrap}()}, -\code{\link{resample_bootstrap}()}, -\code{\link{resample}()} +\code{\link{resample}()}, +\code{\link{resample_bootstrap}()} } \concept{resampling techniques} diff --git a/tests/testthat/test-predictions.R b/tests/testthat/test-predictions.R index 4354d51..daf01cb 100644 --- a/tests/testthat/test-predictions.R +++ b/tests/testthat/test-predictions.R @@ -27,4 +27,11 @@ test_that("*_predictions() return expected shapes", { expect_equal(nrow(out), nrow(df) * 2) }) +test_that("add_predictions() provide intervals", { + df <- tibble::tibble(x = 1:5, y = c(1, 4, 3, 2, 5)) + mod <- lm(y ~ x, data = df) + pred <- stats::predict(mod, df, interval="conf") + out <- add_predictions(df, mod, interval = "confidence") + expect_equal(out$lwr, as.numeric(pred[,"lwr"])) +}) From 349b7c087cd3214617a600e5ee13834a1a5e2ae7 Mon Sep 17 00:00:00 2001 From: Paolo Bosetti Date: Thu, 30 Oct 2025 12:07:23 +0100 Subject: [PATCH 2/2] Updared deprecated upload-artifact@v3 github action --- .github/workflows/test-coverage.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index 27d4528..d1e76b5 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -44,7 +44,7 @@ jobs: - name: Upload test results if: failure() - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v5 with: name: coverage-test-failures path: ${{ runner.temp }}/package