Skip to content

Commit 80bdcaa

Browse files
committed
wip
1 parent a78a707 commit 80bdcaa

File tree

3 files changed

+63
-56
lines changed

3 files changed

+63
-56
lines changed

R/scoring.R

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,24 @@ evaluate_predictions <- function(forecasts, truth_data) {
1212
must.include = c("geo_value", "target_end_date", "true_value")
1313
)
1414

15-
# joined_forecasts <- left_join(forecasts, truth_data, by = c("geo_value", "target_end_date"))
15+
browser()
16+
joined_forecasts <- left_join(forecasts, truth_data, by = c("geo_value", "target_end_date"))
1617

17-
# joined_forecasts %>%
18-
# group_by(model, geo_value, forecast_date, target_end_date) %>%
19-
# summarize(increasing = all(prediction - shift(prediction, 1, 0) > 0)) %>%
20-
# ungroup() %>%
21-
# filter(!increasing)
22-
23-
pred_final %>%
24-
group_by(geo_value, forecast_date, target_end_date) %>%
25-
summarize(increasing = all(value - shift(value, 1, 0) > 0)) %>%
18+
joined_forecasts %>%
19+
group_by(model, geo_value, forecast_date, target_end_date) %>%
20+
summarize(increasing = all(prediction - shift(prediction, 1, 0) > 0)) %>%
2621
ungroup() %>%
2722
filter(!increasing)
2823

24+
joined_forecasts %>%
25+
filter(geo_value == "ma", forecast_date == "2023-10-21", target_end_date == "2023-10-21")
26+
27+
# pred_final %>%
28+
# group_by(geo_value, forecast_date, target_end_date) %>%
29+
# summarize(increasing = all(value - shift(value, 1, 0) > 0)) %>%
30+
# ungroup() %>%
31+
# filter(!increasing)
32+
2933
# joined_forecasts %>% filter(geo_value == "ma", forecast_date == "2023-10-07", target_end_date == "2023-10-21") %>% print(n=50)
3034

3135
forecast_obj <- left_join(forecasts, truth_data, by = c("geo_value", "target_end_date")) %>%
@@ -36,7 +40,6 @@ evaluate_predictions <- function(forecasts, truth_data) {
3640
forecast_unit = c("model", "geo_value", "forecast_date", "target_end_date")
3741
)
3842

39-
# browser()
4043
scores <- forecast_obj %>%
4144
scoringutils::score(metrics = get_metrics(.)) %>%
4245
as_tibble() %>%

