11# ' epi_data is expected to have: geo_value, time_value, and value columns.
2- forecaster_baseline_linear <- function (epi_data , ahead , log = FALSE , sort = FALSE , residual_tail = 0.85 , residual_center = 0.085 ) {
2+ forecaster_baseline_linear <- function (epi_data , ahead , log = FALSE , sort = FALSE , residual_tail = 0.85 , residual_center = 0.085 , no_intercept = FALSE ) {
33 forecast_date <- attributes(epi_data )$ metadata $ as_of
44 population_data <- get_population_data() %> %
55 rename(geo_value = state_id ) %> %
@@ -29,15 +29,45 @@ forecaster_baseline_linear <- function(epi_data, ahead, log = FALSE, sort = FALS
2929 group_by(geo_value ) %> %
3030 filter(! (is.na(value ) | is.infinite(value ))) %> %
3131 filter(n() > = 2 ) %> %
32- mutate(weeks_back = as.integer(time_value - epi_as_of(df_processed )) / 7 )
33-
34- point_forecast <- tibble(
35- geo_value = train_data $ geo_value %> % unique(),
36- model = map(geo_value , ~ lm(value ~ weeks_back , data = train_data %> % filter(geo_value == .x ))),
37- value = map_dbl(model , ~ predict(.x , newdata = data.frame (weeks_back = ahead ))),
38- season_week = target_season_week
39- )
32+ mutate(weeks_back = as.integer(time_value - max(df_processed $ time_value )) / 7 )
33+ latency <- as.integer(ceiling((epi_as_of(df_processed ) - max(df_processed $ time_value )) / 7 ))
4034
35+ if (no_intercept ) {
36+ # set the intercept column to be the last value for that geo
37+ intercept_values <-
38+ train_data %> %
39+ select(geo_value , time_value , value , weeks_back ) %> %
40+ group_by(geo_value ) %> %
41+ filter(time_value == max(time_value )) %> %
42+ rename(intercept = value ) %> %
43+ select(geo_value , intercept ) %> %
44+ ungroup()
45+ train_data <- train_data %> %
46+ left_join(intercept_values , by = join_by(geo_value ))
47+ point_forecast <- tibble(
48+ geo_value = train_data $ geo_value %> % unique(),
49+ model = purrr :: map(geo_value , ~ lm(value ~ weeks_back + 0 , data = train_data %> % filter(geo_value == .x ), offset = intercept )),
50+ season_week = target_season_week
51+ ) %> %
52+ left_join(intercept_values , by = join_by(geo_value )) %> %
53+ mutate(
54+ value = map2_vec(
55+ model , intercept ,
56+ \(model_x , intercept_y )
57+ # need to add the latency
58+ predict(model_x , newdata = data.frame (weeks_back = ahead + latency , intercept = intercept_y ))
59+ )
60+ ) %> %
61+ select(- intercept )
62+ } else {
63+ point_forecast <- tibble(
64+ geo_value = train_data $ geo_value %> % unique(),
65+ model = map(geo_value , ~ lm(value ~ weeks_back , data = train_data %> % filter(geo_value == .x ))),
66+ # ahead is +1 b/c the data is 1 week latent
67+ value = map_dbl(model , ~ predict(.x , newdata = data.frame (weeks_back = ahead + latency ))),
68+ season_week = target_season_week
69+ )
70+ }
4171 missing_geos <- setdiff(geos , unique(point_forecast $ geo_value ))
4272
4373 point_forecast <- bind_rows(
@@ -75,7 +105,8 @@ forecaster_baseline_linear <- function(epi_data, ahead, log = FALSE, sort = FALS
75105 epipredict ::: propagate_samples(residuals , point , quantile_levels = covidhub_probs(), aheads = ahead + 2 , nsim = 1e4 , symmetrize = TRUE , nonneg = FALSE )[[1 ]] %> %
76106 pull(.pred_distn )
77107 }
78- quantile_forecast <- point_forecast %> %
108+ quantile_forecast <-
109+ point_forecast %> %
79110 rowwise() %> %
80111 mutate(quantile = get_quantile(value , ahead ) %> % nested_quantiles()) %> %
81112 unnest(quantile ) %> %
@@ -95,4 +126,5 @@ forecaster_baseline_linear <- function(epi_data, ahead, log = FALSE, sort = FALS
95126 ) %> %
96127 select(- model , - values , - population , - season_week ) %> %
97128 mutate(value = pmax(0 , value ))
129+ quantile_forecast
98130}
0 commit comments