Skip to content

Commit f830d1a

Browse files
committed
feat: add step_training_window with seasonal and use it
1 parent 144e6cd commit f830d1a

File tree

7 files changed

+218
-47
lines changed

7 files changed

+218
-47
lines changed

R/aux_data_utils.R

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,11 @@ add_season_info <- function(data) {
4040
}
4141

4242
data %>%
43-
select(-any_of(c("season", "season_week"))) %>%
44-
{
45-
if ("epiweek" %nin% names(.)) {
46-
. <- (.) %>% mutate(epiweek = epiweek(time_value))
47-
}
48-
if ("epiyear" %nin% names(.)) {
49-
. <- (.) %>% mutate(epiyear = epiyear(time_value))
50-
}
51-
.
52-
} %>%
43+
select(-any_of(c("season", "season_week", "epiweek", "epiyear"))) %>%
44+
mutate(
45+
epiweek = epiweek(time_value),
46+
epiyear = epiyear(time_value)
47+
) %>%
5348
left_join(
5449
(.) %>%
5550
distinct(epiweek, epiyear) %>%

R/default_epipredict_args.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ default_args_list <- function(
1818
check_enough_data_n = NULL,
1919
check_enough_data_epi_keys = NULL,
2020
keys_to_ignore = list(),
21-
n_recent = 5 * 7,
22-
n_forward = 3 * 7,
2321
seasonal_window = FALSE,
22+
seasonal_backward_window = 5 * 7,
23+
seasonal_forward_window = 3 * 7,
2424
...) {
2525
# error checking if lags is a list
2626
rlang::check_dots_empty()
@@ -72,9 +72,9 @@ default_args_list <- function(
7272
check_enough_data_n,
7373
check_enough_data_epi_keys,
7474
keys_to_ignore,
75-
n_recent,
76-
n_forward,
77-
seasonal_window
75+
seasonal_window,
76+
seasonal_backward_window,
77+
seasonal_forward_window
7878
),
7979
class = c("arx_fcast", "alist")
8080
)

R/forecasters/epipredict_utilities.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ arx_preprocess <- function(preproc, outcome, predictors, args_list) {
3030
}
3131
preproc %<>%
3232
step_epi_ahead(!!outcome, ahead = args_list$ahead) %>%
33-
# TODO: Uncomment after debugging
3433
step_epi_naomit() %>%
35-
step_training_window(
34+
step_training_window2(
3635
n_recent = args_list$n_training,
37-
# n_forward = args_list$n_forward,
38-
# seasonal = args_list$seasonal_window
36+
seasonal = args_list$seasonal_window,
37+
seasonal_backward_window = args_list$seasonal_backward_window,
38+
seasonal_forward_window = args_list$seasonal_forward_window,
3939
)
4040
return(preproc)
4141
}

R/forecasters/forecaster_scaled_pop_seasonal.R

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ scaled_pop_seasonal <- function(epi_data,
4545
center_method = c("median", "mean", "none"),
4646
nonlin_method = c("quart_root", "none"),
4747
seasonal_method = c("none", "flu", "covid", "indicator", "window", "climatological"),
48-
season_backward_window = 5 * 7,
49-
season_forward_window = 3 * 7,
48+
seasonal_backward_window = 5 * 7,
49+
seasonal_forward_window = 3 * 7,
5050
train_residual = FALSE,
5151
trainer = parsnip::linear_reg(),
5252
quantile_levels = covidhub_probs(),
@@ -89,9 +89,9 @@ scaled_pop_seasonal <- function(epi_data,
8989
args_input[["ahead"]] <- ahead
9090
args_input[["quantile_levels"]] <- quantile_levels
9191
args_input[["nonneg"]] <- scale_method == "none"
92-
args_input[["n_training"]] <- season_backward_window
93-
args_input[["n_forward"]] <- season_forward_window + ahead
9492
args_input[["seasonal_window"]] <- "window" %in% seasonal_method
93+
args_input[["seasonal_backward_window"]] <- seasonal_backward_window
94+
args_input[["seasonal_forward_window"]] <- seasonal_forward_window + ahead
9595
args_list <- inject(default_args_list(!!!args_input))
9696
# if you want to hardcode particular predictors in a particular forecaster
9797
predictors <- c(outcome, extra_sources[[1]])
@@ -142,25 +142,6 @@ scaled_pop_seasonal <- function(epi_data,
142142
}
143143
}
144144

145-
# TODO: Replace with step_training_window2
146-
if ("window" %in% seasonal_method) {
147-
last_data_season_week <- epi_data %>%
148-
filter(source == "nhsn") %>%
149-
filter(time_value == max(time_value)) %>%
150-
pull(season_week) %>%
151-
max()
152-
current_season_week <- convert_epiweek_to_season_week(epiyear(epi_as_of(epi_data)), epiweek(epi_as_of(epi_data)))
153-
date_ranges <- epi_data %>%
154-
filter(season_week == last_data_season_week) %>%
155-
pull(time_value) %>%
156-
unique() %>%
157-
map(~ c(.x - seq(from = 7, to = season_backward_window, by = 7), .x + seq(from = 0, to = season_forward_window, by = 7))) %>%
158-
unlist() %>%
159-
as.Date() %>%
160-
unique()
161-
epi_data <- epi_data %>% filter(time_value %in% unlist(date_ranges))
162-
}
163-
164145
if (drop_non_seasons) {
165146
season_data <- epi_data %>% drop_non_seasons()
166147
} else {
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
#' Limits the size of the training window to the most recent observations
2+
#'
3+
#' `step_training_window2` creates a *specification* of a recipe step that
4+
#' limits the size of the training window to the `n_recent` most recent
5+
#' observations in `time_value` per group, where the groups are formed
6+
#' based on the remaining `epi_keys`.
7+
#'
8+
#' @param n_recent An integer value that represents the number of most recent
9+
#' observations that are to be kept in the training window per key
10+
#' The default value is 50.
11+
#' @param seasonal Bool, default FALSE. If TRUE, the training window will slice
12+
#' through epidemic seasons. This is useful for forecasting models that need
13+
#' to leverage the data in previous years, but only limited to similar phases
14+
#' in the epidemic. Most useful to heavily seasonal data, like influenza.
15+
#' Expects n_recent to be finite.
16+
#' @param seasonal_forward_window An integer value that represents the number of days
17+
#' after a season week to include in the training window. The default value
18+
#' is 14. Only valid when seasonal is TRUE.
19+
#' @param seasonal_backward_window An integer value that represents the number of days
20+
#' before a season week to include in the training window. The default value
21+
#' is 35. Only valid when seasonal is TRUE.
22+
#' @param epi_keys An optional character vector for specifying "key" variables
23+
#' to group on. The default, `NULL`, ensures that every key combination is
24+
#' limited.
25+
#' @inheritParams step_epi_lag
26+
#' @template step-return
27+
#'
28+
#' @details Note that `step_epi_lead()` and `step_epi_lag()` should come
29+
#' after any filtering step.
30+
#'
31+
#' @export
32+
#'
33+
#' @examples
34+
#' tib <- tibble(
35+
#' x = 1:10,
36+
#' y = 1:10,
37+
#' time_value = rep(seq(as.Date("2020-01-01"), by = 1, length.out = 5), 2),
38+
#' geo_value = rep(c("ca", "hi"), each = 5)
39+
#' ) %>%
40+
#' as_epi_df()
41+
#'
42+
#' epi_recipe(y ~ x, data = tib) %>%
43+
#' step_training_window2(n_recent = 3) %>%
44+
#' prep(tib) %>%
45+
#' bake(new_data = NULL)
46+
#'
47+
#' epi_recipe(y ~ x, data = tib) %>%
48+
#' step_epi_naomit() %>%
49+
#' step_training_window2(n_recent = 3) %>%
50+
#' prep(tib) %>%
51+
#' bake(new_data = NULL)
52+
step_training_window2 <-
53+
function(recipe,
54+
role = NA,
55+
n_recent = 50,
56+
seasonal = FALSE,
57+
seasonal_forward_window = 14,
58+
seasonal_backward_window = 35,
59+
epi_keys = NULL,
60+
id = rand_id("training_window2")) {
61+
epipredict:::arg_is_scalar(n_recent, id, seasonal, seasonal_forward_window, seasonal_backward_window)
62+
epipredict:::arg_is_pos(n_recent, seasonal_forward_window, seasonal_backward_window)
63+
if (is.finite(n_recent)) epipredict:::arg_is_pos_int(n_recent)
64+
epipredict:::arg_is_chr(id)
65+
epipredict:::arg_is_chr(epi_keys, allow_null = TRUE)
66+
add_step(
67+
recipe,
68+
step_training_window2_new(
69+
role = role,
70+
trained = FALSE,
71+
n_recent = n_recent,
72+
seasonal = seasonal,
73+
seasonal_forward_window = seasonal_forward_window,
74+
seasonal_backward_window = seasonal_backward_window,
75+
epi_keys = epi_keys,
76+
skip = TRUE,
77+
id = id
78+
)
79+
)
80+
}
81+
82+
step_training_window2_new <-
83+
function(role, trained, n_recent, seasonal, seasonal_forward_window, seasonal_backward_window, epi_keys, skip, id) {
84+
step(
85+
subclass = "training_window2",
86+
role = role,
87+
trained = trained,
88+
n_recent = n_recent,
89+
seasonal = seasonal,
90+
seasonal_forward_window = seasonal_forward_window,
91+
seasonal_backward_window = seasonal_backward_window,
92+
epi_keys = epi_keys,
93+
skip = skip,
94+
id = id
95+
)
96+
}
97+
98+
#' @export
99+
prep.step_training_window2 <- function(x, training, info = NULL, ...) {
100+
ekt <- epipredict:::epi_keys_only(training)
101+
ek <- x$epi_keys %||% ekt %||% character(0L)
102+
103+
hardhat::validate_column_names(training, ek)
104+
105+
step_training_window2_new(
106+
role = x$role,
107+
trained = TRUE,
108+
n_recent = x$n_recent,
109+
seasonal = x$seasonal,
110+
seasonal_forward_window = x$seasonal_forward_window,
111+
seasonal_backward_window = x$seasonal_backward_window,
112+
epi_keys = ek,
113+
skip = x$skip,
114+
id = x$id
115+
)
116+
}
117+
118+
#' @export
119+
bake.step_training_window2 <- function(object, new_data, ...) {
120+
hardhat::validate_column_names(new_data, object$epi_keys)
121+
122+
if (object$n_recent < Inf) {
123+
new_data %<>%
124+
group_by(across(all_of(object$epi_keys))) %>%
125+
arrange(time_value) %>%
126+
dplyr::slice_tail(n = object$n_recent) %>%
127+
ungroup()
128+
}
129+
130+
if (object$seasonal) {
131+
new_data %<>% add_season_info()
132+
133+
last_data_season_week <- new_data %>%
134+
filter(time_value == max(time_value)) %>%
135+
pull(season_week) %>%
136+
max()
137+
recent_weeks <- c(last_data_season_week)
138+
if (inherits(new_data, "epi_df")) {
139+
current_season_week <- convert_epiweek_to_season_week(epiyear(epi_as_of(new_data)), epiweek(epi_as_of(new_data)))
140+
recent_weeks <- c(recent_weeks, current_season_week)
141+
}
142+
date_ranges <- new_data %>%
143+
filter(season_week %in% recent_weeks) %>%
144+
pull(time_value) %>%
145+
unique() %>%
146+
map(~ c(.x - 1:(object$seasonal_backward_window), .x + 0:(object$seasonal_forward_window))) %>%
147+
unlist() %>%
148+
as.Date() %>%
149+
unique()
150+
new_data %<>% filter(time_value %in% date_ranges)
151+
}
152+
153+
154+
new_data
155+
}
156+
157+
#' @export
158+
print.step_training_window2 <-
159+
function(x, width = max(20, options()$width - 30), ...) {
160+
if (x$seasonal) {
161+
title <- "# of seasonal observations per key limited to:"
162+
n_recent <- x$n_recent
163+
seasonal_forward_window <- x$seasonal_forward_window
164+
seasonal_backward_window <- x$seasonal_backward_window
165+
tr_obj <- recipes::format_selectors(rlang::enquos(n_recent, seasonal_forward_window, seasonal_backward_window), width)
166+
recipes::print_step(tr_obj, rlang::enquos(n_recent, seasonal_forward_window, seasonal_backward_window), x$trained, title, width)
167+
} else {
168+
title <- "# of recent observations per key limited to:"
169+
n_recent <- x$n_recent
170+
tr_obj <- recipes::format_selectors(rlang::enquos(n_recent), width)
171+
recipes::print_step(tr_obj, rlang::enquos(n_recent), x$trained, title, width)
172+
}
173+
invisible(x)
174+
}

scripts/flu_hosp_explore.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ forecaster_parameter_combinations_ <- rlang::list2(
267267
filter_agg_level = "state",
268268
drop_non_seasons = c(TRUE, FALSE),
269269
n_training = Inf,
270-
season_backward_window = 5,
270+
seasonal_backward_window = 5,
271271
keys_to_ignore = very_latent_locations
272272
),
273273
# Window-based seasonal method shouldn't drop non-seasons
@@ -284,7 +284,7 @@ forecaster_parameter_combinations_ <- rlang::list2(
284284
filter_agg_level = "state",
285285
drop_non_seasons = FALSE,
286286
n_training = Inf,
287-
season_backward_window = 5,
287+
seasonal_backward_window = 5,
288288
keys_to_ignore = very_latent_locations
289289
),
290290
tidyr::expand_grid(
@@ -300,7 +300,7 @@ forecaster_parameter_combinations_ <- rlang::list2(
300300
filter_agg_level = "state",
301301
drop_non_seasons = FALSE,
302302
n_training = Inf,
303-
season_backward_window = 8,
303+
seasonal_backward_window = 8,
304304
keys_to_ignore = very_latent_locations
305305
)
306306
# trying various window sizes
@@ -339,8 +339,8 @@ forecaster_parameter_combinations_ <- rlang::list2(
339339
filter_agg_level = "state",
340340
drop_non_seasons = FALSE,
341341
n_training = Inf,
342-
season_backward_window = c(3, 5, 7, 9, 52),
343-
season_forward_window = c(3, 5, 7),
342+
seasonal_backward_window = c(3, 5, 7, 9, 52),
343+
seasonal_forward_window = c(3, 5, 7),
344344
keys_to_ignore = very_latent_locations
345345
),
346346
climate_linear = expand_grid(
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
source(here::here("R", "load_all.R"))
2+
3+
data <- tribble(
4+
~geo_value, ~time_value, ~version, ~value,
5+
"ak", "2024-11-08", "2024-11-13", 1,
6+
"ak", "2024-11-07", "2024-11-13", 2,
7+
"ak", "2024-10-08", "2024-11-13", 2,
8+
"ak", "2024-10-07", "2024-11-13", 2,
9+
) %>%
10+
mutate(time_value = as.Date(time_value), version = as.Date(version)) %>%
11+
bind_rows((.) %>% mutate(geo_value = "ca", value = value * 3 + 1)) %>%
12+
bind_rows((.) %>% filter(geo_value == "ca") %>% mutate(time_value = time_value - 365)) %>%
13+
as_epi_df()
14+
15+
# debugonce(bake.step_training_window2)
16+
epi_recipe(data) %>%
17+
step_training_window2(seasonal_backward_window = 5, seasonal_forward_window = 3, seasonal = TRUE) %>%
18+
prep(data) %>%
19+
bake(new_data = NULL)
20+
21+
# Seems fine

0 commit comments

Comments
 (0)