Skip to content

Commit 5f3baaa

Browse files
committed
fix: add modified data forecaster, organize forecasts
1 parent 25867bf commit 5f3baaa

File tree

3 files changed

+103
-9
lines changed

3 files changed

+103
-9
lines changed

scripts/covid_hosp_prod.R

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,19 @@ rlang::list2(
147147
ungroup()
148148
},
149149
),
150+
tar_target(
151+
name = forecasts_and_ensembles,
152+
command = {
153+
bind_rows(
154+
forecast_res,
155+
ensemble_res %>% mutate(forecaster = "ensemble"),
156+
ensemble_mixture_res %>% mutate(forecaster = "ensemble_mix"),
157+
# TODO: Maybe later, match with flu_hosp_prod
158+
# ensemble_mixture_res_2 %>% mutate(forecaster = "ensemble_mix_2"),
159+
# combo_ensemble_mixture_res %>% mutate(forecaster = "combo_ensemble_mix")
160+
)
161+
}
162+
),
150163
tar_target(
151164
name = make_submission_csv,
152165
command = {
@@ -262,8 +275,7 @@ rlang::list2(
262275
),
263276
params = list(
264277
disease = "covid",
265-
forecast_res = forecast_res %>% bind_rows(ensemble_mixture_res %>% mutate(forecaster = "ensemble_mix")),
266-
ensemble_res = ensemble_res,
278+
forecast_res = forecasts_and_ensembles,
267279
forecast_date = as.Date(forecast_date_int),
268280
truth_data = truth_data
269281
)

scripts/flu_hosp_prod.R

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ forecaster_fns <- list2(
6969
# This is needed to build the data archive
7070
ref_time_values_ <- seq.Date(as.Date("2023-10-04"), as.Date("2024-04-24"), by = 7L)
7171

72+
smooth_last_n <- function(x, n = 1, k = 2) {
73+
x[(length(x) - (n - 1)):length(x)] <- mean(x[(length(x) - (k - 1)):length(x)], na.rm = TRUE)
74+
x
75+
}
7276

7377
rlang::list2(
7478
rlang::list2(
@@ -143,7 +147,7 @@ rlang::list2(
143147
command = exclude_geos(geo_forecasters_weights)
144148
),
145149
tar_target(
146-
forecast_res,
150+
full_data,
147151
command = {
148152
if (as.Date(forecast_generation_date_int) < Sys.Date()) {
149153
train_data <- nhsn_archive_data %>%
@@ -161,6 +165,12 @@ rlang::list2(
161165
bind_rows(joined_latest_extra_data)
162166
attributes(full_data)$metadata$other_keys <- "source"
163167
attributes(full_data)$metadata$as_of <- as.Date(forecast_date_int)
168+
full_data
169+
}
170+
),
171+
tar_target(
172+
forecast_res,
173+
command = {
164174
full_data %>%
165175
forecaster_fns[[forecasters]](ahead = aheads) %>%
166176
mutate(
@@ -171,6 +181,39 @@ rlang::list2(
171181
pattern = cross(aheads, forecasters),
172182
cue = tar_cue(mode = "always")
173183
),
184+
# A hack to model our uncertainty in the data. We smooth the last few points
185+
# to make the forecast more stable.
186+
tar_target(
187+
forecast_res_modified,
188+
command = {
189+
as_of <- attributes(full_data)$metadata$as_of
190+
other_keys <- attributes(full_data)$metadata$other_keys
191+
192+
# Smooth last few points for every geo.
193+
# TODO: This is a hack, we can try some more sophisticated
194+
# smoothing/nowcasting here.
195+
modified_full_data <- full_data %>%
196+
filter(source == "nhsn") %>%
197+
arrange(geo_value, time_value) %>%
198+
group_by(geo_value) %>%
199+
mutate(value = smooth_last_n(value)) %>%
200+
ungroup()
201+
# Add back in the non-nhsn data.
202+
modified_full_data <- modified_full_data %>%
203+
bind_rows(full_data %>% filter(source != "nhsn"))
204+
205+
attributes(modified_full_data)$metadata$as_of <- as_of
206+
attributes(modified_full_data)$metadata$other_keys <- other_keys
207+
modified_full_data %>%
208+
forecaster_fns[[forecasters]](ahead = aheads) %>%
209+
mutate(
210+
forecaster = names(forecaster_fns[forecasters]),
211+
geo_value = as.factor(geo_value)
212+
)
213+
},
214+
pattern = cross(aheads, forecasters),
215+
cue = tar_cue(mode = "always")
216+
),
174217
tar_target(
175218
name = ensemble_res,
176219
command = {
@@ -199,6 +242,48 @@ rlang::list2(
199242
sort_by_quantile()
200243
},
201244
),
245+
tar_target(
246+
name = ensemble_mixture_res_2,
247+
command = {
248+
forecast_res_modified %>%
249+
# Apply the ahead-by-quantile weighting scheme
250+
ensemble_linear_climate(aheads, other_weights = geo_forecasters_weights) %>%
251+
filter(geo_value %nin% geo_exclusions) %>%
252+
ungroup() %>%
253+
# Ensemble with windowed_seasonal
254+
bind_rows(forecast_res_modified %>% filter(forecaster == "windowed_seasonal")) %>%
255+
group_by(geo_value, forecast_date, target_end_date, quantile) %>%
256+
summarize(value = mean(value, na.rm = TRUE), .groups = "drop") %>%
257+
sort_by_quantile()
258+
}
259+
),
260+
tar_target(
261+
name = combo_ensemble_mixture_res,
262+
command = {
263+
inner_join(
264+
ensemble_mixture_res, ensemble_mixture_res_2,
265+
by = join_by(geo_value, forecast_date, target_end_date, quantile)
266+
) %>%
267+
rowwise() %>%
268+
mutate(value = ifelse(quantile > 0.5, max(value.x, value.y), NA)) %>%
269+
mutate(value = ifelse(quantile < 0.5, min(value.x, value.y), value)) %>%
270+
mutate(value = ifelse(quantile == 0.5, (value.x + value.y) / 2, value)) %>%
271+
select(geo_value, forecast_date, target_end_date, quantile, value) %>%
272+
ungroup()
273+
}
274+
),
275+
tar_target(
276+
name = forecasts_and_ensembles,
277+
command = {
278+
bind_rows(
279+
forecast_res,
280+
ensemble_res %>% mutate(forecaster = "ensemble"),
281+
ensemble_mixture_res %>% mutate(forecaster = "ensemble_mix"),
282+
ensemble_mixture_res_2 %>% mutate(forecaster = "ensemble_mix_2"),
283+
combo_ensemble_mixture_res %>% mutate(forecaster = "combo_ensemble_mix")
284+
)
285+
}
286+
),
202287
tar_target(
203288
name = make_submission_csv,
204289
command = {
@@ -321,8 +406,7 @@ rlang::list2(
321406
),
322407
params = list(
323408
disease = "flu",
324-
forecast_res = forecast_res %>% bind_rows(ensemble_mixture_res %>% mutate(forecaster = "ensemble_mix")),
325-
ensemble_res = ensemble_res,
409+
forecast_res = forecasts_and_ensembles,
326410
forecast_date = as.Date(forecast_date_int),
327411
truth_data = truth_data
328412
)

scripts/reports/forecast_report.Rmd

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ output:
1010
params:
1111
disease: "covid"
1212
forecast_res: !r ""
13-
ensemble_res: !r ""
1413
forecast_date: !r ""
1514
truth_data: !r ""
1615
---
@@ -51,7 +50,7 @@ Fan displays 20-80 quantiles for coverage.
5150

5251
```{r, fig.height = 60, fig.width = 12, echo=FALSE}
5352
the_plot <- plot_forecasts(
54-
params$forecast_res %>% bind_rows(params$ensemble_res %>% mutate(forecaster = "ensemble")),
53+
params$forecast_res,
5554
params$forecast_date,
5655
params$truth_data,
5756
quantiles = c(0.8),
@@ -70,7 +69,7 @@ Fan displays 20-80, 5-95, and 1-99 quantiles.
7069

7170
```{r, fig.height = 60, fig.width = 12, echo=FALSE}
7271
the_plot <- plot_forecasts(
73-
params$forecast_res %>% bind_rows(params$ensemble_res %>% mutate(forecaster = "ensemble")),
72+
params$forecast_res,
7473
params$forecast_date,
7574
params$truth_data,
7675
quantiles = c(0.8, 0.95, 0.99),
@@ -88,7 +87,6 @@ ggplotly(the_plot, tooltip = "text", height = 9000, width = 2000) %>%
8887
<!-- ```{r, fig.width = 12, echo=FALSE} -->
8988
<!-- the_plot <- plot_forecasts( -->
9089
<!-- params$forecast_res %>% -->
91-
<!-- bind_rows(params$ensemble_res %>% mutate(forecaster = "ensemble")) %>% -->
9290
<!-- filter(geo_value != "usa") %>% -->
9391
<!-- summarise( -->
9492
<!-- value = sum(value), -->

0 commit comments

Comments
 (0)