Skip to content

Commit 1105511

Browse files
committed
fix: make flu_hosp_prod work
1 parent 1523d41 commit 1105511

File tree

2 files changed

+24
-28
lines changed

2 files changed

+24
-28
lines changed

R/utils.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,18 +172,18 @@ data_substitutions <- function(dataset, disease, forecast_generation_date) {
172172
}
173173

174174
parse_prod_weights <- function(filename = here::here("covid_geo_exclusions.csv"),
175-
gen_forecast_date) {
175+
forecast_date) {
176176
all_states <- c(
177177
unique(readr::read_csv("https://raw.githubusercontent.com/cmu-delphi/covidcast-indicators/refs/heads/main/_delphi_utils_python/delphi_utils/data/2020/state_pop.csv", show_col_types = FALSE)$state_id),
178178
"usa", "us"
179179
)
180180
all_prod_weights <- readr::read_csv(filename, comment = "#", show_col_types = FALSE)
181181
# if we haven't set specific weights, use the overall defaults
182-
useful_prod_weights <- filter(all_prod_weights, forecast_date == gen_forecast_date)
182+
useful_prod_weights <- filter(all_prod_weights, forecast_date == forecast_date)
183183
if (nrow(useful_prod_weights) == 0) {
184184
useful_prod_weights <- all_prod_weights %>%
185185
filter(forecast_date == min(forecast_date)) %>%
186-
mutate(forecast_date = gen_forecast_date)
186+
mutate(forecast_date = forecast_date)
187187
}
188188
useful_prod_weights %>%
189189
mutate(

scripts/flu_hosp_prod.R

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,17 @@ rlang::list2(
118118
}
119119
),
120120
tar_map(
121+
# Because targets relies on R metaprogramming, it loses the Date class.
121122
values = tibble(
122-
forecast_date = forecast_date,
123-
forecast_generation_date = forecast_generation_date
123+
forecast_date_int = forecast_date,
124+
forecast_generation_date_int = forecast_generation_date
124125
),
125-
names = "forecast_date",
126+
names = "forecast_date_int",
126127
tar_target(
127128
name = geo_forecasters_weights,
128129
command = {
129-
geo_forecasters_weights <- parse_prod_weights(here::here("flu_geo_exclusions.csv"), forecast_date)
130-
if (nrow(geo_forecasters_weights %>% filter(forecast_date == forecast_date)) == 0) {
130+
geo_forecasters_weights <- parse_prod_weights(here::here("flu_geo_exclusions.csv"), forecast_date_int)
131+
if (nrow(geo_forecasters_weights %>% filter(forecast_date == forecast_date_int)) == 0) {
131132
cli_abort("there are no weights for the forecast date {forecast_date}")
132133
}
133134
geo_forecasters_weights
@@ -136,17 +137,14 @@ rlang::list2(
136137
),
137138
tar_target(
138139
name = geo_exclusions,
139-
command = {
140-
exclude_geos(geo_forecasters_weights)
141-
}
140+
command = exclude_geos(geo_forecasters_weights)
142141
),
143142
tar_target(
144143
forecast_res,
145144
command = {
146-
forecast_generation_date <- as.Date(forecast_generation_date)
147-
if (forecast_generation_date < Sys.Date()) {
145+
if (as.Date(forecast_generation_date_int) < Sys.Date()) {
148146
train_data <- nhsn_archive_data %>%
149-
epix_as_of(forecast_generation_date) %>%
147+
epix_as_of(as.Date(forecast_generation_date_int)) %>%
150148
add_season_info() %>%
151149
mutate(
152150
source = "nhsn",
@@ -159,7 +157,7 @@ rlang::list2(
159157
full_data <- train_data %>%
160158
bind_rows(joined_latest_extra_data)
161159
attributes(full_data)$metadata$other_keys <- "source"
162-
attributes(full_data)$metadata$as_of <- as.Date(forecast_date)
160+
attributes(full_data)$metadata$as_of <- as.Date(forecast_date_int)
163161
full_data %>%
164162
forecaster_fns[[forecasters]](ahead = aheads) %>%
165163
mutate(
@@ -173,8 +171,7 @@ rlang::list2(
173171
tar_target(
174172
name = ensemble_res,
175173
command = {
176-
forecasts <- forecast_res
177-
forecasts %>%
174+
forecast_res %>%
178175
mutate(quantile = round(quantile, digits = 3)) %>%
179176
left_join(geo_forecasters_weights, by = join_by(forecast_date, forecaster, geo_value)) %>%
180177
mutate(value = value * weight) %>%
@@ -205,7 +202,7 @@ rlang::list2(
205202
ensemble_mixture_res %>%
206203
format_flusight(disease = "flu") %>%
207204
write_submission_file(
208-
get_forecast_reference_date(forecast_date),
205+
get_forecast_reference_date(forecast_date_int),
209206
file.path(submission_directory, "model-output/CMU-TimeSeries")
210207
)
211208
},
@@ -223,7 +220,7 @@ rlang::list2(
223220
ungroup() %>%
224221
format_flusight(disease = "flu") %>%
225222
write_submission_file(
226-
get_forecast_reference_date(forecast_date),
223+
get_forecast_reference_date(forecast_date_int),
227224
submission_directory = file.path(submission_directory, "model-output/CMU-climatological-baseline"),
228225
file_name = "CMU-climatological-baseline"
229226
)
@@ -239,7 +236,7 @@ rlang::list2(
239236
if (submission_directory != "cache") {
240237
validation <- validate_submission(
241238
submission_directory,
242-
file_path = sprintf("CMU-TimeSeries/%s-CMU-TimeSeries.csv", get_forecast_reference_date(forecast_date))
239+
file_path = sprintf("CMU-TimeSeries/%s-CMU-TimeSeries.csv", get_forecast_reference_date(forecast_date_int))
243240
)
244241
} else {
245242
validation <- "not validating when there is no hub (set submission_directory)"
@@ -256,7 +253,7 @@ rlang::list2(
256253
if (submission_directory != "cache" && submit_climatological) {
257254
validation <- validate_submission(
258255
submission_directory,
259-
file_path = sprintf("CMU-climatological-baseline/%s-CMU-climatological-baseline.csv", get_forecast_reference_date(forecast_date))
256+
file_path = sprintf("CMU-climatological-baseline/%s-CMU-climatological-baseline.csv", get_forecast_reference_date(forecast_date_int))
260257
)
261258
} else {
262259
validation <- "not validating when there is no hub (set submission_directory)"
@@ -279,9 +276,8 @@ rlang::list2(
279276
select(geo_value, source, target_end_date = time_value, value) %>%
280277
filter(target_end_date > truth_data_date, geo_value %nin% insufficient_data_geos) %>%
281278
mutate(target_end_date = target_end_date + 6)
282-
forecast_generation_date <- as.Date(forecast_generation_date)
283-
if (forecast_generation_date < Sys.Date()) {
284-
truth_data <- nhsn_archive_data %>% epix_as_of(forecast_generation_date)
279+
if (as.Date(forecast_generation_date_int) < Sys.Date()) {
280+
truth_data <- nhsn_archive_data %>% epix_as_of(as.Date(forecast_generation_date_int))
285281
} else {
286282
truth_data <- nhsn_latest_data
287283
}
@@ -298,11 +294,11 @@ rlang::list2(
298294
full_join(
299295
truth_data %>%
300296
select(geo_value, target_end_date, value),
301-
by = join_by(geo_value, target_end_date)
297+
by = c("geo_value", "target_end_date")
302298
) %>%
303299
group_by(geo_value) %>%
304300
summarise(rel_max_value = max(value, na.rm = TRUE) / max(nssp, na.rm = TRUE)),
305-
by = join_by(geo_value)
301+
by = "geo_value"
306302
) %>%
307303
mutate(value = value * rel_max_value) %>%
308304
select(-rel_max_value)
@@ -318,13 +314,13 @@ rlang::list2(
318314
"scripts/reports/forecast_report.Rmd",
319315
output_file = here::here(
320316
"reports",
321-
sprintf("%s_flu_prod_on_%s.html", as.Date(forecast_date), as.Date(forecast_generation_date))
317+
sprintf("%s_flu_prod_on_%s.html", as.Date(forecast_date_int), as.Date(forecast_generation_date_int))
322318
),
323319
params = list(
324320
disease = "flu",
325321
forecast_res = forecast_res %>% bind_rows(ensemble_mixture_res %>% mutate(forecaster = "ensemble_mix")),
326322
ensemble_res = ensemble_res,
327-
forecast_date = as.Date(forecast_date),
323+
forecast_date = as.Date(forecast_date_int),
328324
truth_data = truth_data
329325
)
330326
)

0 commit comments

Comments
 (0)