|
| 1 | +#' Create an epi_workflow |
| 2 | +#' |
| 3 | +#' This is a container object that unifies preprocessing, fitting, prediction, |
| 4 | +#' and postprocessing for predictive modeling on epidemiological data. It extends |
| 5 | +#' the functionality of a [`workflows::workflow()`] to handle the typical panel |
| 6 | +#' data structures found in this field. This extension is handled completely |
| 7 | +#' internally, and should be invisible to the user. For all intents and purposes, |
| 8 | +#' this operates exactly like a [`workflows::workflow()`]. For more details |
| 9 | +#' and numerous examples, see there. |
| 10 | +#' |
| 11 | +#' @inheritParams workflows::workflow |
| 12 | +#' |
| 13 | +#' @return A new `epi_workflow` object. |
| 14 | +#' @seealso workflows::workflow |
| 15 | +#' @importFrom rlang is_null |
1 | 16 | #' @export |
2 | 17 | epi_workflow <- function(preprocessor = NULL, spec = NULL) { |
3 | | - out <- workflow(preprocessor, spec) |
| 18 | + out <- workflows::workflow(spec = spec) |
4 | 19 | class(out) <- c("epi_workflow", class(out)) |
5 | | -} |
6 | 20 |
|
7 | | -predict.epi_workflow <- |
8 | | - function(object, new_data, type = NULL, opts = list(), forecast_date, ...) { |
9 | | - out <- predict(object, new_data, type = type, opts = opts, ...) |
10 | | - if (is_epi_df(new_data)) { |
11 | | - ek <- epi_keys(new_data) |
| 21 | + if (is_epi_recipe(preprocessor)) { |
| 22 | + return(add_epi_recipe(out, preprocessor)) |
| 23 | + } |
12 | 24 |
|
13 | | - } |
| 25 | + if (!is_null(preprocessor)) { |
| 26 | + return(workflows:::add_preprocessor(out, preprocessor)) |
14 | 27 | } |
| 28 | + out |
| 29 | +} |
15 | 30 |
|
| 31 | +#' Test for an `epi_workflow` |
| 32 | +#' |
| 33 | +#' @param x An object. |
| 34 | +#' @return `TRUE` if the object inherits from `epi_workflow`. |
| 35 | +#' |
| 36 | +#' @export |
16 | 37 | is_epi_workflow <- function(x) { |
17 | 38 | inherits(x, "epi_workflow") |
18 | 39 | } |
19 | 40 |
|
20 | | -workflow <- function(preprocessor = NULL, spec = NULL) { |
21 | | - out <- new_workflow() |
22 | | - |
23 | | - if (!is_null(preprocessor)) { |
24 | | - out <- add_preprocessor(out, preprocessor) |
25 | | - } |
26 | 41 |
|
27 | | - if (!is_null(spec)) { |
28 | | - out <- add_model(out, spec) |
| 42 | +predict.epi_workflow <- |
| 43 | + function(object, new_data, type = NULL, opts = list(), |
| 44 | + forecast_date = NULL, ...) { |
| 45 | + if (!workflows::is_trained_workflow(object)) { |
| 46 | + rlang::abort( |
| 47 | + c("Can't predict on an untrained epi_workflow.", |
| 48 | + i = "Do you need to call `fit()`?")) |
| 49 | + } |
| 50 | + the_fit <- workflows::extract_fit_parsnip(object) |
| 51 | + mold <- workflows::extract_mold(object) |
| 52 | + forged <- hardhat::forge(new_data, blueprint = mold$blueprint) |
| 53 | + preds <- predict(the_fit, forged$predictors, type = type, opts = opts, ...) |
| 54 | + keys <- grab_forged_keys(forged, mold, new_data) |
| 55 | + out <- dplyr::bind_cols(keys, preds, forecast_date) |
| 56 | + out |
29 | 57 | } |
30 | 58 |
|
31 | | - out |
32 | | -} |
33 | | - |
34 | | -add_preprocessor <- function(x, preprocessor, ..., call = caller_env()) { |
35 | | - check_dots_empty() |
36 | | - |
37 | | - if (is_formula(preprocessor)) { |
38 | | - return(add_formula(x, preprocessor)) |
| 59 | +grab_forged_keys <- function(forged, mold, new_data) { |
| 60 | + keys <- c("time_value", "geo_value", "key") |
| 61 | + forged_names <- names(forged$extras$roles) |
| 62 | + molded_names <- names(mold$extras$roles) |
| 63 | + extras <- dplyr::bind_cols(forged$extras$roles[forged_names %in% keys]) |
| 64 | + # 1. these are the keys in the test data after prep/bake |
| 65 | + new_keys <- names(extras) |
| 66 | + # 2. these are the keys in the training data |
| 67 | + old_keys <- purrr::map_chr(mold$extras$roles[molded_names %in% keys], names) |
| 68 | + # 3. these are the keys in the test data as input |
| 69 | + new_df_keys <- epi_keys(new_data) |
| 70 | + if (! (setequal(old_keys, new_df_keys) && setequal(new_keys, new_df_keys))) { |
| 71 | + rlang::warn(c( |
| 72 | + "Not all epi keys that were present in the training data are available", |
| 73 | + "in `new_data`. Predictions will have only the available keys.") |
| 74 | + ) |
39 | 75 | } |
40 | | - |
41 | | - if (is_recipe(preprocessor)) { |
42 | | - return(add_recipe(x, preprocessor)) |
| 76 | + if (epiprocess::is_epi_df(new_data) || keys[1:2] %in% new_keys) { |
| 77 | + l <- list() |
| 78 | + if (length(new_keys) > 2) l <- list(other_keys = new_keys[-c(1:2)]) |
| 79 | + extras <- as_epi_df(extras, additional_metadata = l) |
43 | 80 | } |
| 81 | + extras |
| 82 | +} |
44 | 83 |
|
45 | | - if (is_workflow_variables(preprocessor)) { |
46 | | - return(add_variables(x, variables = preprocessor)) |
47 | | - } |
| 84 | +new_epi_workflow <- function( |
| 85 | + pre = workflows:::new_stage_pre(), |
| 86 | + fit = workflows:::new_stage_fit(), |
| 87 | + post = workflows:::new_stage_post(), |
| 88 | + trained = FALSE) { |
48 | 89 |
|
49 | | - abort( |
50 | | - "`preprocessor` must be a formula, recipe, or a set of workflow variables.", |
51 | | - call = call |
52 | | - ) |
| 90 | + out <- workflows:::new_workflow( |
| 91 | + pre = pre, fit = fit, post = post, trained = trained) |
| 92 | + class(out) <- c("epi_workflow", class(out)) |
53 | 93 | } |
0 commit comments