Skip to content

Commit 8493d0f

Browse files
committed
rebuild docs
1 parent 39259d9 commit 8493d0f

File tree

7 files changed

+130
-15
lines changed

7 files changed

+130
-15
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ S3method(epi_keys,recipe)
88
S3method(epi_recipe,default)
99
S3method(epi_recipe,epi_df)
1010
S3method(epi_recipe,formula)
11+
S3method(predict,epi_workflow)
1112
S3method(prep,step_epi_ahead)
1213
S3method(prep,step_epi_lag)
1314
S3method(print,step_epi_ahead)

R/epi_keys.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,11 @@ epi_keys.recipe <- function(x) {
2424
x$var_info$variable[x$var_info$role %in% c("time_value", "geo_value", "key")]
2525
}
2626

27+
# a mold is a list extracted from a fitted workflow, gives info about
28+
# training data. But it doesn't have a class
29+
epi_keys_mold <- function(mold) {
30+
keys <- c("time_value", "geo_value", "key")
31+
molded_names <- names(mold$extras$roles)
32+
mold_keys <- purrr::map_chr(mold$extras$roles[molded_names %in% keys], names)
33+
unname(mold_keys)
34+
}

R/epi_recipe.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,14 +254,12 @@ is_epi_recipe <- function(x) {
254254
#'
255255
#' @export
256256
#' @examples
257-
#' library(recipes)
258257
#' library(magrittr)
259-
#' library(workflows)
260258
#'
261259
#' recipe <- epi_recipe(mpg ~ cyl, mtcars) %>%
262260
#' step_log(cyl)
263261
#'
264-
#' workflow <- workflow() %>%
262+
#' workflow <- epi_workflow() %>%
265263
#' add_epi_recipe(recipe)
266264
#'
267265
#' workflow

R/epi_workflow.R

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,39 @@ is_epi_workflow <- function(x) {
3838
inherits(x, "epi_workflow")
3939
}
4040

41-
41+
#' Predict from an epi_workflow
42+
#'
43+
#' @description
44+
#' This is the `predict()` method for a fit epi_workflow object. The nice thing
45+
#' about predicting from an epi_workflow is that it will:
46+
#'
47+
#' - Preprocess `new_data` using the preprocessing method specified when the
48+
#' workflow was created and fit. This is accomplished using
49+
#' [hardhat::forge()], which will apply any formula preprocessing or call
50+
#' [recipes::bake()] if a recipe was supplied.
51+
#'
52+
#' - Call [parsnip::predict.model_fit()] for you using the underlying fit
53+
#' parsnip model.
54+
#'
55+
#' - Ensure that the returned object is an [epiprocess::epi_df] where
56+
#' possible. Specifically, the output will have `time_value` and
57+
#' `geo_value` columns as well as the prediction.
58+
#'
59+
#' @inheritParams parsnip::predict.model_fit
60+
#' @param forecast_date The date on which the forecast is (was) made.
61+
#'
62+
#' @param object An epi_workflow that has been fit by [fit.workflow()]
63+
#'
64+
#' @param new_data A data frame containing the new predictors to preprocess
65+
#' and predict on
66+
#'
67+
#' @return
68+
#' A data frame of model predictions, with as many rows as `new_data` has.
69+
#' If `new_data` is an `epi_df` or a data frame with `time_value` or
70+
#' `geo_value` columns, then the result will have those as well.
71+
#'
72+
#' @name predict-epi_workflow
73+
#' @export
4274
predict.epi_workflow <-
4375
function(object, new_data, type = NULL, opts = list(),
4476
forecast_date = NULL, ...) {
@@ -47,24 +79,24 @@ predict.epi_workflow <-
4779
c("Can't predict on an untrained epi_workflow.",
4880
i = "Do you need to call `fit()`?"))
4981
}
82+
if (!is_null(forecast_date)) forecast_date <- as.Date(forecast_date)
5083
the_fit <- workflows::extract_fit_parsnip(object)
5184
mold <- workflows::extract_mold(object)
5285
forged <- hardhat::forge(new_data, blueprint = mold$blueprint)
5386
preds <- predict(the_fit, forged$predictors, type = type, opts = opts, ...)
5487
keys <- grab_forged_keys(forged, mold, new_data)
55-
out <- dplyr::bind_cols(keys, preds, forecast_date)
88+
out <- dplyr::bind_cols(keys, forecast_date = forecast_date, preds)
5689
out
5790
}
5891

5992
grab_forged_keys <- function(forged, mold, new_data) {
6093
keys <- c("time_value", "geo_value", "key")
61-
forged_names <- names(forged$extras$roles)
62-
molded_names <- names(mold$extras$roles)
63-
extras <- dplyr::bind_cols(forged$extras$roles[forged_names %in% keys])
94+
forged_roles <- names(forged$extras$roles)
95+
extras <- dplyr::bind_cols(forged$extras$roles[forged_roles %in% keys])
6496
# 1. these are the keys in the test data after prep/bake
6597
new_keys <- names(extras)
6698
# 2. these are the keys in the training data
67-
old_keys <- purrr::map_chr(mold$extras$roles[molded_names %in% keys], names)
99+
old_keys <- epi_keys_mold(mold)
68100
# 3. these are the keys in the test data as input
69101
new_df_keys <- epi_keys(new_data)
70102
if (! (setequal(old_keys, new_df_keys) && setequal(new_keys, new_df_keys))) {
@@ -73,10 +105,13 @@ grab_forged_keys <- function(forged, mold, new_data) {
73105
"in `new_data`. Predictions will have only the available keys.")
74106
)
75107
}
76-
if (epiprocess::is_epi_df(new_data) || keys[1:2] %in% new_keys) {
108+
if (epiprocess::is_epi_df(new_data)) {
109+
extras <- epiprocess::as_epi_df(extras)
110+
attr(extras, "metadata") <- attr(new_data, "metadata")
111+
} else if (keys[1:2] %in% new_keys) {
77112
l <- list()
78113
if (length(new_keys) > 2) l <- list(other_keys = new_keys[-c(1:2)])
79-
extras <- as_epi_df(extras, additional_metadata = l)
114+
extras <- epiprocess::as_epi_df(extras, additional_metadata = l)
80115
}
81116
extras
82117
}

man/add_epi_recipe.Rd

Lines changed: 1 addition & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/predict-epi_workflow.Rd

Lines changed: 74 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

musings/example-recipe.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ r <- epi_recipe(x) %>% # if we add this as a class, maybe we get better
5151
slm <- linear_reg()
5252

5353
# actually estimate the model
54-
slm_fit <- workflow() %>%
55-
add_recipe(r) %>%
54+
slm_fit <- epi_workflow() %>%
55+
add_epi_recipe(r) %>%
5656
add_model(slm) %>%
5757
fit(data = x)
5858
# slm_fit <- workflow(r, slm) also works

0 commit comments

Comments
 (0)