@@ -16,6 +16,8 @@ ggplot2::autoplot
1616# ' @param object An `epi_workflow`
1717# ' @param predictions A data frame with predictions. If `NULL`, only the
1818# ' original data is shown.
19+ # ' @param plot_data An epi_df of the data to plot against. This is for the case
20+ # ' where you have the actual results to compare the forecast against.
1921# ' @param .levels A numeric vector of levels to plot for any prediction bands.
2022# ' More than 3 levels begins to be difficult to see.
2123# ' @param ... Ignored
8284# ' @export
8385# ' @rdname autoplot-epipred
8486autoplot.epi_workflow <- function (
85- object , predictions = NULL ,
87+ object ,
88+ predictions = NULL ,
89+ plot_data = NULL ,
8690 .levels = c(.5 , .8 , .9 ), ... ,
8791 .color_by = c(" all_keys" , " geo_value" , " other_keys" , " .response" , " all" , " none" ),
8892 .facet_by = c(" .response" , " other_keys" , " all_keys" , " geo_value" , " all" , " none" ),
@@ -109,30 +113,32 @@ autoplot.epi_workflow <- function(
109113 }
110114 keys <- c(" geo_value" , " time_value" , " key" )
111115 mold_roles <- names(mold $ extras $ roles )
112- edf <- bind_cols(mold $ extras $ roles [mold_roles %in% keys ], y )
113- if (starts_with_impl(" ahead_" , names(y ))) {
114- old_name_y <- unlist(strsplit(names(y ), " _" ))
115- shift <- as.numeric(old_name_y [2 ])
116- new_name_y <- paste(old_name_y [- c(1 : 2 )], collapse = " _" )
117- edf <- rename(edf , !! new_name_y : = !! names(y ))
118- } else if (starts_with_impl(" lag_" , names(y ))) {
119- old_name_y <- unlist(strsplit(names(y ), " _" ))
120- shift <- - as.numeric(old_name_y [2 ])
121- new_name_y <- paste(old_name_y [- c(1 : 2 )], collapse = " _" )
122- edf <- rename(edf , !! new_name_y : = !! names(y ))
123- }
124-
125- if (! is.null(shift )) {
126- edf <- mutate(edf , time_value = time_value + shift )
116+ # extract the relevant column names for plotting
117+ old_name_y <- unlist(strsplit(names(y ), " _" ))
118+ new_name_y <- paste(old_name_y [- c(1 : 2 )], collapse = " _" )
119+ if (is.null(plot_data )) {
120+ # the outcome has shifted, so we need to shift it forward (or back)
121+ # by the corresponding amount
122+ plot_data <- bind_cols(mold $ extras $ roles [mold_roles %in% keys ], y )
123+ if (starts_with_impl(" ahead_" , names(y ))) {
124+ shift <- as.numeric(old_name_y [2 ])
125+ } else if (starts_with_impl(" lag_" , names(y ))) {
126+ old_name_y <- unlist(strsplit(names(y ), " _" ))
127+ shift <- - as.numeric(old_name_y [2 ])
128+ }
129+ plot_data <- rename(plot_data , !! new_name_y : = !! names(y ))
130+ if (! is.null(shift )) {
131+ plot_data <- mutate(plot_data , time_value = time_value + shift )
132+ }
133+ other_keys <- setdiff(key_colnames(object ), c(" geo_value" , " time_value" ))
134+ plot_data <- as_epi_df(plot_data ,
135+ as_of = object $ fit $ meta $ as_of ,
136+ other_keys = other_keys
137+ )
127138 }
128- other_keys <- setdiff(key_colnames(object ), c(" geo_value" , " time_value" ))
129- edf <- as_epi_df(edf ,
130- as_of = object $ fit $ meta $ as_of ,
131- other_keys = other_keys
132- )
133139 if (is.null(predictions )) {
134140 return (autoplot(
135- edf , new_name_y ,
141+ plot_data , new_name_y ,
136142 .color_by = .color_by , .facet_by = .facet_by , .base_color = .base_color ,
137143 .max_facets = .max_facets
138144 ))
@@ -144,27 +150,27 @@ autoplot.epi_workflow <- function(
144150 }
145151 predictions <- rename(predictions , time_value = target_date )
146152 }
147- pred_cols_ok <- hardhat :: check_column_names(predictions , key_colnames(edf ))
153+ pred_cols_ok <- hardhat :: check_column_names(predictions , key_colnames(plot_data ))
148154 if (! pred_cols_ok $ ok ) {
149155 cli_warn(c(
150156 " `predictions` is missing required variables: {.var {pred_cols_ok$missing_names}}." ,
151157 i = " Plotting the original data."
152158 ))
153159 return (autoplot(
154- edf , !! new_name_y ,
160+ plot_data , !! new_name_y ,
155161 .color_by = .color_by , .facet_by = .facet_by , .base_color = .base_color ,
156162 .max_facets = .max_facets
157163 ))
158164 }
159165
160166 # First we plot the history, always faceted by everything
161- bp <- autoplot(edf , !! new_name_y ,
167+ bp <- autoplot(plot_data , !! new_name_y ,
162168 .color_by = " none" , .facet_by = " all_keys" ,
163169 .base_color = " black" , .max_facets = .max_facets
164170 )
165171
166172 # Now, prepare matching facets in the predictions
167- ek <- epi_keys_only(edf )
173+ ek <- epi_keys_only(plot_data )
168174 predictions <- predictions %> %
169175 mutate(
170176 .facets = interaction(!!! rlang :: syms(as.list(ek )), sep = " /" ),
@@ -202,7 +208,7 @@ autoplot.epi_workflow <- function(
202208# ' @export
203209# ' @rdname autoplot-epipred
204210autoplot.canned_epipred <- function (
205- object , ... ,
211+ object , plot_data = NULL , ... ,
206212 .color_by = c(" all_keys" , " geo_value" , " other_keys" , " .response" , " all" , " none" ),
207213 .facet_by = c(" .response" , " other_keys" , " all_keys" , " geo_value" , " all" , " none" ),
208214 .base_color = " dodgerblue4" ,
@@ -216,7 +222,7 @@ autoplot.canned_epipred <- function(
216222 predictions <- object $ predictions %> %
217223 rename(time_value = target_date )
218224
219- autoplot(ewf , predictions ,
225+ autoplot(ewf , predictions , plot_data , ... ,
220226 .color_by = .color_by , .facet_by = .facet_by ,
221227 .base_color = .base_color , .max_facets = .max_facets
222228 )
0 commit comments