Skip to content

Commit 910dd19

Browse files
authored
Merge branch 'main' into main
2 parents 98f5326 + 5464bd1 commit 910dd19

14 files changed

+411
-192
lines changed

DESCRIPTION

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Imports:
2020
glue,
2121
hardhat (>= 1.0.0.9000),
2222
magrittr,
23+
parsnip,
2324
purrr,
2425
recipes (>= 0.2.0.9001),
2526
rlang,
@@ -28,6 +29,7 @@ Imports:
2829
tibble,
2930
tidyr,
3031
tidyselect,
32+
tensr,
3133
workflows
3234
Suggests:
3335
covidcast,

NAMESPACE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Generated by roxygen2: do not edit by hand
22

3-
S3method(bake,step_epi_ahead)
4-
S3method(bake,step_epi_lag)
3+
S3method(bake,step_epi_shift)
54
S3method(epi_keys,default)
65
S3method(epi_keys,epi_df)
76
S3method(epi_keys,recipe)
@@ -36,6 +35,7 @@ export(smooth_arx_args_list)
3635
export(smooth_arx_forecaster)
3736
export(step_epi_ahead)
3837
export(step_epi_lag)
38+
export(step_epi_naomit)
3939
import(recipes)
4040
importFrom(magrittr,"%>%")
4141
importFrom(rlang,"!!")

R/epi_ahead.R

