Skip to content

Commit 83020c6

Browse files
committed
added augment method just to avoid downstream trouble
1 parent bd40bcf commit 83020c6

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

R/epi_keys.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@ epi_keys_mold <- function(mold) {
3232
mold_keys <- purrr::map_chr(mold$extras$roles[molded_names %in% keys], names)
3333
unname(mold_keys)
3434
}
35+

R/epi_workflow.R

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,31 @@ grab_forged_keys <- function(forged, mold, new_data) {
169169
extras
170170
}
171171

172+
173+
#' Augment data with predictions
174+
#'
175+
#' @param x A trained epi_workflow
176+
#' @param new_data A epi_df of predictors
177+
#' @param ... Arguments passed on to the predict method.
178+
#'
179+
#' @return new_data with additional columns containing the predicted values
180+
#' @export
181+
augment.epi_workflow <- function (x, new_data, ...) {
182+
predictions <- predict(x, new_data, ...)
183+
if (is_epi_df(predictions)) join_by <- epi_keys(predictions)
184+
else rlang::abort(
185+
c("Cannot determine how to join new_data with the predictions.",
186+
"Try converting new_data to an epi_df with `as_epi_df(new_data)`."))
187+
complete_overlap <- intersect(names(new_data), join_by)
188+
if (length(complete_overlap) < length(join_by)) {
189+
rlang::warn(
190+
glue::glue("Your original training data had keys {join_by}, but",
191+
"`new_data` only has {complete_overlap}. The output",
192+
"may be strange."))
193+
}
194+
dplyr::full_join(predictions, new_data, by = join_by)
195+
}
196+
172197
new_epi_workflow <- function(
173198
pre = workflows:::new_stage_pre(),
174199
fit = workflows:::new_stage_fit(),

0 commit comments

Comments
 (0)