@@ -169,6 +169,31 @@ grab_forged_keys <- function(forged, mold, new_data) {
169169 extras
170170}
171171
172+
173+ # ' Augment data with predictions
174+ # '
175+ # ' @param x A trained epi_workflow
176+ # ' @param new_data A epi_df of predictors
177+ # ' @param ... Arguments passed on to the predict method.
178+ # '
179+ # ' @return new_data with additional columns containing the predicted values
180+ # ' @export
181+ augment.epi_workflow <- function (x , new_data , ... ) {
182+ predictions <- predict(x , new_data , ... )
183+ if (is_epi_df(predictions )) join_by <- epi_keys(predictions )
184+ else rlang :: abort(
185+ c(" Cannot determine how to join new_data with the predictions." ,
186+ " Try converting new_data to an epi_df with `as_epi_df(new_data)`." ))
187+ complete_overlap <- intersect(names(new_data ), join_by )
188+ if (length(complete_overlap ) < length(join_by )) {
189+ rlang :: warn(
190+ glue :: glue(" Your original training data had keys {join_by}, but" ,
191+ " `new_data` only has {complete_overlap}. The output" ,
192+ " may be strange." ))
193+ }
194+ dplyr :: full_join(predictions , new_data , by = join_by )
195+ }
196+
172197new_epi_workflow <- function (
173198 pre = workflows ::: new_stage_pre(),
174199 fit = workflows ::: new_stage_fit(),
0 commit comments