R/utils.R

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,25 @@
77
#' @param pattern string to search in the forecaster name.
88
#'
99
#' @export
10-
forecaster_lookup <- function(pattern, forecaster_grid = NULL, printing = TRUE) {
10+
forecaster_lookup <- function(pattern, forecaster_grid = NULL) {
1111
if (is.null(forecaster_grid)) {
12-
cli::cli_warn("Reading `forecaster_param_combinations` target. If it's not up to date, results will be off. Update with `tar_make(forecaster_parameter_combinations)`.")
13-
forecaster_grid <- tar_read_raw("forecaster_parameter_combinations") %>%
14-
map(make_forecaster_grid) %>%
15-
bind_rows()
12+
cli::cli_warn("Reading `forecaster_param_combinations` target. If it's not up to date, results will be off.
13+
Update with `tar_make(forecaster_parameter_combinations)`.")
14+
forecaster_grid <- tar_read_raw("forecaster_parameter_grid")
1615
}
17-
fc_row <- forecaster_grid %>% filter(grepl(pattern, id))
18-
if (printing) {
19-
params <- fc_row$params[[1]]
20-
if (!is.null(params$trainer)) {
21-
params$trainer <- as_string(params$trainer)
16+
17+
# Remove the "forecaster_" prefix from the pattern if it exists.
18+
if (grepl("forecaster_", pattern)) {
19+
pattern <- gsub("forecaster_", "", pattern)
20+
}
21+
22+
for (table in forecaster_grid) {
23+
filtered_table <- table %>% filter(grepl(pattern, id))
24+
if (nrow(filtered_table) > 0) {
25+
filtered_table %>% glimpse()
26+
break
2227
}
23-
print(glue::glue("name: {fc_row %>% pull(id)}"))
24-
print(glue::glue("forecaster: {fc_row$forecaster[[1]]}"))
25-
print(glue::glue("params:"))
26-
print(params %>% data.table::as.data.table())
2728
}
28-
return(fc_row)
2929
}
3030

3131
#' Add a unique id based on the column contents
@@ -37,8 +37,6 @@ forecaster_lookup <- function(pattern, forecaster_grid = NULL, printing = TRUE)
3737
#'
3838
#' @export
3939
add_id <- function(tib, exclude = c()) {
40-
browser()
41-
# TODO
4240
ids <- tib %>%
4341
select(-all_of(exclude)) %>%
4442
purrr::transpose() %>%
@@ -63,7 +61,7 @@ get_single_id <- function(param_list) {
6361
}
6462

6563
#' Turn a tibble of parameters into a list of named lists.
66-
make_params_list <- function(df, singleton_cols = c("trainer")) {
64+
make_params_list <- function(df, unlist_cols = c("lags", "trainer"), get_cols = c("trainer")) {
6765
params_list <- df %>%
6866
select(-forecaster, -id) %>%
6967
split(seq_len(nrow(.))) %>%
@@ -72,9 +70,10 @@ make_params_list <- function(df, singleton_cols = c("trainer")) {
7270
names(params_list) <- df$id
7371

7472
# Some columns need to be unlisted.
75-
if (length(singleton_cols) > 0) {
73+
unlist_cols <- unlist_cols[unlist_cols %in% names(params_list[[1]])]
74+
if (length(unlist_cols) > 0) {
7675
params_list %<>% lapply(function(x) {
77-
for (col in singleton_cols) {
76+
for (col in unlist_cols) {
7877
if (length(x[[col]]) == 1) {
7978
x[[col]] <- x[[col]][[1]]
8079
}
@@ -83,13 +82,16 @@ make_params_list <- function(df, singleton_cols = c("trainer")) {
8382
})
8483
}
8584

86-
# We also should remove one layer of nesting from lags, but only if they're a list of length 1.
87-
params_list %<>% map(function(x) {
88-
if (!is.null(x$lags) && length(x$lags) == 1) {
89-
x$lags <- x$lags[[1]]
90-
}
91-
x
92-
})
85+
# Some columns need to be converted to symbols.
86+
get_cols <- get_cols[get_cols %in% names(params_list[[1]])]
87+
if (length(get_cols) > 0) {
88+
params_list %<>% lapply(function(x) {
89+
for (col in get_cols) {
90+
x[[col]] <- get(x[[col]])
91+
}
92+
x
93+
})
94+
}
9395

9496
return(params_list)
9597
}
@@ -371,7 +373,7 @@ get_recent_targets_errors <- function(recent_minutes = 60) {
371373
}
372374

373375
other_errors <- targets::tar_meta() %>%
374-
filter(time > Sys.time() - days(recent_days), !is.na(error)) %>%
376+
filter(time > Sys.time() - minutes(recent_minutes), !is.na(error)) %>%
375377
arrange(desc(time)) %>%
376378
distinct(error, .keep_all = TRUE) %>%
377379
select(time, name, error)
@@ -388,5 +390,3 @@ get_recent_targets_errors <- function(recent_minutes = 60) {
388390
}
389391
}
390392
}
391-
# Alias
392-
grte <- get_recent_targets_errors

scripts/covid_hosp_explore.R

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ forecast_dates <- forecast_dates[1:10]
2727
# Human-readable object to be used for inspecting the forecasters in the pipeline.
2828
forecaster_parameter_combinations <- rlang::list2(
2929
scaled_pop_main = tidyr::expand_grid(
30-
forecaster = list(scaled_pop),
31-
trainer = list(linreg, quantreg),
30+
forecaster = "scaled_pop",
31+
trainer = list("linreg", "quantreg"),
3232
lags = list(
3333
c(0, 7),
3434
c(0, 7, 14),
@@ -39,13 +39,13 @@ forecaster_parameter_combinations <- rlang::list2(
3939
n_training = Inf
4040
),
4141
flatline_forecaster = tidyr::expand_grid(
42-
forecaster = list(flatline_fc),
42+
forecaster = "flatline_fc",
4343
),
4444
# using exogenous variables
4545
scaled_pop_exogenous = bind_rows(
4646
expand_grid(
47-
forecaster = list(scaled_pop),
48-
trainer = list(quantreg),
47+
forecaster = "scaled_pop",
48+
trainer = "quantreg",
4949
# since it's a list, this gets expanded out to a single one in each row
5050
extra_sources = list2("nssp", "google_symptoms_4_bronchitis", "google_symptoms", "nwss", "nwss_region"),
5151
lags = list2(
@@ -68,8 +68,8 @@ forecaster_parameter_combinations <- rlang::list2(
6868
drop_non_seasons = FALSE,
6969
),
7070
expand_grid(
71-
forecaster = list(scaled_pop),
72-
trainer = list(quantreg),
71+
forecaster = "scaled_pop",
72+
trainer = "quantreg",
7373
extra_sources = list2(
7474
## c("dr_visits", "google_symptoms"),
7575
## c("dr_visits", "nssp"),
@@ -99,8 +99,8 @@ forecaster_parameter_combinations <- rlang::list2(
9999
drop_non_seasons = FALSE,
100100
),
101101
expand_grid(
102-
forecaster = list(scaled_pop),
103-
trainer = list(quantreg),
102+
forecaster = "scaled_pop",
103+
trainer = "quantreg",
104104
extra_sources = list2(
105105
c("nssp", "google_symptoms", "nwss", "nwss_region"),
106106
),
@@ -131,8 +131,8 @@ forecaster_parameter_combinations <- rlang::list2(
131131
)
132132
),
133133
scaled_pop_season = tidyr::expand_grid(
134-
forecaster = list(scaled_pop_seasonal),
135-
trainer = list(quantreg),
134+
forecaster = "scaled_pop_seasonal",
135+
trainer = "quantreg",
136136
lags = list(
137137
c(0, 7, 14, 21),
138138
c(0, 7)
@@ -144,7 +144,7 @@ forecaster_parameter_combinations <- rlang::list2(
144144
) %>%
145145
map(function(x) {
146146
if (dummy_mode) {
147-
x$forecaster <- list(dummy_forecaster)
147+
x$forecaster <- "dummy_forecaster"
148148
}
149149
x
150150
}) %>%
@@ -193,17 +193,21 @@ get_partially_applied_forecaster <- function(id) {
193193
!!!params_list[[id]],
194194
.homonyms = "last"
195195
)
196-
rlang::inject(forecaster_functions_list[[id]](epi_data = epi_data, !!!forecaster_args))
196+
# This uses string lookup to get the function.
197+
forecaster_fn <- get(forecaster_functions_list[[id]])
198+
rlang::inject(forecaster_fn(epi_data = epi_data, !!!forecaster_args))
197199
}
198200
}
199201

200202
# ================================ TARGETS =================================
201203
# ================================ PARAMETERS TARGETS ======================
202204
parameter_targets <- list2(
203-
# tar_target(name = forecaster_names, command = forecaster_functions_list %>% names()),
204205
tar_target(name = aheads, command = c(0, 7, 14, 21)),
206+
tar_target(name = ref_time_values, command = forecast_dates),
207+
# This is used for parameter lookup.
208+
tar_target(name = forecaster_parameter_grid, command = forecaster_parameter_combinations),
209+
# This is used for generating notebooks.
205210
tar_target(name = forecaster_families, command = forecaster_parameter_combinations %>% names()),
206-
tar_target(name = ref_time_values, command = forecast_dates)
207211
)
208212
# ================================ DATA TARGETS ==============================
209213
data_targets <- list2(

0 commit comments

Comments
 (0)