Skip to content

Commit 5ade571

Browse files
committed
working example
1 parent bf13ab7 commit 5ade571

File tree

7 files changed

+193
-5
lines changed

7 files changed

+193
-5
lines changed

.Rbuildignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
^_pkgdown\.yml$
99
^docs$
1010
^pkgdown$
11+
^musings$

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
.Ruserdata
55
docs
66
inst/doc
7+
.DS_Store

R/check-train_window.R

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#' Check Training Window Length
2+
#'
3+
#' `check_train_window` creates a *specification* of a recipe
4+
#' check that will check if there is insufficient training data
5+
#'
6+
#' @inheritParams check_missing
7+
#' @min_train_window Positive integer. The minimum amount of training
8+
#' data time points required
9+
#' to fit a predictive model. Using less results causes downstream
10+
#' fit calls to return minimal objects rather than crashing.
11+
#' @param warn If `TRUE` the check will throw a warning instead
12+
#' of an error when failing.
13+
#' @param train_window The number of days of training data.
14+
#' This is `NULL` until computed by [prep()].
15+
#' @template check-return
16+
#' @family checks
17+
#' @export
18+
#'
19+
check_train_window <-
20+
function(recipe,
21+
...,
22+
role = NA,
23+
skip = FALSE,
24+
trained = FALSE,
25+
min_train_window = 20,
26+
warn = TRUE,
27+
train_length,
28+
id = rand_id("train_window_check_")) {
29+
add_check(
30+
recipe,
31+
check_train_window_new(
32+
terms = dplyr::enquos(...),
33+
role = role,
34+
skip = skip,
35+
trained = trained,
36+
min_train_window = min_train_window,
37+
warn = warn,
38+
train_length = train_length,
39+
id = id
40+
)
41+
)
42+
}
43+
44+
## Initializes a new object
45+
check_train_window_new <-
46+
function(terms, role, skip, trained, min_train_window, warn,
47+
train_length, id) {
48+
check(
49+
subclass = "train_window",
50+
terms = terms,
51+
role = role,
52+
skip = skip,
53+
trained = trained,
54+
min_train_window = min_train_window,
55+
warn = warn,
56+
train_length = train_length,
57+
id = id
58+
)
59+
}
60+
61+
62+
prep.check_train_window <- function(x,
63+
training,
64+
info = NULL,
65+
...) {
66+
67+
train_length <- nrow(training)
68+
69+
70+
check_train_window_new(
71+
terms = x$terms,
72+
role = x$role,
73+
trained = TRUE,
74+
skip = x$skip,
75+
warn = x$warn,
76+
min_train_window = min_train_window,
77+
warn = warn,
78+
train_length = train_length,
79+
id = x$id
80+
)
81+
}
82+
83+
bake.check_range <- function(object,
84+
new_data,
85+
...) {
86+
87+
mtw <- object$min_train_window
88+
stopifnot(is.numeric(mtw), length(mtw) == 1L, mtw == as.integer(mtw))
89+
90+
n <- nrow(new_data)
91+
n.complete <- sum(complete.cases(new_data))
92+
93+
msg <- NULL
94+
if (n < mtw) {
95+
msg <- paste0(msg, "Total available rows of data is ", n,
96+
"\n < min_train_window ", mtw, ".\n")
97+
}
98+
if (n.complete < mtw) {
99+
msg <- paste0(msg, "Total complete rows of data is ", n.complete,
100+
"\n < min_train_window ", mtw, ".\n")
101+
}
102+
103+
if (object$warn & !is.null(msg)) {
104+
rlang::warn(msg)
105+
} else if (!is.null(msg)) {
106+
rlang::abort(msg)
107+
}
108+
109+
as_tibble(new_data)
110+
}
111+
112+
print.check_train_window <-
113+
function(x, width = max(20, options()$width - 30), ...) {
114+
title <- "Checking number of training observations"
115+
invisible(x)
116+
}
117+
118+

R/epi_ahead.R

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,11 @@ bake.step_epi_ahead <- function(object, new_data, ...) {
121121
}
122122

