Skip to content

Commit 8196dda

Browse files
authored
fix+refactor: cleanup pipelines (#169)
* fix covid and flu explore to functioning state * move shared utilities to new directory * switch to new convention of marking targets global variables with g_ * share convention with prod * move prod to using params_grid idiom * lots of small tweaks to utilities and bug fixes * use_crew = yes default * simplify targets options * remove flu_hosp_tiny * update notebooks
1 parent 64e76a1 commit 8196dda

36 files changed

+2403
-2865
lines changed

Makefile

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ run:
1313
Rscript scripts/run.R
1414

1515
prod-covid:
16-
export TAR_RUN_PROJECT=covid_hosp_prod; \
17-
Rscript scripts/run_prod.R
16+
export TAR_PROJECT=covid_hosp_prod; \
17+
Rscript scripts/run.R
1818

1919
prod-flu:
20-
export TAR_RUN_PROJECT=flu_hosp_prod; \
21-
Rscript scripts/run_prod.R
20+
export TAR_PROJECT=flu_hosp_prod; \
21+
Rscript scripts/run.R
2222

2323
prod: prod-covid prod-flu update_site netlify
2424

R/aux_data_utils.R

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,17 @@ gen_pop_and_density_data <-
176176
)
177177
}
178178

179-
daily_to_weekly <- function(epi_df, agg_method = c("sum", "mean"), day_of_week = 4L, day_of_week_end = 7L, keys = "geo_value", values = c("value")) {
179+
#' Aggregate a daily archive to a weekly archive.
180+
#'
181+
#' By default, aggregates from Sunday to Saturday and labels with the Wednesday
182+
#' of that week.
183+
#'
184+
#' @param epi_df the archive to aggregate.
185+
#' @param agg_method the method to use to aggregate the data, one of "sum" or "mean".
186+
#' @param keys the columns to group by.
187+
#' @param values the columns to aggregate.
188+
daily_to_weekly <- function(epi_df, agg_method = c("sum", "mean"), keys = "geo_value", values = c("value")) {
189+
agg_method <- arg_match(agg_method)
180190
epi_df %>%
181191
mutate(epiweek = epiweek(time_value), year = epiyear(time_value)) %>%
182192
group_by(across(any_of(c(keys, "epiweek", "year")))) %>%
@@ -299,6 +309,13 @@ drop_non_seasons <- function(epi_data, min_window = 12) {
299309
)
300310
}
301311

312+
get_nwss_coarse_data <- function(disease = c("covid", "flu")) {
313+
disease <- arg_match(disease)
314+
aws.s3::get_bucket_df(prefix = glue::glue("exploration/aux_data/nwss_{disease}_data"), bucket = "forecasting-team-data") %>%
315+
slice_max(LastModified) %>%
316+
pull(Key) %>%
317+
aws.s3::s3read_using(FUN = readr::read_csv, object = ., bucket = "forecasting-team-data")
318+
}
302319

303320
#' add a column summing the values in the hhs region
304321
#' @param hhs_region_table the region table

R/forecasters/data_transforms.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,7 @@ data_whitening <- function(epi_data, colname, learned_params, nonlin_method = c(
234234
join_cols <- key_colnames(epi_data, exclude = "time_value")
235235
}
236236
nonlin_method <- arg_match(nonlin_method)
237-
res <- epi_data %>%
238-
left_join(learned_params, by = join_cols)
237+
res <- epi_data %>% left_join(learned_params, by = join_cols)
239238
if (nonlin_method == "quart_root") {
240239
res %<>% mutate(across(all_of(colname), ~ (.x + 0.01)^(1 / 4)))
241240
}

R/forecasters/data_validation.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ confirm_sufficient_data <- function(epi_data, ahead, args_input, outcome, extra_
6969
# TODO: Buffer should probably be 2 * n(lags) * n(predictors). But honestly,
7070
# this needs to be fixed in epipredict itself, see
7171
# https://github.com/cmu-delphi/epipredict/issues/106.
72-
if (extra_sources == c("")) {
72+
if (identical(extra_sources, "")) {
7373
extra_sources <- character(0L)
7474
}
7575
has_no_last_nas <- epi_data %>%

R/forecasters/ensemble_linear_climate.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#' @param other_weights if non null, it should be a tibble giving a list of weights by forecaster and geo_value
1515
#' @importFrom rlang %||%
1616
#' @export
17-
ensemble_linear_climate <- function(forecasts,
17+
ensemble_climate_linear <- function(forecasts,
1818
aheads,
1919
other_weights = NULL,
2020
probs = covidhub_probs(),

R/forecasters/forecaster_climatological.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ climate_linear_ensembled <- function(epi_data,
5959
pred_geo_climate <- climatological_model(epi_data, ahead, geo_agg = FALSE) %>% mutate(forecaster = "climate_geo")
6060
pred_linear <- forecaster_baseline_linear(epi_data, ahead, residual_tail = residual_tail, residual_center = residual_center) %>% mutate(forecaster = "linear")
6161
pred <- bind_rows(pred_climate, pred_linear, pred_geo_climate) %>%
62-
ensemble_linear_climate((args_list$aheads[[1]]) / 7) %>%
62+
ensemble_climate_linear((args_list$aheads[[1]]) / 7) %>%
6363
ungroup()
6464
# undo whitening
6565
pred_final <- pred %>%

R/forecasters/forecaster_no_recent_outcome.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,13 @@ no_recent_outcome <- function(epi_data,
6262
args_input[["quantile_levels"]] <- quantile_levels
6363
args_list <- do.call(default_args_list, args_input)
6464
# if you want to hardcode particular predictors in a particular forecaster
65-
predictors <- c(outcome, extra_sources[[1]])
66-
c(args_list, tmp_pred, trainer) %<-% sanitize_args_predictors_trainer(epi_data, outcome, predictors, trainer, args_list)
67-
if (extra_sources[[1]] == "") {
65+
if (identical(extra_sources[[1]], "")) {
6866
predictors <- character()
6967
} else {
7068
predictors <- extra_sources[[1]]
7169
}
70+
c(args_list, tmp_pred, trainer) %<-% sanitize_args_predictors_trainer(epi_data, outcome, predictors, trainer, args_list)
71+
7272
# end of the copypasta
7373
# finally, any other pre-processing (e.g. smoothing) that isn't performed by
7474
# epipredict

R/forecasters/forecaster_scaled_pop_seasonal.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ scaled_pop_seasonal <- function(epi_data,
5959

6060
epi_data <- validate_epi_data(epi_data)
6161

62+
# TODO: handle this when creating param grid?
6263
if (typeof(seasonal_method) == "list") {
6364
seasonal_method <- seasonal_method[[1]]
6465
}

R/looping.R

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,24 +84,27 @@ slide_forecaster <- function(epi_archive,
8484
)
8585
}
8686

87-
epix_slide_simple <- function(epi_archive, forecaster, ref_time_values, before, cache_key = NULL) {
87+
88+
epix_slide_simple <- function(epi_archive, forecaster, ref_time_values, before = Inf, cache_key = NULL) {
8889
# this is so that changing the object without changing the name doesn't result in pulling the wrong cache
8990
cache_hash <- rlang::hash(epi_archive)
9091
dir.create(".exploration_cache/slide_cache", showWarnings = FALSE, recursive = TRUE)
91-
purrr::map(ref_time_values, function(tv) {
92+
out <- purrr::map(ref_time_values, function(tv) {
9293
if (is.null(cache_key)) {
9394
epi_df <- epi_archive %>%
94-
epix_as_of(tv, min_time_value = tv - before)
95+
epix_as_of(min(tv, .$versions_end), min_time_value = tv - before)
9596
} else {
9697
file_path <- glue::glue(".exploration_cache/slide_cache/{cache_key}_{cache_hash}_{before}_{tv}.parquet")
9798
if (file.exists(file_path)) {
9899
epi_df <- qs::qread(file_path)
99100
} else {
100101
epi_df <- epi_archive %>%
101-
epix_as_of(tv, min_time_value = tv - before)
102+
epix_as_of(min(tv, .$versions_end), min_time_value = tv - before)
102103
qs::qsave(epi_df, file_path)
103104
}
104105
}
105106
epi_df %>% forecaster()
106107
}) %>% bind_rows()
108+
gc()
109+
return(out)
107110
}

R/scoring.R

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,24 @@ evaluate_predictions <- function(forecasts, truth_data) {
2222

2323
scores <- forecast_obj %>%
2424
scoringutils::score(metrics = get_metrics(.)) %>%
25-
as_tibble() %>%
25+
as_tibble()
26+
missing_metrics <- setdiff(
27+
c("model", "geo_value", "forecast_date", "target_end_date", "wis", "ae_median", "interval_coverage_50", "interval_coverage_90"),
28+
names(scores)
29+
)
30+
if (length(missing_metrics) > 0) {
31+
cli::cli_abort(c(
32+
"scoring error",
33+
"i" = "missing metrics: {missing_metrics}",
34+
"i" = "if wis is missing, then likely quantile monotonicity was violated"
35+
))
36+
}
37+
scores %>%
2638
select(
27-
model, geo_value, forecast_date, target_end_date,
39+
model,
40+
geo_value,
41+
forecast_date,
42+
target_end_date,
2843
wis,
2944
ae = ae_median,
3045
coverage_50 = interval_coverage_50,

0 commit comments

Comments
 (0)