Skip to content

Commit 5f7892b

Browse files
committed
adding forced intercept result
1 parent a29e080 commit 5f7892b

File tree

4 files changed

+53
-12
lines changed

4 files changed

+53
-12
lines changed

R/forecasters/forecaster_baseline_linear.R

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#' epi_data is expected to have: geo_value, time_value, and value columns.
2-
forecaster_baseline_linear <- function(epi_data, ahead, log = FALSE, sort = FALSE, residual_tail = 0.85, residual_center = 0.085) {
2+
forecaster_baseline_linear <- function(epi_data, ahead, log = FALSE, sort = FALSE, residual_tail = 0.85, residual_center = 0.085, no_intercept = FALSE) {
33
forecast_date <- attributes(epi_data)$metadata$as_of
44
population_data <- get_population_data() %>%
55
rename(geo_value = state_id) %>%
@@ -29,15 +29,45 @@ forecaster_baseline_linear <- function(epi_data, ahead, log = FALSE, sort = FALS
2929
group_by(geo_value) %>%
3030
filter(!(is.na(value) | is.infinite(value))) %>%
3131
filter(n() >= 2) %>%
32-
mutate(weeks_back = as.integer(time_value - epi_as_of(df_processed)) / 7)
33-
34-
point_forecast <- tibble(
35-
geo_value = train_data$geo_value %>% unique(),
36-
model = map(geo_value, ~ lm(value ~ weeks_back, data = train_data %>% filter(geo_value == .x))),
37-
value = map_dbl(model, ~ predict(.x, newdata = data.frame(weeks_back = ahead))),
38-
season_week = target_season_week
39-
)
32+
mutate(weeks_back = as.integer(time_value - max(df_processed$time_value)) / 7)
33+
latency <- as.integer(ceiling((epi_as_of(df_processed) - max(df_processed$time_value)) / 7))
4034

35+
if (no_intercept) {
36+
# set the intercept column to be the last value for that geo
37+
intercept_values <-
38+
train_data %>%
39+
select(geo_value, time_value, value, weeks_back) %>%
40+
group_by(geo_value) %>%
41+
filter(time_value == max(time_value)) %>%
42+
rename(intercept = value) %>%
43+
select(geo_value, intercept) %>%
44+
ungroup()
45+
train_data <- train_data %>%
46+
left_join(intercept_values, by = join_by(geo_value))
47+
point_forecast <- tibble(
48+
geo_value = train_data$geo_value %>% unique(),
49+
model = purrr::map(geo_value, ~ lm(value ~ weeks_back + 0, data = train_data %>% filter(geo_value == .x), offset = intercept)),
50+
season_week = target_season_week
51+
) %>%
52+
left_join(intercept_values, by = join_by(geo_value)) %>%
53+
mutate(
54+
value = map2_vec(
55+
model, intercept,
56+
\(model_x, intercept_y)
57+
# need to add the latency
58+
predict(model_x, newdata = data.frame(weeks_back = ahead + latency, intercept = intercept_y))
59+
)
60+
) %>%
61+
select(-intercept)
62+
} else {
63+
point_forecast <- tibble(
64+
geo_value = train_data$geo_value %>% unique(),
65+
model = map(geo_value, ~ lm(value ~ weeks_back, data = train_data %>% filter(geo_value == .x))),
66+
# ahead is +1 b/c the data is 1 week latent
67+
value = map_dbl(model, ~ predict(.x, newdata = data.frame(weeks_back = ahead + latency))),
68+
season_week = target_season_week
69+
)
70+
}
4171
missing_geos <- setdiff(geos, unique(point_forecast$geo_value))
4272

4373
point_forecast <- bind_rows(
@@ -75,7 +105,8 @@ forecaster_baseline_linear <- function(epi_data, ahead, log = FALSE, sort = FALS
75105
epipredict:::propagate_samples(residuals, point, quantile_levels = covidhub_probs(), aheads = ahead + 2, nsim = 1e4, symmetrize = TRUE, nonneg = FALSE)[[1]] %>%
76106
pull(.pred_distn)
77107
}
78-
quantile_forecast <- point_forecast %>%
108+
quantile_forecast <-
109+
point_forecast %>%
79110
rowwise() %>%
80111
mutate(quantile = get_quantile(value, ahead) %>% nested_quantiles()) %>%
81112
unnest(quantile) %>%
@@ -95,4 +126,5 @@ forecaster_baseline_linear <- function(epi_data, ahead, log = FALSE, sort = FALS
95126
) %>%
96127
select(-model, -values, -population, -season_week) %>%
97128
mutate(value = pmax(0, value))
129+
quantile_forecast
98130
}

flu_geo_exclusions.csv

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@ forecast_date,forecaster,geo_value,weight
77
2024-10-01, climate_base, all, 0.5
88
2024-10-01, climate_geo_agged, all, 0.25
99
2024-10-01, climate_quantile_extrapolated, all, 0
10+
# jan 8
11+
2025-01-08, all, mp, 0
12+
2025-01-08, windowed_seasonal, all, 3
13+
2025-01-08, linear, all, 0.5
14+
2025-01-08, linearlog, all, 0
15+
2025-01-08, climate_base, all, 0.5
16+
2025-01-08, climate_geo_agged, all, 0.25
17+
2025-01-08, climate_quantile_extrapolated, all, 0
1018
# nov 27
1119
2024-11-27, all, mp, 0
1220
2024-11-27, linear, all, 3

scripts/covid_hosp_prod.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ forecast_generation_date <- Sys.Date()
1212

1313
forecaster_fns <- list2(
1414
linear = function(...) {
15-
forecaster_baseline_linear(..., residual_tail = 0.97, residual_center = 0.097)
15+
forecaster_baseline_linear(..., residual_tail = 0.97, residual_center = 0.097, no_intercept = TRUE)
1616
},
1717
# linearlog = function(...) {
1818
# forecaster_baseline_linear(..., log = TRUE)

scripts/flu_hosp_prod.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ forecaster_fns <- list2(
2323
forecaster_baseline_linear(
2424
ahead, ...,
2525
residual_tail = 0.99,
26-
residual_center = 0.35
26+
residual_center = 0.35,
27+
no_intercept = TRUE
2728
)
2829
},
2930
# linearlog = function(...) {

0 commit comments

Comments
 (0)