123123
grid <- tidyr::expand_grid(
124-
col = object$columns, lag_val = -object$ahead, ahead_val = object$ahead) %>%
125-
dplyr::mutate(newname = glue::glue("{object$prefix}{ahead_val}_{col}")) %>%
124+
col = object$columns, lag_val = -object$ahead) %>%
125+
dplyr::mutate(
126+
ahead_val = -lag_val,
127+
newname = glue::glue("{object$prefix}{ahead_val}_{col}")
128+
) %>%
126129
dplyr::select(-ahead_val)
127130

128131
## ensure no name clashes
@@ -143,7 +146,7 @@ bake.step_epi_ahead <- function(object, new_data, ...) {
143146
by = ok
144147
)
145148

146-
dplyr::full_join(new_data, lagged, by = object$keys) %>%
149+
dplyr::full_join(new_data, lagged, by = ok) %>%
147150
dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>%
148151
dplyr::arrange(time_value) %>%
149152
dplyr::ungroup()

R/epi_lag.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ bake.step_epi_lag <- function(object, new_data, ...) {
106106
by = ok
107107
)
108108

109-
dplyr::full_join(new_data, lagged, by = object$keys) %>%
109+
dplyr::full_join(new_data, lagged, by = ok) %>%
110110
dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>%
111111
dplyr::arrange(time_value) %>%
112112
dplyr::ungroup()

R/epi_recipe.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ epi_recipe.epi_df <-
107107
var_info$role <- roles
108108
} else {
109109
var_info <- var_info %>% dplyr::filter(!(variable %in% keys))
110-
var_info$role <- NA
110+
var_info$role <- "raw"
111111
}
112112
## Now we add the keys when necessary
113113
var_info <- dplyr::union(

musings/example-recipe.R

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
library(tidyverse)
2+
library(covidcast)
3+
library(delphi.epidata)
4+
library(epiprocess)
5+
library(tidymodels)
6+
x <- covidcast(
7+
data_source = "jhu-csse",
8+
signals = "confirmed_7dav_incidence_prop",
9+
time_type = "day",
10+
geo_type = "state",
11+
time_values = epirange(20200301, 20211231),
12+
geo_values = "*"
13+
) %>%
14+
fetch_tbl() %>%
15+
select(geo_value, time_value, case_rate = value)
16+
17+
y <- covidcast(
18+
data_source = "jhu-csse",
19+
signals = "deaths_7dav_incidence_prop",
20+
time_type = "day",
21+
geo_type = "state",
22+
time_values = epirange(20200301, 20211231),
23+
geo_values = "*"
24+
) %>%
25+
fetch_tbl() %>%
26+
select(geo_value, time_value, death_rate = value)
27+
28+
x <- x %>%
29+
full_join(y, by = c("geo_value", "time_value")) %>%
30+
as_epi_df()
31+
rm(y)
32+
33+
xx <- x %>% filter(time_value > "2021-12-01")
34+
35+
36+
# Baseline AR3
37+
r <- epi_recipe(x) %>% # if we add this as a class, maybe we get better
38+
# behaviour downstream?
39+
step_epi_ahead(death_rate, ahead = 7) %>%
40+
step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
41+
step_epi_lag(case_rate, lag = c(0, 7, 14)) %>%
42+
step_naomit(all_predictors()) %>%
43+
# below, `skip` means we don't do this at predict time
44+
# we should probably do something useful here to avoid user error
45+
step_naomit(all_outcomes(), skip = TRUE)
46+
47+
48+
slm <- linear_reg()
49+
50+
slm_fit <- workflow() %>%
51+
add_recipe(r) %>%
52+
add_model(slm) %>%
53+
fit(data = x)
54+
55+
x_latest <- x %>%
56+
filter(!is.na(case_rate), !is.na(death_rate)) %>%
57+
group_by(geo_value) %>%
58+
slice_tail(n = 15) # have lag 0,...,14, so need 15 for a complete case
59+
60+
pp <- predict(slm_fit, new_data = x_latest) # drops the keys...
61+
62+
xl <- x %>%
63+
filter(!is.na(case_rate), !is.na(death_rate)) %>%
64+
group_by(geo_value) %>%
65+
slice_tail(n = 14)

0 commit comments

Comments
 (0)