Skip to content

Commit 39259d9

Browse files
committed
predict method implemented for epi_workflow
* needs tests/validation * needs improved documentation
1 parent e84778c commit 39259d9

File tree

7 files changed

+132
-37
lines changed

7 files changed

+132
-37
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ export(epi_workflow)
2525
export(get_precision)
2626
export(grab_names)
2727
export(is_epi_recipe)
28+
export(is_epi_workflow)
2829
export(knn_iteraive_ar_args_list)
2930
export(knn_iteraive_ar_forecaster)
3031
export(knnarx_args_list)
@@ -37,6 +38,7 @@ import(recipes)
3738
importFrom(magrittr,"%>%")
3839
importFrom(rlang,"!!")
3940
importFrom(rlang,":=")
41+
importFrom(rlang,is_null)
4042
importFrom(stats,as.formula)
4143
importFrom(stats,lm)
4244
importFrom(stats,model.frame)

R/epi_keys.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ epi_keys.epi_df <- function(x) {
2323
epi_keys.recipe <- function(x) {
2424
x$var_info$variable[x$var_info$role %in% c("time_value", "geo_value", "key")]
2525
}
26+

R/epi_recipe.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ epi_form2args <- function(formula, data, ...) {
212212

213213

214214

215-
#' Test for `epi_df` format
215+
#' Test for `epi_recipe`
216216
#'
217217
#' @param x An object.
218218
#' @return `TRUE` if the object inherits from `epi_recipe`.
@@ -256,6 +256,7 @@ is_epi_recipe <- function(x) {
256256
#' @examples
257257
#' library(recipes)
258258
#' library(magrittr)
259+
#' library(workflows)
259260
#'
260261
#' recipe <- epi_recipe(mpg ~ cyl, mtcars) %>%
261262
#' step_log(cyl)

R/epi_workflow.R

Lines changed: 74 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,93 @@
1+
#' Create an epi_workflow
2+
#'
3+
#' This is a container object that unifies preprocessing, fitting, prediction,
4+
#' and postprocessing for predictive modeling on epidemiological data. It extends
5+
#' the functionality of a [`workflows::workflow()`] to handle the typical panel
6+
#' data structures found in this field. This extension is handled completely
7+
#' internally, and should be invisible to the user. For all intents and purposes,
8+
#' this operates exactly like a [`workflows::workflow()`]. For more details
9+
#' and numerous examples, see there.
10+
#'
11+
#' @inheritParams workflows::workflow
12+
#'
13+
#' @return A new `epi_workflow` object.
14+
#' @seealso workflows::workflow
15+
#' @importFrom rlang is_null
116
#' @export
217
epi_workflow <- function(preprocessor = NULL, spec = NULL) {
3-
out <- workflow(preprocessor, spec)
18+
out <- workflows::workflow(spec = spec)
419
class(out) <- c("epi_workflow", class(out))
5-
}
620

7-
predict.epi_workflow <-
8-
function(object, new_data, type = NULL, opts = list(), forecast_date, ...) {
9-
out <- predict(object, new_data, type = type, opts = opts, ...)
10-
if (is_epi_df(new_data)) {
11-
ek <- epi_keys(new_data)
21+
if (is_epi_recipe(preprocessor)) {
22+
return(add_epi_recipe(out, preprocessor))
23+
}
1224

13-
}
25+
if (!is_null(preprocessor)) {
26+
return(workflows:::add_preprocessor(out, preprocessor))
1427
}
28+
out
29+
}
1530

31+
#' Test for an `epi_workflow`
32+
#'
33+
#' @param x An object.
34+
#' @return `TRUE` if the object inherits from `epi_workflow`.
35+
#'
36+
#' @export
1637
is_epi_workflow <- function(x) {
1738
inherits(x, "epi_workflow")
1839
}
1940

20-
workflow <- function(preprocessor = NULL, spec = NULL) {
21-
out <- new_workflow()
22-
23-
if (!is_null(preprocessor)) {
24-
out <- add_preprocessor(out, preprocessor)
25-
}
2641

27-
if (!is_null(spec)) {
28-
out <- add_model(out, spec)
42+
predict.epi_workflow <-
43+
function(object, new_data, type = NULL, opts = list(),
44+
forecast_date = NULL, ...) {
45+
if (!workflows::is_trained_workflow(object)) {
46+
rlang::abort(
47+
c("Can't predict on an untrained epi_workflow.",
48+
i = "Do you need to call `fit()`?"))
49+
}
50+
the_fit <- workflows::extract_fit_parsnip(object)
51+
mold <- workflows::extract_mold(object)
52+
forged <- hardhat::forge(new_data, blueprint = mold$blueprint)
53+
preds <- predict(the_fit, forged$predictors, type = type, opts = opts, ...)
54+
keys <- grab_forged_keys(forged, mold, new_data)
55+
out <- dplyr::bind_cols(keys, preds, forecast_date)
56+
out
2957
}
3058

31-
out
32-
}
33-
34-
add_preprocessor <- function(x, preprocessor, ..., call = caller_env()) {
35-
check_dots_empty()
36-
37-
if (is_formula(preprocessor)) {
38-
return(add_formula(x, preprocessor))
59+
grab_forged_keys <- function(forged, mold, new_data) {
60+
keys <- c("time_value", "geo_value", "key")
61+
forged_names <- names(forged$extras$roles)
62+
molded_names <- names(mold$extras$roles)
63+
extras <- dplyr::bind_cols(forged$extras$roles[forged_names %in% keys])
64+
# 1. these are the keys in the test data after prep/bake
65+
new_keys <- names(extras)
66+
# 2. these are the keys in the training data
67+
old_keys <- purrr::map_chr(mold$extras$roles[molded_names %in% keys], names)
68+
# 3. these are the keys in the test data as input
69+
new_df_keys <- epi_keys(new_data)
70+
if (! (setequal(old_keys, new_df_keys) && setequal(new_keys, new_df_keys))) {
71+
rlang::warn(c(
72+
"Not all epi keys that were present in the training data are available",
73+
"in `new_data`. Predictions will have only the available keys.")
74+
)
3975
}
40-
41-
if (is_recipe(preprocessor)) {
42-
return(add_recipe(x, preprocessor))
76+
if (epiprocess::is_epi_df(new_data) || keys[1:2] %in% new_keys) {
77+
l <- list()
78+
if (length(new_keys) > 2) l <- list(other_keys = new_keys[-c(1:2)])
79+
extras <- as_epi_df(extras, additional_metadata = l)
4380
}
81+
extras
82+
}
4483

45-
if (is_workflow_variables(preprocessor)) {
46-
return(add_variables(x, variables = preprocessor))
47-
}
84+
new_epi_workflow <- function(
85+
pre = workflows:::new_stage_pre(),
86+
fit = workflows:::new_stage_fit(),
87+
post = workflows:::new_stage_post(),
88+
trained = FALSE) {
4889

49-
abort(
50-
"`preprocessor` must be a formula, recipe, or a set of workflow variables.",
51-
call = call
52-
)
90+
out <- workflows:::new_workflow(
91+
pre = pre, fit = fit, post = post, trained = trained)
92+
class(out) <- c("epi_workflow", class(out))
5393
}

man/epi_workflow.Rd

Lines changed: 34 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/is_epi_recipe.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/is_epi_workflow.Rd

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)