Skip to content

Commit 7b13e55

Browse files
authored
Merge pull request #42 from cmu-delphi/epi-workflow
Epi workflow
2 parents e8cfd8e + 83020c6 commit 7b13e55

14 files changed

+761
-20
lines changed

DESCRIPTION

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,42 @@ Description: What the package does (one paragraph).
1010
License: MIT + file LICENSE
1111
URL: https://github.com/cmu-delphi/epipredict/,
1212
https://cmu-delphi.github.io/epiprocess
13+
Depends:
14+
R (>= 3.5.0)
1315
Imports:
1416
assertthat,
1517
cli,
1618
dplyr,
19+
epiprocess,
1720
glue,
21+
hardhat (>= 1.0.0.9000),
1822
magrittr,
1923
purrr,
20-
recipes,
24+
recipes (>= 0.2.0.9001),
2125
rlang,
2226
stats,
27+
tensr,
2328
tibble,
2429
tidyr,
2530
tidyselect,
26-
tensr
31+
workflows
2732
Suggests:
2833
covidcast,
2934
data.table,
30-
epiprocess,
3135
ggplot2,
3236
knitr,
3337
lubridate,
38+
parsnip (>= 0.2.1.9001),
3439
RcppRoll,
3540
rmarkdown,
3641
testthat (>= 3.0.0)
3742
VignetteBuilder:
3843
knitr
3944
Remotes:
40-
dajmcdon/epiprocess
45+
dajmcdon/epiprocess,
46+
tidymodels/hardhat,
47+
tidymodels/parsnip,
48+
tidymodels/recipes
4149
Config/testthat/edition: 3
4250
Encoding: UTF-8
4351
Roxygen: list(markdown = TRUE)

NAMESPACE

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,25 @@ S3method(epi_keys,recipe)
88
S3method(epi_recipe,default)
99
S3method(epi_recipe,epi_df)
1010
S3method(epi_recipe,formula)
11+
S3method(predict,epi_workflow)
1112
S3method(prep,step_epi_ahead)
1213
S3method(prep,step_epi_lag)
1314
S3method(print,step_epi_ahead)
1415
S3method(print,step_epi_lag)
1516
export("%>%")
17+
export(add_epi_recipe)
1618
export(arx_args_list)
1719
export(arx_forecaster)
1820
export(create_lags_and_leads)
21+
export(default_epi_recipe_blueprint)
1922
export(df_mat_mul)
2023
export(epi_keys)
2124
export(epi_recipe)
25+
export(epi_workflow)
2226
export(get_precision)
2327
export(grab_names)
28+
export(is_epi_recipe)
29+
export(is_epi_workflow)
2430
export(knn_iteraive_ar_args_list)
2531
export(knn_iteraive_ar_forecaster)
2632
export(knnarx_args_list)
@@ -33,6 +39,7 @@ import(recipes)
3339
importFrom(magrittr,"%>%")
3440
importFrom(rlang,"!!")
3541
importFrom(rlang,":=")
42+
importFrom(rlang,is_null)
3643
importFrom(stats,as.formula)
3744
importFrom(stats,lm)
3845
importFrom(stats,model.frame)

