@@ -16,6 +16,8 @@ ggplot2::autoplot
1616# ' @param object An `epi_workflow`
1717# ' @param predictions A data frame with predictions. If `NULL`, only the
1818# ' original data is shown.
19+ # ' @param plot_data An epi_df of the data to plot against. This is for the case
20+ # ' where you have the actual results to compare the forecast against.
1921# ' @param .levels A numeric vector of levels to plot for any prediction bands.
2022# ' More than 3 levels begins to be difficult to see.
2123# ' @param ... Ignored
8183# ' @export
8284# ' @rdname autoplot-epipred
8385autoplot.epi_workflow <- function (
84- object , predictions = NULL ,
86+ object ,
87+ predictions = NULL ,
88+ plot_data = NULL ,
8589 .levels = c(.5 , .8 , .9 ), ... ,
8690 .color_by = c(" all_keys" , " geo_value" , " other_keys" , " .response" , " all" , " none" ),
8791 .facet_by = c(" .response" , " other_keys" , " all_keys" , " geo_value" , " all" , " none" ),
@@ -108,30 +112,32 @@ autoplot.epi_workflow <- function(
108112 }
109113 keys <- c(" geo_value" , " time_value" , " key" )
110114 mold_roles <- names(mold $ extras $ roles )
111- edf <- bind_cols(mold $ extras $ roles [mold_roles %in% keys ], y )
112- if (starts_with_impl(" ahead_" , names(y ))) {
113- old_name_y <- unlist(strsplit(names(y ), " _" ))
114- shift <- as.numeric(old_name_y [2 ])
115- new_name_y <- paste(old_name_y [- c(1 : 2 )], collapse = " _" )
116- edf <- rename(edf , !! new_name_y : = !! names(y ))
117- } else if (starts_with_impl(" lag_" , names(y ))) {
118- old_name_y <- unlist(strsplit(names(y ), " _" ))
119- shift <- - as.numeric(old_name_y [2 ])
120- new_name_y <- paste(old_name_y [- c(1 : 2 )], collapse = " _" )
121- edf <- rename(edf , !! new_name_y : = !! names(y ))
122- }
123-
124- if (! is.null(shift )) {
125- edf <- mutate(edf , time_value = time_value + shift )
115+ # extract the relevant column names for plotting
116+ old_name_y <- unlist(strsplit(names(y ), " _" ))
117+ new_name_y <- paste(old_name_y [- c(1 : 2 )], collapse = " _" )
118+ if (is.null(plot_data )) {
119+ # the outcome has shifted, so we need to shift it forward (or back)
120+ # by the corresponding amount
121+ plot_data <- bind_cols(mold $ extras $ roles [mold_roles %in% keys ], y )
122+ if (starts_with_impl(" ahead_" , names(y ))) {
123+ shift <- as.numeric(old_name_y [2 ])
124+ } else if (starts_with_impl(" lag_" , names(y ))) {
125+ old_name_y <- unlist(strsplit(names(y ), " _" ))
126+ shift <- - as.numeric(old_name_y [2 ])
127+ }
128+ plot_data <- rename(plot_data , !! new_name_y : = !! names(y ))
129+ if (! is.null(shift )) {
130+ plot_data <- mutate(plot_data , time_value = time_value + shift )
131+ }
132+ other_keys <- setdiff(key_colnames(object ), c(" geo_value" , " time_value" ))
133+ plot_data <- as_epi_df(plot_data ,
134+ as_of = object $ fit $ meta $ as_of ,
135+ other_keys = other_keys
136+ )
126137 }
127- other_keys <- setdiff(key_colnames(object ), c(" geo_value" , " time_value" ))
128- edf <- as_epi_df(edf ,
129- as_of = object $ fit $ meta $ as_of ,
130- other_keys = other_keys
131- )
132138 if (is.null(predictions )) {
133139 return (autoplot(
134- edf , new_name_y ,
140+ plot_data , new_name_y ,
135141 .color_by = .color_by , .facet_by = .facet_by , .base_color = .base_color ,
136142 .max_facets = .max_facets
137143 ))
@@ -143,27 +149,27 @@ autoplot.epi_workflow <- function(
143149 }
144150 predictions <- rename(predictions , time_value = target_date )
145151 }
146- pred_cols_ok <- hardhat :: check_column_names(predictions , key_colnames(edf ))
152+ pred_cols_ok <- hardhat :: check_column_names(predictions , key_colnames(plot_data ))
147153 if (! pred_cols_ok $ ok ) {
148154 cli_warn(c(
149155 " `predictions` is missing required variables: {.var {pred_cols_ok$missing_names}}." ,
150156 i = " Plotting the original data."
151157 ))
152158 return (autoplot(
153- edf , !! new_name_y ,
159+ plot_data , !! new_name_y ,
154160 .color_by = .color_by , .facet_by = .facet_by , .base_color = .base_color ,
155161 .max_facets = .max_facets
156162 ))
157163 }
158164
159165 # First we plot the history, always faceted by everything
160- bp <- autoplot(edf , !! new_name_y ,
166+ bp <- autoplot(plot_data , !! new_name_y ,
161167 .color_by = " none" , .facet_by = " all_keys" ,
162168 .base_color = " black" , .max_facets = .max_facets
163169 )
164170
165171 # Now, prepare matching facets in the predictions
166- ek <- epi_keys_only(edf )
172+ ek <- epi_keys_only(plot_data )
167173 predictions <- predictions %> %
168174 mutate(
169175 .facets = interaction(!!! rlang :: syms(as.list(ek )), sep = " /" ),
@@ -201,7 +207,7 @@ autoplot.epi_workflow <- function(
201207# ' @export
202208# ' @rdname autoplot-epipred
203209autoplot.canned_epipred <- function (
204- object , ... ,
210+ object , plot_data = NULL , ... ,
205211 .color_by = c(" all_keys" , " geo_value" , " other_keys" , " .response" , " all" , " none" ),
206212 .facet_by = c(" .response" , " other_keys" , " all_keys" , " geo_value" , " all" , " none" ),
207213 .base_color = " dodgerblue4" ,
@@ -215,7 +221,7 @@ autoplot.canned_epipred <- function(
215221 predictions <- object $ predictions %> %
216222 rename(time_value = target_date )
217223
218- autoplot(ewf , predictions ,
224+ autoplot(ewf , predictions , plot_data , ... ,
219225 .color_by = .color_by , .facet_by = .facet_by ,
220226 .base_color = .base_color , .max_facets = .max_facets
221227 )
0 commit comments