Skip to content

Commit 5ccb58a

Browse files
committed
moving to a workflow subclass
* add some required packages/versions * draft workflow class * add truthy testing * begin predict method
1 parent 4c2a4a2 commit 5ccb58a

File tree

3 files changed

+40
-3
lines changed

3 files changed

+40
-3
lines changed

DESCRIPTION

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ Imports:
2323
tibble,
2424
tidyr,
2525
tidyselect,
26-
tensr
26+
tensr,
27+
hardhat (>= 1.0.0.9000)
2728
Suggests:
2829
covidcast,
2930
data.table,
@@ -37,7 +38,9 @@ Suggests:
3738
VignetteBuilder:
3839
knitr
3940
Remotes:
40-
dajmcdon/epiprocess
41+
dajmcdon/epiprocess,
42+
tidymodels/hardhat,
43+
tidymodels/recipes
4144
Config/testthat/edition: 3
4245
Encoding: UTF-8
4346
Roxygen: list(markdown = TRUE)

R/epi_recipe.R

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ epi_recipe.epi_df <-
137137
levels = NULL,
138138
retained = NA
139139
)
140-
class(out) <- "recipe"
140+
class(out) <- c("epi_recipe", "recipe")
141141
out
142142
}
143143

@@ -210,3 +210,19 @@ epi_form2args <- function(formula, data, ...) {
210210
list(x = data, vars = vars, roles = roles)
211211
}
212212

213+
is_epi_recipe <- function(x) {
214+
inherits(x, "epi_recipe")
215+
}
216+
217+
add_epi_recipe <- function(
218+
x, recipe, ..., blueprint = default_epi_recipe_blueprint()) {
219+
add_recipe(x, recipe, ..., blueprint = blueprint)
220+
}
221+
222+
default_epi_recipe_blueprint <-
223+
function(intercept = FALSE, allow_novel_levels = FALSE, fresh = TRUE,
224+
bake_dependent_roles = c("time_value", "geo_value", "key", "raw"),
225+
composition = "tibble") {
226+
hardhat::default_recipe_blueprint(
227+
intercept, allow_novel_levels, fresh, bake_dependent_roles, composition)
228+
}

R/epi_workflow.R

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#' @export
2+
epi_workflow <- function(preprocessor = NULL, spec = NULL) {
3+
out <- workflow(preprocessor, spec)
4+
class(out) <- c("epi_workflow", class(out))
5+
}
6+
7+
predict.epi_workflow <-
8+
function(object, new_data, type = NULL, opts = list(), forecast_date, ...) {
9+
out <- predict(object, new_data, type = type, opts = opts, ...)
10+
if (is_epi_df(new_data)) {
11+
ek <- epi_keys(new_data)
12+
13+
}
14+
}
15+
16+
is_epi_workflow <- function(x) {
17+
inherits(x, "epi_workflow")
18+
}

0 commit comments

Comments
 (0)