Skip to content

Commit c760aa0

Browse files
committed
autoplot new data
1 parent a68e817 commit c760aa0

File tree

1 file changed

+34
-28
lines changed

1 file changed

+34
-28
lines changed

R/autoplot.R

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ ggplot2::autoplot
1616
#' @param object An `epi_workflow`
1717
#' @param predictions A data frame with predictions. If `NULL`, only the
1818
#' original data is shown.
19+
#' @param plot_data An epi_df of the data to plot against. This is for the case
20+
#' where you have the actual results to compare the forecast against.
1921
#' @param .levels A numeric vector of levels to plot for any prediction bands.
2022
#' More than 3 levels begins to be difficult to see.
2123
#' @param ... Ignored
@@ -81,7 +83,9 @@ NULL
8183
#' @export
8284
#' @rdname autoplot-epipred
8385
autoplot.epi_workflow <- function(
84-
object, predictions = NULL,
86+
object,
87+
predictions = NULL,
88+
plot_data = NULL,
8589
.levels = c(.5, .8, .9), ...,
8690
.color_by = c("all_keys", "geo_value", "other_keys", ".response", "all", "none"),
8791
.facet_by = c(".response", "other_keys", "all_keys", "geo_value", "all", "none"),
@@ -108,30 +112,32 @@ autoplot.epi_workflow <- function(
108112
}
109113
keys <- c("geo_value", "time_value", "key")
110114
mold_roles <- names(mold$extras$roles)
111-
edf <- bind_cols(mold$extras$roles[mold_roles %in% keys], y)
112-
if (starts_with_impl("ahead_", names(y))) {
113-
old_name_y <- unlist(strsplit(names(y), "_"))
114-
shift <- as.numeric(old_name_y[2])
115-
new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_")
116-
edf <- rename(edf, !!new_name_y := !!names(y))
117-
} else if (starts_with_impl("lag_", names(y))) {
118-
old_name_y <- unlist(strsplit(names(y), "_"))
119-
shift <- -as.numeric(old_name_y[2])
120-
new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_")
121-
edf <- rename(edf, !!new_name_y := !!names(y))
122-
}
123-
124-
if (!is.null(shift)) {
125-
edf <- mutate(edf, time_value = time_value + shift)
115+
# extract the relevant column names for plotting
116+
old_name_y <- unlist(strsplit(names(y), "_"))
117+
new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_")
118+
if (is.null(plot_data)) {
119+
# the outcome has shifted, so we need to shift it forward (or back)
120+
# by the corresponding amount
121+
plot_data <- bind_cols(mold$extras$roles[mold_roles %in% keys], y)
122+
if (starts_with_impl("ahead_", names(y))) {
123+
shift <- as.numeric(old_name_y[2])
124+
} else if (starts_with_impl("lag_", names(y))) {
125+
old_name_y <- unlist(strsplit(names(y), "_"))
126+
shift <- -as.numeric(old_name_y[2])
127+
}
128+
plot_data <- rename(plot_data, !!new_name_y := !!names(y))
129+
if (!is.null(shift)) {
130+
plot_data <- mutate(plot_data, time_value = time_value + shift)
131+
}
132+
other_keys <- setdiff(key_colnames(object), c("geo_value", "time_value"))
133+
plot_data <- as_epi_df(plot_data,
134+
as_of = object$fit$meta$as_of,
135+
other_keys = other_keys
136+
)
126137
}
127-
other_keys <- setdiff(key_colnames(object), c("geo_value", "time_value"))
128-
edf <- as_epi_df(edf,
129-
as_of = object$fit$meta$as_of,
130-
other_keys = other_keys
131-
)
132138
if (is.null(predictions)) {
133139
return(autoplot(
134-
edf, new_name_y,
140+
plot_data, new_name_y,
135141
.color_by = .color_by, .facet_by = .facet_by, .base_color = .base_color,
136142
.max_facets = .max_facets
137143
))
@@ -143,27 +149,27 @@ autoplot.epi_workflow <- function(
143149
}
144150
predictions <- rename(predictions, time_value = target_date)
145151
}
146-
pred_cols_ok <- hardhat::check_column_names(predictions, key_colnames(edf))
152+
pred_cols_ok <- hardhat::check_column_names(predictions, key_colnames(plot_data))
147153
if (!pred_cols_ok$ok) {
148154
cli_warn(c(
149155
"`predictions` is missing required variables: {.var {pred_cols_ok$missing_names}}.",
150156
i = "Plotting the original data."
151157
))
152158
return(autoplot(
153-
edf, !!new_name_y,
159+
plot_data, !!new_name_y,
154160
.color_by = .color_by, .facet_by = .facet_by, .base_color = .base_color,
155161
.max_facets = .max_facets
156162
))
157163
}
158164

159165
# First we plot the history, always faceted by everything
160-
bp <- autoplot(edf, !!new_name_y,
166+
bp <- autoplot(plot_data, !!new_name_y,
161167
.color_by = "none", .facet_by = "all_keys",
162168
.base_color = "black", .max_facets = .max_facets
163169
)
164170

165171
# Now, prepare matching facets in the predictions
166-
ek <- epi_keys_only(edf)
172+
ek <- epi_keys_only(plot_data)
167173
predictions <- predictions %>%
168174
mutate(
169175
.facets = interaction(!!!rlang::syms(as.list(ek)), sep = "/"),
@@ -201,7 +207,7 @@ autoplot.epi_workflow <- function(
201207
#' @export
202208
#' @rdname autoplot-epipred
203209
autoplot.canned_epipred <- function(
204-
object, ...,
210+
object, plot_data = NULL, ...,
205211
.color_by = c("all_keys", "geo_value", "other_keys", ".response", "all", "none"),
206212
.facet_by = c(".response", "other_keys", "all_keys", "geo_value", "all", "none"),
207213
.base_color = "dodgerblue4",
@@ -215,7 +221,7 @@ autoplot.canned_epipred <- function(
215221
predictions <- object$predictions %>%
216222
rename(time_value = target_date)
217223

218-
autoplot(ewf, predictions,
224+
autoplot(ewf, predictions, plot_data, ...,
219225
.color_by = .color_by, .facet_by = .facet_by,
220226
.base_color = .base_color, .max_facets = .max_facets
221227
)

0 commit comments

Comments
 (0)