|
| 1 | +#' Limits the size of the training window to the most recent observations |
| 2 | +#' |
| 3 | +#' `step_training_window2` creates a *specification* of a recipe step that |
| 4 | +#' limits the size of the training window to the `n_recent` most recent |
| 5 | +#' observations in `time_value` per group, where the groups are formed |
| 6 | +#' based on the remaining `epi_keys`. |
| 7 | +#' |
| 8 | +#' @param n_recent An integer value that represents the number of most recent |
| 9 | +#' observations that are to be kept in the training window per key |
| 10 | +#' The default value is 50. |
| 11 | +#' @param seasonal Bool, default FALSE. If TRUE, the training window will slice |
| 12 | +#' through epidemic seasons. This is useful for forecasting models that need |
| 13 | +#' to leverage the data in previous years, but only limited to similar phases |
| 14 | +#' in the epidemic. Most useful to heavily seasonal data, like influenza. |
| 15 | +#' Expects n_recent to be finite. |
| 16 | +#' @param seasonal_forward_window An integer value that represents the number of days |
| 17 | +#' after a season week to include in the training window. The default value |
| 18 | +#' is 14. Only valid when seasonal is TRUE. |
| 19 | +#' @param seasonal_backward_window An integer value that represents the number of days |
| 20 | +#' before a season week to include in the training window. The default value |
| 21 | +#' is 35. Only valid when seasonal is TRUE. |
| 22 | +#' @param epi_keys An optional character vector for specifying "key" variables |
| 23 | +#' to group on. The default, `NULL`, ensures that every key combination is |
| 24 | +#' limited. |
| 25 | +#' @inheritParams step_epi_lag |
| 26 | +#' @template step-return |
| 27 | +#' |
| 28 | +#' @details Note that `step_epi_lead()` and `step_epi_lag()` should come |
| 29 | +#' after any filtering step. |
| 30 | +#' |
| 31 | +#' @export |
| 32 | +#' |
| 33 | +#' @examples |
| 34 | +#' tib <- tibble( |
| 35 | +#' x = 1:10, |
| 36 | +#' y = 1:10, |
| 37 | +#' time_value = rep(seq(as.Date("2020-01-01"), by = 1, length.out = 5), 2), |
| 38 | +#' geo_value = rep(c("ca", "hi"), each = 5) |
| 39 | +#' ) %>% |
| 40 | +#' as_epi_df() |
| 41 | +#' |
| 42 | +#' epi_recipe(y ~ x, data = tib) %>% |
| 43 | +#' step_training_window2(n_recent = 3) %>% |
| 44 | +#' prep(tib) %>% |
| 45 | +#' bake(new_data = NULL) |
| 46 | +#' |
| 47 | +#' epi_recipe(y ~ x, data = tib) %>% |
| 48 | +#' step_epi_naomit() %>% |
| 49 | +#' step_training_window2(n_recent = 3) %>% |
| 50 | +#' prep(tib) %>% |
| 51 | +#' bake(new_data = NULL) |
| 52 | +step_training_window2 <- |
| 53 | + function(recipe, |
| 54 | + role = NA, |
| 55 | + n_recent = 50, |
| 56 | + seasonal = FALSE, |
| 57 | + seasonal_forward_window = 14, |
| 58 | + seasonal_backward_window = 35, |
| 59 | + epi_keys = NULL, |
| 60 | + id = rand_id("training_window2")) { |
| 61 | + epipredict:::arg_is_scalar(n_recent, id, seasonal, seasonal_forward_window, seasonal_backward_window) |
| 62 | + epipredict:::arg_is_pos(n_recent, seasonal_forward_window, seasonal_backward_window) |
| 63 | + if (is.finite(n_recent)) epipredict:::arg_is_pos_int(n_recent) |
| 64 | + epipredict:::arg_is_chr(id) |
| 65 | + epipredict:::arg_is_chr(epi_keys, allow_null = TRUE) |
| 66 | + add_step( |
| 67 | + recipe, |
| 68 | + step_training_window2_new( |
| 69 | + role = role, |
| 70 | + trained = FALSE, |
| 71 | + n_recent = n_recent, |
| 72 | + seasonal = seasonal, |
| 73 | + seasonal_forward_window = seasonal_forward_window, |
| 74 | + seasonal_backward_window = seasonal_backward_window, |
| 75 | + epi_keys = epi_keys, |
| 76 | + skip = TRUE, |
| 77 | + id = id |
| 78 | + ) |
| 79 | + ) |
| 80 | + } |
| 81 | + |
| 82 | +step_training_window2_new <- |
| 83 | + function(role, trained, n_recent, seasonal, seasonal_forward_window, seasonal_backward_window, epi_keys, skip, id) { |
| 84 | + step( |
| 85 | + subclass = "training_window2", |
| 86 | + role = role, |
| 87 | + trained = trained, |
| 88 | + n_recent = n_recent, |
| 89 | + seasonal = seasonal, |
| 90 | + seasonal_forward_window = seasonal_forward_window, |
| 91 | + seasonal_backward_window = seasonal_backward_window, |
| 92 | + epi_keys = epi_keys, |
| 93 | + skip = skip, |
| 94 | + id = id |
| 95 | + ) |
| 96 | + } |
| 97 | + |
| 98 | +#' @export |
| 99 | +prep.step_training_window2 <- function(x, training, info = NULL, ...) { |
| 100 | + ekt <- epipredict:::epi_keys_only(training) |
| 101 | + ek <- x$epi_keys %||% ekt %||% character(0L) |
| 102 | + |
| 103 | + hardhat::validate_column_names(training, ek) |
| 104 | + |
| 105 | + step_training_window2_new( |
| 106 | + role = x$role, |
| 107 | + trained = TRUE, |
| 108 | + n_recent = x$n_recent, |
| 109 | + seasonal = x$seasonal, |
| 110 | + seasonal_forward_window = x$seasonal_forward_window, |
| 111 | + seasonal_backward_window = x$seasonal_backward_window, |
| 112 | + epi_keys = ek, |
| 113 | + skip = x$skip, |
| 114 | + id = x$id |
| 115 | + ) |
| 116 | +} |
| 117 | + |
| 118 | +#' @export |
| 119 | +bake.step_training_window2 <- function(object, new_data, ...) { |
| 120 | + hardhat::validate_column_names(new_data, object$epi_keys) |
| 121 | + |
| 122 | + if (object$n_recent < Inf) { |
| 123 | + new_data %<>% |
| 124 | + group_by(across(all_of(object$epi_keys))) %>% |
| 125 | + arrange(time_value) %>% |
| 126 | + dplyr::slice_tail(n = object$n_recent) %>% |
| 127 | + ungroup() |
| 128 | + } |
| 129 | + |
| 130 | + if (object$seasonal) { |
| 131 | + new_data %<>% add_season_info() |
| 132 | + |
| 133 | + last_data_season_week <- new_data %>% |
| 134 | + filter(time_value == max(time_value)) %>% |
| 135 | + pull(season_week) %>% |
| 136 | + max() |
| 137 | + recent_weeks <- c(last_data_season_week) |
| 138 | + if (inherits(new_data, "epi_df")) { |
| 139 | + current_season_week <- convert_epiweek_to_season_week(epiyear(epi_as_of(new_data)), epiweek(epi_as_of(new_data))) |
| 140 | + recent_weeks <- c(recent_weeks, current_season_week) |
| 141 | + } |
| 142 | + date_ranges <- new_data %>% |
| 143 | + filter(season_week %in% recent_weeks) %>% |
| 144 | + pull(time_value) %>% |
| 145 | + unique() %>% |
| 146 | + map(~ c(.x - 1:(object$seasonal_backward_window), .x + 0:(object$seasonal_forward_window))) %>% |
| 147 | + unlist() %>% |
| 148 | + as.Date() %>% |
| 149 | + unique() |
| 150 | + new_data %<>% filter(time_value %in% date_ranges) |
| 151 | + } |
| 152 | + |
| 153 | + |
| 154 | + new_data |
| 155 | +} |
| 156 | + |
| 157 | +#' @export |
| 158 | +print.step_training_window2 <- |
| 159 | + function(x, width = max(20, options()$width - 30), ...) { |
| 160 | + if (x$seasonal) { |
| 161 | + title <- "# of seasonal observations per key limited to:" |
| 162 | + n_recent <- x$n_recent |
| 163 | + seasonal_forward_window <- x$seasonal_forward_window |
| 164 | + seasonal_backward_window <- x$seasonal_backward_window |
| 165 | + tr_obj <- recipes::format_selectors(rlang::enquos(n_recent, seasonal_forward_window, seasonal_backward_window), width) |
| 166 | + recipes::print_step(tr_obj, rlang::enquos(n_recent, seasonal_forward_window, seasonal_backward_window), x$trained, title, width) |
| 167 | + } else { |
| 168 | + title <- "# of recent observations per key limited to:" |
| 169 | + n_recent <- x$n_recent |
| 170 | + tr_obj <- recipes::format_selectors(rlang::enquos(n_recent), width) |
| 171 | + recipes::print_step(tr_obj, rlang::enquos(n_recent), x$trained, title, width) |
| 172 | + } |
| 173 | + invisible(x) |
| 174 | + } |
0 commit comments