@@ -38,7 +38,39 @@ is_epi_workflow <- function(x) {
3838 inherits(x , " epi_workflow" )
3939}
4040
41-
41+ # ' Predict from an epi_workflow
42+ # '
43+ # ' @description
44+ # ' This is the `predict()` method for a fit epi_workflow object. The nice thing
45+ # ' about predicting from an epi_workflow is that it will:
46+ # '
47+ # ' - Preprocess `new_data` using the preprocessing method specified when the
48+ # ' workflow was created and fit. This is accomplished using
49+ # ' [hardhat::forge()], which will apply any formula preprocessing or call
50+ # ' [recipes::bake()] if a recipe was supplied.
51+ # '
52+ # ' - Call [parsnip::predict.model_fit()] for you using the underlying fit
53+ # ' parsnip model.
54+ # '
55+ # ' - Ensure that the returned object is an [epiprocess::epi_df] where
56+ # ' possible. Specifically, the output will have `time_value` and
57+ # ' `geo_value` columns as well as the prediction.
58+ # '
59+ # ' @inheritParams parsnip::predict.model_fit
60+ # ' @param forecast_date The date on which the forecast is (was) made.
61+ # '
62+ # ' @param object An epi_workflow that has been fit by [fit.workflow()]
63+ # '
64+ # ' @param new_data A data frame containing the new predictors to preprocess
65+ # ' and predict on
66+ # '
67+ # ' @return
68+ # ' A data frame of model predictions, with as many rows as `new_data` has.
69+ # ' If `new_data` is an `epi_df` or a data frame with `time_value` or
70+ # ' `geo_value` columns, then the result will have those as well.
71+ # '
72+ # ' @name predict-epi_workflow
73+ # ' @export
4274predict.epi_workflow <-
4375 function (object , new_data , type = NULL , opts = list (),
4476 forecast_date = NULL , ... ) {
@@ -47,24 +79,24 @@ predict.epi_workflow <-
4779 c(" Can't predict on an untrained epi_workflow." ,
4880 i = " Do you need to call `fit()`?" ))
4981 }
82+ if (! is_null(forecast_date )) forecast_date <- as.Date(forecast_date )
5083 the_fit <- workflows :: extract_fit_parsnip(object )
5184 mold <- workflows :: extract_mold(object )
5285 forged <- hardhat :: forge(new_data , blueprint = mold $ blueprint )
5386 preds <- predict(the_fit , forged $ predictors , type = type , opts = opts , ... )
5487 keys <- grab_forged_keys(forged , mold , new_data )
55- out <- dplyr :: bind_cols(keys , preds , forecast_date )
88+ out <- dplyr :: bind_cols(keys , forecast_date = forecast_date , preds )
5689 out
5790 }
5891
5992grab_forged_keys <- function (forged , mold , new_data ) {
6093 keys <- c(" time_value" , " geo_value" , " key" )
61- forged_names <- names(forged $ extras $ roles )
62- molded_names <- names(mold $ extras $ roles )
63- extras <- dplyr :: bind_cols(forged $ extras $ roles [forged_names %in% keys ])
94+ forged_roles <- names(forged $ extras $ roles )
95+ extras <- dplyr :: bind_cols(forged $ extras $ roles [forged_roles %in% keys ])
6496 # 1. these are the keys in the test data after prep/bake
6597 new_keys <- names(extras )
6698 # 2. these are the keys in the training data
67- old_keys <- purrr :: map_chr (mold $ extras $ roles [ molded_names %in% keys ], names )
99+ old_keys <- epi_keys_mold (mold )
68100 # 3. these are the keys in the test data as input
69101 new_df_keys <- epi_keys(new_data )
70102 if (! (setequal(old_keys , new_df_keys ) && setequal(new_keys , new_df_keys ))) {
@@ -73,10 +105,13 @@ grab_forged_keys <- function(forged, mold, new_data) {
73105 " in `new_data`. Predictions will have only the available keys." )
74106 )
75107 }
76- if (epiprocess :: is_epi_df(new_data ) || keys [1 : 2 ] %in% new_keys ) {
108+ if (epiprocess :: is_epi_df(new_data )) {
109+ extras <- epiprocess :: as_epi_df(extras )
110+ attr(extras , " metadata" ) <- attr(new_data , " metadata" )
111+ } else if (keys [1 : 2 ] %in% new_keys ) {
77112 l <- list ()
78113 if (length(new_keys ) > 2 ) l <- list (other_keys = new_keys [- c(1 : 2 )])
79- extras <- as_epi_df(extras , additional_metadata = l )
114+ extras <- epiprocess :: as_epi_df(extras , additional_metadata = l )
80115 }
81116 extras
82117}
0 commit comments