Skip to content

Commit a78a707

Browse files
committed
wip: refactoring covid_hosp_explore
1 parent 7c75590 commit a78a707

File tree

12 files changed

+550
-621
lines changed

12 files changed

+550
-621
lines changed

R/aux_data_utils.R

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -188,38 +188,56 @@ daily_to_weekly <- function(epi_df, agg_method = c("sum", "mean"), day_of_week =
188188
select(-epiweek, -year)
189189
}
190190

191+
#' Aggregate a daily archive to a weekly archive.
192+
#'
193+
#' @param epi_arch the archive to aggregate.
194+
#' @param agg_columns the columns to aggregate.
195+
#' @param agg_method the method to use to aggregate the data, one of "sum" or "mean".
196+
#' @param day_of_week the day of the week to use as the reference day.
197+
#' @param day_of_week_end the day of the week to use as the end of the week.
191198
daily_to_weekly_archive <- function(epi_arch,
192199
agg_columns,
193200
agg_method = c("sum", "mean"),
194201
day_of_week = 4L,
195202
day_of_week_end = 7L) {
203+
# How to aggregate the windowed data.
196204
agg_method <- arg_match(agg_method)
197-
keys <- key_colnames(epi_arch, exclude = "time_value")
205+
# The columns we will later group by when aggregating.
206+
keys <- key_colnames(epi_arch, exclude = c("time_value", "version"))
207+
# The versions we will slide over.
198208
ref_time_values <- epi_arch$DT$version %>%
199209
unique() %>%
200210
sort()
211+
# Choose a fast function to use to slide and aggregate.
201212
if (agg_method == "sum") {
202213
slide_fun <- epi_slide_sum
203214
} else if (agg_method == "mean") {
204215
slide_fun <- epi_slide_mean
205216
}
206-
too_many_tibbles <- epix_slide(
217+
# Slide over the versions and aggregate.
218+
epix_slide(
207219
epi_arch,
208-
.before = 99999999L,
209220
.versions = ref_time_values,
210-
function(x, group, ref_time) {
221+
function(x, group_keys, ref_time) {
222+
# The last day of the week we will slide over.
211223
ref_time_last_week_end <-
212224
floor_date(ref_time, "week", day_of_week_end - 1) # this is over by 1
225+
# The last day of the week we will slide over.
213226
max_time <- max(x$time_value)
227+
# The days we will slide over.
214228
valid_slide_days <- seq.Date(
215229
from = ceiling_date(min(x$time_value), "week", week_start = day_of_week_end - 1),
216230
to = floor_date(max(x$time_value), "week", week_start = day_of_week_end - 1),
217231
by = 7L
218232
)
233+
# If the last day of the week is not the end of the week, add it to the
234+
# list of valid slide days (this will produce an incomplete slide, but
235+
# that's fine for us, since it should only be 1 day, historically.)
219236
if (wday(max_time) != day_of_week_end) {
220237
valid_slide_days <- c(valid_slide_days, max_time)
221238
}
222-
slid_result <- x %>%
239+
# Slide over the days and aggregate.
240+
x %>%
223241
group_by(across(all_of(keys))) %>%
224242
slide_fun(
225243
agg_columns,
@@ -229,18 +247,13 @@ daily_to_weekly_archive <- function(epi_arch,
229247
) %>%
230248
select(-all_of(agg_columns)) %>%
231249
rename_with(~ gsub("slide_value_", "", .x)) %>%
232-
# only keep 1/week
233-
# group_by week, keep the largest in each week
234-
# alternatively
235-
# switch time_value to the designated day of the week
250+
rename_with(~ gsub("_7dsum", "", .x)) %>%
251+
# Round all dates to reference day of the week. These will get
252+
# de-duplicated by compactify in as_epi_archive below.
236253
mutate(time_value = round_date(time_value, "week", day_of_week - 1)) %>%
237254
as_tibble()
238255
}
239-
)
240-
too_many_tibbles %>%
241-
pull(time_value) %>%
242-
max()
243-
too_many_tibbles %>%
256+
) %>%
244257
as_epi_archive(compactify = TRUE)
245258
}
246259

R/forecasters/epipredict_utilities.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ run_workflow_and_format <- function(preproc,
114114
if (is.null(as_of)) {
115115
as_of <- max(train_data$time_value)
116116
}
117+
118+
# Look at the train data (uncomment for debuggin).
119+
# df <- preproc %>% prep(train_data) %>% bake(train_data)
120+
# browser()
121+
117122
workflow <- epi_workflow(preproc, trainer) %>%
118123
fit(train_data) %>%
119124
add_frosting(postproc)

R/forecasters/forecaster_scaled_pop.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,16 @@ scaled_pop <- function(epi_data,
141141
by = c("geo_value" = "abbr")
142142
)
143143
}
144+
144145
# with all the setup done, we execute and format
145146
pred <- run_workflow_and_format(preproc, postproc, trainer, season_data, epi_data)
146147
# now pred has the columns
147148
# (geo_value, forecast_date, target_end_date, quantile, value)
148149
# finally, any postprocessing not supported by epipredict e.g. calibration
149150
# reintroduce color into the value
151+
# if (pred %>% distinct(forecast_date) %>% pull(forecast_date) == as.Date("2023-10-04")) {
152+
# browser()
153+
# }
150154
pred_final <- pred %>%
151155
rename({{ outcome }} := value) %>%
152156
data_coloring(outcome, learned_params, join_cols = key_colnames(epi_data, exclude = "time_value"), nonlin_method = nonlin_method) %>%

R/imports.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ library(parsnip)
2929
library(paws.storage)
3030
library(plotly)
3131
library(purrr)
32+
library(qs2)
3233
library(quantreg)
3334
library(readr)
3435
library(recipes)

R/looping.R

Lines changed: 1 addition & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,4 @@
1-
#' Generate forecaster predictions on a particular dataset
2-
#'
3-
#' A wrapper that turns a forecaster, parameters, data combination into an
4-
#' actual experiment that outputs a prediction for each day.
5-
#'
6-
#' @param archive the epi_df object
7-
#' @param outcome the name of the target column
8-
#' @param ahead the number of days ahead to forecast
9-
#' @param extra_sources any extra columns used for prediction that aren't
10-
#' the target
11-
#' @param forecaster a function that does the actual forecasting for a given
12-
#' day. See `exampleSpec.R` for an example function and its documentation
13-
#' for the general parameter requirements.
14-
#' @param slide_training a required parameter that governs how much data to
15-
#' exclude before starting the evaluation.
16-
#' @param n_training_pad a required parameter that determines how many extra
17-
#' samples for epix_slide to hand to the forecaster to guarantee that at
18-
#' least `ntraining` examples are available to the forecaster.
19-
#' @param forecaster_args the list of arguments to the forecaster; it must
20-
#' contain `ahead`
21-
#' @param forecaster_args_names a bit of a hack around targets, it contains
22-
#' the names of the `forecaster_args`.
23-
#' @param date_range_step_size the step size (in days) to use when generating
24-
#' the forecast dates.
25-
#' @param cache_key a unique identifier for the cache file
26-
#'
27-
#' @importFrom epiprocess epix_slide
28-
#' @importFrom cli cli_abort
29-
#' @importFrom rlang !!
30-
#' @export
31-
slide_forecaster <- function(epi_archive,
32-
outcome,
33-
ahead,
34-
forecaster = scaled_pop,
35-
slide_training = 0,
36-
n_training_pad = 5,
37-
forecaster_args = list(),
38-
forecaster_args_names = list(),
39-
ref_time_values = NULL,
40-
start_date = NULL,
41-
end_date = NULL,
42-
date_range_step_size = 1L,
43-
cache_key = NULL) {
44-
if (length(forecaster_args) > 0) {
45-
names(forecaster_args) <- forecaster_args_names
46-
}
47-
forecaster_args$ahead <- ahead
48-
if (!is.numeric(forecaster_args$n_training) && !is.null(forecaster_args$n_training)) {
49-
n_training <- as.numeric(forecaster_args$n_training)
50-
net_slide_training <- max(slide_training, n_training) + n_training_pad
51-
} else {
52-
n_training <- Inf
53-
net_slide_training <- slide_training + n_training_pad
54-
}
55-
if (is.null(ref_time_values)) {
56-
# restrict the dataset to areas where training is possible
57-
if (is.null(start_date)) {
58-
start_date <- min(epi_archive$DT$time_value) + net_slide_training
59-
}
60-
if (is.null(end_date)) {
61-
end_date <- max(epi_archive$DT$time_value) - forecaster_args$ahead
62-
}
63-
ref_time_values <- seq.Date(from = start_date, to = end_date, by = date_range_step_size)
64-
}
65-
66-
# first generate the forecasts
67-
before <- n_training + n_training_pad - 1
68-
forecaster_args <- rlang::dots_list(
69-
!!!list(
70-
outcome = outcome
71-
),
72-
!!!forecaster_args,
73-
.homonyms = "last"
74-
)
75-
forecaster_wrapper <- function(x) {
76-
rlang::inject(forecaster(epi_data = x, !!!forecaster_args))
77-
}
78-
epix_slide_simple(
79-
epi_archive,
80-
forecaster_wrapper,
81-
ref_time_values,
82-
before,
83-
cache_key = cache_key
84-
)
85-
}
86-
87-
epix_slide_simple <- function(epi_archive, forecaster, ref_time_values, before, cache_key = NULL) {
1+
epix_slide_simple <- function(epi_archive, forecaster, ref_time_values, before = Inf, cache_key = NULL) {
882
# this is so that changing the object without changing the name doesn't result in pulling the wrong cache
893
cache_hash <- rlang::hash(epi_archive)
904
dir.create(".exploration_cache/slide_cache", showWarnings = FALSE, recursive = TRUE)

R/scoring.R

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,51 @@
11
# Scoring and Evaluation Functions
22

3-
evaluate_predictions <- function(predictions_cards, truth_data) {
4-
checkmate::assert_data_frame(predictions_cards)
3+
evaluate_predictions <- function(forecasts, truth_data) {
4+
checkmate::assert_data_frame(forecasts)
55
checkmate::assert_data_frame(truth_data)
66
checkmate::assert_names(
7-
names(predictions_cards),
7+
names(forecasts),
88
must.include = c("model", "geo_value", "forecast_date", "target_end_date", "quantile", "prediction")
99
)
1010
checkmate::assert_names(
1111
names(truth_data),
1212
must.include = c("geo_value", "target_end_date", "true_value")
1313
)
1414

15-
left_join(predictions_cards, truth_data, by = c("geo_value", "target_end_date")) %>%
16-
scoringutils::score(metrics = c("interval_score", "ae_median", "coverage")) %>%
17-
scoringutils::add_coverage(by = c("model", "geo_value", "forecast_date", "target_end_date"), ranges = c(80)) %>%
18-
scoringutils::summarize_scores(by = c("model", "geo_value", "forecast_date", "target_end_date")) %>%
15+
# joined_forecasts <- left_join(forecasts, truth_data, by = c("geo_value", "target_end_date"))
16+
17+
# joined_forecasts %>%
18+
# group_by(model, geo_value, forecast_date, target_end_date) %>%
19+
# summarize(increasing = all(prediction - shift(prediction, 1, 0) > 0)) %>%
20+
# ungroup() %>%
21+
# filter(!increasing)
22+
23+
pred_final %>%
24+
group_by(geo_value, forecast_date, target_end_date) %>%
25+
summarize(increasing = all(value - shift(value, 1, 0) > 0)) %>%
26+
ungroup() %>%
27+
filter(!increasing)
28+
29+
# joined_forecasts %>% filter(geo_value == "ma", forecast_date == "2023-10-07", target_end_date == "2023-10-21") %>% print(n=50)
30+
31+
forecast_obj <- left_join(forecasts, truth_data, by = c("geo_value", "target_end_date")) %>%
32+
scoringutils::as_forecast_quantile(
33+
quantile_level = "quantile",
34+
observed = "true_value",
35+
predicted = "prediction",
36+
forecast_unit = c("model", "geo_value", "forecast_date", "target_end_date")
37+
)
38+
39+
# browser()
40+
scores <- forecast_obj %>%
41+
scoringutils::score(metrics = get_metrics(.)) %>%
1942
as_tibble() %>%
2043
select(
21-
model,
22-
geo_value,
23-
forecast_date,
24-
target_end_date,
25-
wis = interval_score,
44+
model, geo_value, forecast_date, target_end_date,
45+
wis,
2646
ae = ae_median,
27-
coverage_80
47+
coverage_50 = interval_coverage_50,
48+
coverage_90 = interval_coverage_90
2849
) %>%
2950
mutate(ahead = as.numeric(target_end_date - forecast_date))
3051
}

0 commit comments

Comments
 (0)