@@ -69,6 +69,10 @@ forecaster_fns <- list2(
6969# This is needed to build the data archive
7070ref_time_values_ <- seq.Date(as.Date(" 2023-10-04" ), as.Date(" 2024-04-24" ), by = 7L )
7171
72+ smooth_last_n <- function (x , n = 1 , k = 2 ) {
73+ x [(length(x ) - (n - 1 )): length(x )] <- mean(x [(length(x ) - (k - 1 )): length(x )], na.rm = TRUE )
74+ x
75+ }
7276
7377rlang :: list2(
7478 rlang :: list2(
@@ -143,7 +147,7 @@ rlang::list2(
143147 command = exclude_geos(geo_forecasters_weights )
144148 ),
145149 tar_target(
146- forecast_res ,
150+ full_data ,
147151 command = {
148152 if (as.Date(forecast_generation_date_int ) < Sys.Date()) {
149153 train_data <- nhsn_archive_data %> %
@@ -161,6 +165,12 @@ rlang::list2(
161165 bind_rows(joined_latest_extra_data )
162166 attributes(full_data )$ metadata $ other_keys <- " source"
163167 attributes(full_data )$ metadata $ as_of <- as.Date(forecast_date_int )
168+ full_data
169+ }
170+ ),
171+ tar_target(
172+ forecast_res ,
173+ command = {
164174 full_data %> %
165175 forecaster_fns [[forecasters ]](ahead = aheads ) %> %
166176 mutate(
@@ -171,6 +181,39 @@ rlang::list2(
171181 pattern = cross(aheads , forecasters ),
172182 cue = tar_cue(mode = " always" )
173183 ),
184+ # A hack to model our uncertainty in the data. We smooth the last few points
185+ # to make the forecast more stable.
186+ tar_target(
187+ forecast_res_modified ,
188+ command = {
189+ as_of <- attributes(full_data )$ metadata $ as_of
190+ other_keys <- attributes(full_data )$ metadata $ other_keys
191+
192+ # Smooth last few points for every geo.
193+ # TODO: This is a hack, we can try some more sophisticated
194+ # smoothing/nowcasting here.
195+ modified_full_data <- full_data %> %
196+ filter(source == " nhsn" ) %> %
197+ arrange(geo_value , time_value ) %> %
198+ group_by(geo_value ) %> %
199+ mutate(value = smooth_last_n(value )) %> %
200+ ungroup()
201+ # Add back in the non-nhsn data.
202+ modified_full_data <- modified_full_data %> %
203+ bind_rows(full_data %> % filter(source != " nhsn" ))
204+
205+ attributes(modified_full_data )$ metadata $ as_of <- as_of
206+ attributes(modified_full_data )$ metadata $ other_keys <- other_keys
207+ modified_full_data %> %
208+ forecaster_fns [[forecasters ]](ahead = aheads ) %> %
209+ mutate(
210+ forecaster = names(forecaster_fns [forecasters ]),
211+ geo_value = as.factor(geo_value )
212+ )
213+ },
214+ pattern = cross(aheads , forecasters ),
215+ cue = tar_cue(mode = " always" )
216+ ),
174217 tar_target(
175218 name = ensemble_res ,
176219 command = {
@@ -199,6 +242,48 @@ rlang::list2(
199242 sort_by_quantile()
200243 },
201244 ),
245+ tar_target(
246+ name = ensemble_mixture_res_2 ,
247+ command = {
248+ forecast_res_modified %> %
249+ # Apply the ahead-by-quantile weighting scheme
250+ ensemble_linear_climate(aheads , other_weights = geo_forecasters_weights ) %> %
251+ filter(geo_value %nin % geo_exclusions ) %> %
252+ ungroup() %> %
253+ # Ensemble with windowed_seasonal
254+ bind_rows(forecast_res_modified %> % filter(forecaster == " windowed_seasonal" )) %> %
255+ group_by(geo_value , forecast_date , target_end_date , quantile ) %> %
256+ summarize(value = mean(value , na.rm = TRUE ), .groups = " drop" ) %> %
257+ sort_by_quantile()
258+ }
259+ ),
260+ tar_target(
261+ name = combo_ensemble_mixture_res ,
262+ command = {
263+ inner_join(
264+ ensemble_mixture_res , ensemble_mixture_res_2 ,
265+ by = join_by(geo_value , forecast_date , target_end_date , quantile )
266+ ) %> %
267+ rowwise() %> %
268+ mutate(value = ifelse(quantile > 0.5 , max(value.x , value.y ), NA )) %> %
269+ mutate(value = ifelse(quantile < 0.5 , min(value.x , value.y ), value )) %> %
270+ mutate(value = ifelse(quantile == 0.5 , (value.x + value.y ) / 2 , value )) %> %
271+ select(geo_value , forecast_date , target_end_date , quantile , value ) %> %
272+ ungroup()
273+ }
274+ ),
275+ tar_target(
276+ name = forecasts_and_ensembles ,
277+ command = {
278+ bind_rows(
279+ forecast_res ,
280+ ensemble_res %> % mutate(forecaster = " ensemble" ),
281+ ensemble_mixture_res %> % mutate(forecaster = " ensemble_mix" ),
282+ ensemble_mixture_res_2 %> % mutate(forecaster = " ensemble_mix_2" ),
283+ combo_ensemble_mixture_res %> % mutate(forecaster = " combo_ensemble_mix" )
284+ )
285+ }
286+ ),
202287 tar_target(
203288 name = make_submission_csv ,
204289 command = {
@@ -321,8 +406,7 @@ rlang::list2(
321406 ),
322407 params = list (
323408 disease = " flu" ,
324- forecast_res = forecast_res %> % bind_rows(ensemble_mixture_res %> % mutate(forecaster = " ensemble_mix" )),
325- ensemble_res = ensemble_res ,
409+ forecast_res = forecasts_and_ensembles ,
326410 forecast_date = as.Date(forecast_date_int ),
327411 truth_data = truth_data
328412 )
0 commit comments