Lines changed: 11 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -63,101 +63,16 @@ step_epi_ahead <-
6363
columns = NULL,
6464
skip = FALSE,
6565
id = rand_id("epi_ahead")) {
66-
add_step(
67-
recipe,
68-
step_epi_ahead_new(
69-
terms = dplyr::enquos(...),
70-
role = role,
71-
trained = trained,
72-
ahead = ahead,
73-
prefix = prefix,
74-
default = default,
75-
keys = keys,
76-
columns = columns,
77-
skip = skip,
78-
id = id
79-
)
66+
step_epi_shift(recipe,
67+
...,
68+
role = role,
69+
trained = trained,
70+
shift = ahead,
71+
prefix = prefix,
72+
default = default,
73+
keys = keys,
74+
columns = columns,
75+
skip = skip,
76+
id = id
8077
)
8178
}
82-
83-
step_epi_ahead_new <-
84-
function(terms, role, trained, ahead, prefix, default, keys,
85-
columns, skip, id) {
86-
step(
87-
subclass = "epi_ahead",
88-
terms = terms,
89-
role = role,
90-
trained = trained,
91-
ahead = ahead,
92-
prefix = prefix,
93-
default = default,
94-
keys = keys,
95-
columns = columns,
96-
skip = skip,
97-
id = id
98-
)
99-
}
100-
101-
#' @export
102-
prep.step_epi_ahead <- function(x, training, info = NULL, ...) {
103-
step_epi_ahead_new(
104-
terms = x$terms,
105-
role = x$role,
106-
trained = TRUE,
107-
ahead = x$ahead,
108-
prefix = x$prefix,
109-
default = x$default,
110-
keys = x$keys,
111-
columns = recipes_eval_select(x$terms, training, info),
112-
skip = x$skip,
113-
id = x$id
114-
)
115-
}
116-
117-
#' @export
118-
bake.step_epi_ahead <- function(object, new_data, ...) {
119-
if (!all(object$ahead == as.integer(object$ahead))) {
120-
rlang::abort("step_epi_ahead requires 'ahead' argument to be integer valued.")
121-
}
122-
123-
grid <- tidyr::expand_grid(
124-
col = object$columns, lag_val = -object$ahead) %>%
125-
dplyr::mutate(
126-
ahead_val = -lag_val,
127-
newname = glue::glue("{object$prefix}{ahead_val}_{col}")
128-
) %>%
129-
dplyr::select(-ahead_val)
130-
131-
## ensure no name clashes
132-
new_data_names <- colnames(new_data)
133-
intersection <- new_data_names %in% grid$newname
134-
if (any(intersection)) {
135-
rlang::abort(
136-
paste0("Name collision occured in `", class(object)[1],
137-
"`. The following variable names already exists: ",
138-
paste0(new_data_names[intersection], collapse = ", "),
139-
"."))
140-
}
141-
142-
ok <- object$keys
143-
lagged <- purrr::reduce(
144-
purrr::pmap(grid, epi_shift_single, x = new_data, key_cols = ok),
145-
dplyr::full_join,
146-
by = ok
147-
)
148-
149-
dplyr::full_join(new_data, lagged, by = ok) %>%
150-
dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>%
151-
dplyr::arrange(time_value) %>%
152-
dplyr::ungroup()
153-
154-
}
155-
156-
#' @export
157-
print.step_epi_ahead <-
158-
function(x, width = max(20, options()$width - 30), ...) {
159-
## TODO add printing of the lags
160-
title <- "Leading "
161-
recipes::print_step(x$columns, x$terms, x$trained, title, width)
162-
invisible(x)
163-
}

R/epi_lag.R

Lines changed: 11 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -29,95 +29,17 @@ step_epi_lag <-
2929
columns = NULL,
3030
skip = FALSE,
3131
id = rand_id("epi_lag")) {
32-
add_step(
33-
recipe,
34-
step_epi_lag_new(
35-
terms = dplyr::enquos(...),
36-
role = role,
37-
trained = trained,
38-
lag = lag,
39-
prefix = prefix,
40-
default = default,
41-
keys = keys,
42-
columns = columns,
43-
skip = skip,
44-
id = id
45-
)
32+
step_epi_shift(recipe,
33+
...,
34+
role = role,
35+
trained = trained,
36+
shift = -lag,
37+
prefix = prefix,
38+
default = default,
39+
keys = keys,
40+
columns = columns,
41+
skip = skip,
42+
id = id
4643
)
4744
}
4845

49-
step_epi_lag_new <-
50-
function(terms, role, trained, lag, prefix, default, keys,
51-
columns, skip, id) {
52-
step(
53-
subclass = "epi_lag",
54-
terms = terms,
55-
role = role,
56-
trained = trained,
57-
lag = lag,
58-
prefix = prefix,
59-
default = default,
60-
keys = keys,
61-
columns = columns,
62-
skip = skip,
63-
id = id
64-
)
65-
}
66-
67-
#' @export
68-
prep.step_epi_lag <- function(x, training, info = NULL, ...) {
69-
step_epi_lag_new(
70-
terms = x$terms,
71-
role = x$role,
72-
trained = TRUE,
73-
lag = x$lag,
74-
prefix = x$prefix,
75-
default = x$default,
76-
keys = x$keys,
77-
columns = recipes_eval_select(x$terms, training, info),
78-
skip = x$skip,
79-
id = x$id
80-
)
81-
}
82-
83-
#' @export
84-
bake.step_epi_lag <- function(object, new_data, ...) {
85-
if (!all(object$lag == as.integer(object$lag))) {
86-
rlang::abort("step_epi_lag requires 'lag' argument to be integer valued.")
87-
}
88-
89-
grid <- tidyr::expand_grid(col = object$columns, lag_val = object$lag) %>%
90-
dplyr::mutate(newname = glue::glue("{object$prefix}{lag_val}_{col}"))
91-
92-
## ensure no name clashes
93-
new_data_names <- colnames(new_data)
94-
intersection <- new_data_names %in% grid$newname
95-
if (any(intersection)) {
96-
rlang::abort(
97-
paste0("Name collision occured in `", class(object)[1],
98-
"`. The following variable names already exists: ",
99-
paste0(new_data_names[intersection], collapse = ", "),
100-
"."))
101-
}
102-
ok <- object$keys
103-
lagged <- purrr::reduce(
104-
purrr::pmap(grid, epi_shift_single, x = new_data, key_cols = ok),
105-
dplyr::full_join,
106-
by = ok
107-
)
108-
109-
dplyr::full_join(new_data, lagged, by = ok) %>%
110-
dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>%
111-
dplyr::arrange(time_value) %>%
112-
dplyr::ungroup()
113-
114-
}
115-
116-
#' @export
117-
print.step_epi_lag <-
118-
function(x, width = max(20, options()$width - 30), ...) {
119-
## TODO add printing of the lags
120-
title <- "Lagging "
121-
recipes::print_step(x$columns, x$terms, x$trained, title, width)
122-
invisible(x)
123-
}

R/epi_shift_internal.R

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
#' Create a shifted predictor
2+
#'
3+
#' `step_epi_shift` creates a *specification* of a recipe step that
4+
#' will add new columns of shifted data. shifted data will
5+
#' by default include NA values where the shift was induced.
6+
#' These can be removed with [step_naomit()], or you may
7+
#' specify an alternative filler value with the `default`
8+
#' argument.
9+
#'
10+
#' @param shift A vector of integers. Each specified column will be
11+
#' shifted for each value in the vector.
12+
#' @template step-return
13+
#'
14+
#' @details The step assumes that the data are already _in the proper sequential
15+
#' order_ for shifting.
16+
#'
17+
#' @family row operation steps
18+
#' @rdname step_epi_ahead
19+
step_epi_shift <-
20+
function(recipe,
21+
...,
22+
role,
23+
trained,
24+
shift,
25+
prefix,
26+
default,
27+
keys,
28+
columns,
29+
skip,
30+
id) {
31+
add_step(
32+
recipe,
33+
step_epi_shift_new(
34+
terms = dplyr::enquos(...),
35+
role = role,
36+
trained = trained,
37+
shift = shift,
38+
prefix = prefix,
39+
default = default,
40+
keys = keys,
41+
columns = columns,
42+
skip = skip,
43+
id = id
44+
)
45+
)
46+
}
47+
48+
step_epi_shift_new <-
49+
function(terms, role, trained, shift, prefix, default, keys,
50+
columns, skip, id) {
51+
step(
52+
subclass = "epi_shift",
53+
terms = terms,
54+
role = role,
55+
trained = trained,
56+
shift = shift,
57+
prefix = prefix,
58+
default = default,
59+
keys = keys,
60+
columns = columns,
61+
skip = skip,
62+
id = id
63+
)
64+
}
65+
66+
#' @export
67+
prep.step_epi_shift <- function(x, training, info = NULL, ...) {
68+
step_epi_shift_new(
69+
terms = x$terms,
70+
role = x$role,
71+
trained = TRUE,
72+
shift = x$shift,
73+
prefix = x$prefix,
74+
default = x$default,
75+
keys = x$keys,
76+
columns = recipes_eval_select(x$terms, training, info),
77+
skip = x$skip,
78+
id = x$id
79+
)
80+
}
81+
82+
#' @export
83+
bake.step_epi_shift <- function(object, new_data, ...) {
84+
if (!all(object$shift == as.integer(object$shift))) {
85+
rlang::abort("step_epi_shift requires 'shift' argument to be integer valued.")
86+
}
87+
grid <- tidyr::expand_grid(col = object$columns, lag_val = -object$shift)
88+
is_lag <- object$role == "predictor"
89+
if (!is_lag) {
90+
grid <- dplyr::mutate(grid,ahead_val = -lag_val)
91+
}
92+
grid <- dplyr::mutate(grid,
93+
newname = glue::glue(
94+
paste0(
95+
"{object$prefix}",
96+
ifelse(is_lag,"{lag_val}","{ahead_val}"),
97+
"_{col}"
98+
)
99+
)
100+
)
101+
if (!is_lag) {
102+
grid <- dplyr::select(grid, -ahead_val)
103+
}
104+
## ensure no name clashes
105+
new_data_names <- colnames(new_data)
106+
intersection <- new_data_names %in% grid$newname
107+
if (any(intersection)) {
108+
rlang::abort(
109+
paste0("Name collision occured in `", class(object)[1],
110+
"`. The following variable names already exists: ",
111+
paste0(new_data_names[intersection], collapse = ", "),
112+
"."))
113+
}
114+
ok <- object$keys
115+
shifted <- purrr::reduce(
116+
purrr::pmap(grid, epi_shift_single, x = new_data, key_cols = ok),
117+
dplyr::full_join,
118+
by = ok
119+
)
120+
121+
dplyr::full_join(new_data, shifted, by = ok) %>%
122+
dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>%
123+
dplyr::arrange(time_value) %>%
124+
dplyr::ungroup()
125+
126+
}
127+
128+
#' @export
129+
print.step_epi_shift <-
130+
function(x, width = max(20, options()$width - 30), ...) {
131+
## TODO add printing of the shifts
132+
title <- ifelse(x$role == "predictor","Lagging","Leading") %>%
133+
paste0(": ", abs(x$shift),",")
134+
recipes::print_step(x$columns, x$terms, x$trained, title, width)
135+
invisible(x)
136+
}

0 commit comments

Comments
 (0)