R/epi_keys.R

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,13 @@ epi_keys.epi_df <- function(x) {
2323
epi_keys.recipe <- function(x) {
2424
x$var_info$variable[x$var_info$role %in% c("time_value", "geo_value", "key")]
2525
}
26+
27+
# a mold is a list extracted from a fitted workflow, gives info about
28+
# training data. But it doesn't have a class
29+
epi_keys_mold <- function(mold) {
30+
keys <- c("time_value", "geo_value", "key")
31+
molded_names <- names(mold$extras$roles)
32+
mold_keys <- purrr::map_chr(mold$extras$roles[molded_names %in% keys], names)
33+
unname(mold_keys)
34+
}
35+

R/epi_recipe.R

Lines changed: 109 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#'
33
#' A recipe is a description of the steps to be applied to a data set in
44
#' order to prepare it for data analysis. This is a loose wrapper
5-
#' around `recipes::recipe()` to properly handle the additional
5+
#' around [recipes::recipe()] to properly handle the additional
66
#' columns present in an `epi_df`
77
#'
88
#' @aliases epi_recipe epi_recipe.default epi_recipe.formula
@@ -51,9 +51,27 @@ epi_recipe.default <- function(x, ...) {
5151
#' as the data given in the `data` argument but can be different after
5252
#' the recipe is trained.}
5353
#'
54-
# @includeRmd man/rmd/recipes.Rmd details
5554
#'
5655
#' @export
56+
#' @examples
57+
#' library(epiprocess)
58+
#' library(dplyr)
59+
#' library(recipes)
60+
#'
61+
#' jhu <- jhu_csse_daily_subset %>%
62+
#' filter(time_value > "2021-08-01") %>%
63+
#' select(geo_value:death_rate_7d_av) %>%
64+
#' rename(case_rate = case_rate_7d_av, death_rate = death_rate_7d_av)
65+
#'
66+
#' r <- epi_recipe(jhu) %>%
67+
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
68+
#' step_epi_ahead(death_rate, ahead = 7) %>%
69+
#' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>%
70+
#' step_naomit(all_predictors()) %>%
71+
#' # below, `skip` means we don't do this at predict time
72+
#' step_naomit(all_outcomes(), skip = TRUE)
73+
#'
74+
#' r
5775
epi_recipe.epi_df <-
5876
function(x,
5977
formula = NULL,
@@ -137,7 +155,7 @@ epi_recipe.epi_df <-
137155
levels = NULL,
138156
retained = NA
139157
)
140-
class(out) <- "recipe"
158+
class(out) <- c("epi_recipe", "recipe")
141159
out
142160
}
143161

@@ -210,3 +228,91 @@ epi_form2args <- function(formula, data, ...) {
210228
list(x = data, vars = vars, roles = roles)
211229
}
212230

231+
232+
233+
#' Test for `epi_recipe`
234+
#'
235+
#' @param x An object.
236+
#' @return `TRUE` if the object inherits from `epi_recipe`.
237+
#'
238+
#' @export
239+
is_epi_recipe <- function(x) {
240+
inherits(x, "epi_recipe")
241+
}
242+
243+
244+
245+
#' Add an epi_recipe to a workflow
246+
#'
247+
#' @seealso [workflows::add_recipe()]
248+
#' - `add_recipe()` specifies the terms of the model and any preprocessing that
249+
#' is required through the usage of a recipe.
250+
#'
251+
#' - `remove_recipe()` removes the recipe as well as any downstream objects
252+
#'
253+
#' @details
254+
#' Has the same behaviour as [workflows::add_recipe()] but sets a different
255+
#' default blueprint to automatically handle [epiprocess::epi_df] data.
256+
#'
257+
#' @param x A workflow or epi_workflow
258+
#'
259+
#' @param recipe A recipe created using [recipes::recipe()]
260+
#'
261+
#' @param ... Not used.
262+
#'
263+
#' @param blueprint A hardhat blueprint used for fine tuning the preprocessing.
264+
#'
265+
#' [default_epi_recipe_blueprint()] is used.
266+
#'
267+
#' Note that preprocessing done here is separate from preprocessing that
268+
#' might be done automatically by the underlying model.
269+
#'
270+
#' @return
271+
#' `x`, updated with a new recipe preprocessor.
272+
#'
273+
#' @export
274+
#' @examples
275+
#' library(epiprocess)
276+
#' library(dplyr)
277+
#' library(recipes)
278+
#'
279+
#' jhu <- jhu_csse_daily_subset %>%
280+
#' filter(time_value > "2021-08-01") %>%
281+
#' select(geo_value:death_rate_7d_av) %>%
282+
#' rename(case_rate = case_rate_7d_av, death_rate = death_rate_7d_av)
283+
#'
284+
#' r <- epi_recipe(jhu) %>%
285+
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
286+
#' step_epi_ahead(death_rate, ahead = 7) %>%
287+
#' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>%
288+
#' step_naomit(all_predictors()) %>%
289+
#' step_naomit(all_outcomes(), skip = TRUE)
290+
#'
291+
#' workflow <- epi_workflow() %>%
292+
#' add_epi_recipe(r)
293+
#'
294+
#' workflow
295+
add_epi_recipe <- function(
296+
x, recipe, ..., blueprint = default_epi_recipe_blueprint()) {
297+
workflows::add_recipe(x, recipe, ..., blueprint = blueprint)
298+
}
299+
300+
301+
302+
#' Recipe blueprint that accounts for `epi_df` panel data
303+
#'
304+
#' Used for simplicity. See [hardhat::default_recipe_blueprint()] for more
305+
#' details.
306+
#'
307+
#' @inheritParams hardhat::default_recipe_blueprint
308+
#'
309+
#' @details The `bake_dependent_roles` are automatically set to `epi_df` defaults.
310+
#' @return A recipe blueprint.
311+
#' @export
312+
default_epi_recipe_blueprint <-
313+
function(intercept = FALSE, allow_novel_levels = FALSE, fresh = TRUE,
314+
bake_dependent_roles = c("time_value", "geo_value", "key", "raw"),
315+
composition = "tibble") {
316+
hardhat::default_recipe_blueprint(
317+
intercept, allow_novel_levels, fresh, bake_dependent_roles, composition)
318+
}

0 commit comments

Comments
 (0)