Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
^pkg.Rproj$
figure$
cache$
^.*\.Rproj$
^\.Rproj\.user$
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ inst/doc

test.R

*-vignette.pdf
*-vignette.pdf
.Rproj.user
15 changes: 8 additions & 7 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,12 @@ Imports:
Matrix,
osqp,
rlang,
purrr,
Remotes:
susanathey/MCPanel
License: MIT + file LICENSE
Encoding: UTF-8
LazyData: true
RoxygenNote: 7.1.1
purrr
Suggests:
testthat,
CausalImpact,
keras,
FNN,
gsynth,
knitr,
rmarkdown,
Expand All @@ -38,4 +33,10 @@ Suggests:
randomForest,
kableExtra,
ggrepel
Remotes:
susanathey/MCPanel
License: MIT + file LICENSE
Encoding: UTF-8
LazyData: true
RoxygenNote: 7.1.2
VignetteBuilder: knitr
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ importFrom(ggplot2,aes)
importFrom(graphics,plot)
importFrom(magrittr,"%>%")
importFrom(purrr,reduce)
importFrom(stats,as.formula)
importFrom(stats,coef)
importFrom(stats,delete.response)
importFrom(stats,formula)
Expand All @@ -40,6 +41,10 @@ importFrom(stats,model.matrix)
importFrom(stats,na.omit)
importFrom(stats,poly)
importFrom(stats,predict)
importFrom(stats,qnorm)
importFrom(stats,quantile)
importFrom(stats,rgamma)
importFrom(stats,rmultinom)
importFrom(stats,sd)
importFrom(stats,terms)
importFrom(stats,update)
Expand Down
105 changes: 52 additions & 53 deletions R/augsynth.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
## Main functions for single-period treatment augmented synthetic controls Method
################################################################################


#' Fit Augmented SCM
#'
#'
#' @param form outcome ~ treatment | auxillary covariates
#' @param unit Name of unit column
#' @param time Name of time column
Expand All @@ -14,11 +13,11 @@
#' ridge=Ridge regression (allows for standard errors),
#' none=No outcome model,
#' en=Elastic Net, RF=Random Forest, GSYN=gSynth,
#' mcp=MCPanel,
#' mcp=MCPanel,
#' cits=Comparitive Interuppted Time Series
#' causalimpact=Bayesian structural time series with CausalImpact
#' @param scm Whether the SCM weighting function is used
#' @param fixedeff Whether to include a unit fixed effect, default F
#' @param fixedeff Whether to include a unit fixed effect, default F
#' @param cov_agg Covariate aggregation functions, if NULL then use mean with NAs omitted
#' @param ... optional arguments for outcome model
#'
Expand Down Expand Up @@ -49,27 +48,27 @@ single_augsynth <- function(form, unit, time, t_int, data,

wide <- format_data(outcome, trt, unit, time, t_int, data)
synth_data <- do.call(format_synth, wide)

treated_units <- data %>% filter(!!trt == 1) %>% distinct(!!unit) %>% pull(!!unit)
control_units <- data %>% filter(!(!!unit %in% treated_units)) %>%
control_units <- data %>% filter(!(!!unit %in% treated_units)) %>%
distinct(!!unit) %>% arrange(!!unit) %>% pull(!!unit)
## add covariates
if(length(form)[2] == 2) {
Z <- extract_covariates(form, unit, time, t_int, data, cov_agg)
} else {
Z <- NULL
}

# fit augmented SCM
augsynth <- fit_augsynth_internal(wide, synth_data, Z, progfunc,
augsynth <- fit_augsynth_internal(wide, synth_data, Z, progfunc,
scm, fixedeff, ...)

# add some extra data
augsynth$data$time <- data %>% distinct(!!time) %>%
arrange(!!time) %>% pull(!!time)
augsynth$call <- call_name
augsynth$t_int <- t_int
augsynth$t_int <- t_int

augsynth$weights <- matrix(augsynth$weights)
rownames(augsynth$weights) <- control_units

Expand All @@ -86,9 +85,9 @@ single_augsynth <- function(form, unit, time, t_int, data,
#' @param fixedeff Whether to de-mean synth
#' @param V V matrix for Synth, default NULL
#' @param ... Extra args for outcome model
#'
#'
#' @noRd
#'
#'
fit_augsynth_internal <- function(wide, synth_data, Z, progfunc,
scm, fixedeff, V = NULL, ...) {

Expand Down Expand Up @@ -119,23 +118,23 @@ fit_augsynth_internal <- function(wide, synth_data, Z, progfunc,
} else if(progfunc == "none") {
## Just SCM
augsynth <- do.call(fit_ridgeaug_formatted,
c(list(wide_data = fit_wide,
c(list(wide_data = fit_wide,
synth_data = fit_synth_data,
Z = Z, ridge = F, scm = T, V = V, ...)))
} else {
## Other outcome models
progfuncs = c("ridge", "none", "en", "rf", "gsyn", "mcp",
"cits", "causalimpact", "seq2seq")
if (progfunc %in% progfuncs) {
augsynth <- fit_augsyn(fit_wide, fit_synth_data,
augsynth <- fit_augsyn(fit_wide, fit_synth_data,
progfunc, scm, ...)
} else {
stop("progfunc must be one of 'EN', 'RF', 'GSYN', 'MCP', 'CITS', 'CausalImpact', 'seq2seq', 'None'")
}

}

augsynth$mhat <- mhat + cbind(matrix(0, nrow = n, ncol = t0),
augsynth$mhat <- mhat + cbind(matrix(0, nrow = n, ncol = t0),
augsynth$mhat)
augsynth$data <- wide
augsynth$data$Z <- Z
Expand Down Expand Up @@ -169,13 +168,13 @@ predict.augsynth <- function(object, att = F, ...) {
# att <- F
# }
augsynth <- object

X <- augsynth$data$X
y <- augsynth$data$y
comb <- cbind(X, y)
trt <- augsynth$data$trt
mhat <- augsynth$mhat

m1 <- colMeans(mhat[trt==1,,drop=F])

resid <- (comb[trt==0,,drop=F] - mhat[trt==0,drop=F])
Expand All @@ -198,7 +197,7 @@ predict.augsynth <- function(object, att = F, ...) {
#' @export
print.augsynth <- function(x, ...) {
augsynth <- x

## straight from lm
cat("\nCall:\n", paste(deparse(augsynth$call), sep="\n", collapse="\n"), "\n\n", sep="")

Expand All @@ -214,7 +213,7 @@ print.augsynth <- function(x, ...) {

#' Plot function for augsynth
#' @importFrom graphics plot
#'
#'
#' @param x Augsynth object to be plotted
#' @param inf Boolean, whether to get confidence intervals around the point estimates
#' @param cv If True, plot cross validation MSE against hyper-parameter, otherwise plot effects
Expand All @@ -228,22 +227,22 @@ plot.augsynth <- function(x, inf = T, cv = F, ...) {
# }

augsynth <- x

if (cv == T) {
errors = data.frame(lambdas = augsynth$lambdas,
errors = augsynth$lambda_errors,
errors_se = augsynth$lambda_errors_se)
p <- ggplot2::ggplot(errors, ggplot2::aes(x = lambdas, y = errors)) +
ggplot2::geom_point(size = 2) +
ggplot2::geom_point(size = 2) +
ggplot2::geom_errorbar(
ggplot2::aes(ymin = errors,
ymax = errors + errors_se),
width=0.2, size = 0.5)
width=0.2, size = 0.5)
p <- p + ggplot2::labs(title = bquote("Cross Validation MSE over " ~ lambda),
x = expression(lambda), y = "Cross Validation MSE",
x = expression(lambda), y = "Cross Validation MSE",
parse = TRUE)
p <- p + ggplot2::scale_x_log10()

# find minimum and min + 1se lambda to plot
min_lambda <- choose_lambda(augsynth$lambdas,
augsynth$lambda_errors,
Expand All @@ -257,7 +256,7 @@ plot.augsynth <- function(x, inf = T, cv = F, ...) {
min_1se_lambda_index <- which(augsynth$lambdas == min_1se_lambda)

p <- p + ggplot2::geom_point(
ggplot2::aes(x = min_lambda,
ggplot2::aes(x = min_lambda,
y = augsynth$lambda_errors[min_lambda_index]),
color = "gold")
p + ggplot2::geom_point(
Expand Down Expand Up @@ -299,8 +298,8 @@ summary.augsynth <- function(object, inf = T, inf_type = "conformal", ...) {
# } else {
# inf_type <- "conformal"
# }


summ <- list()

t0 <- ncol(augsynth$data$X)
Expand Down Expand Up @@ -382,8 +381,8 @@ summary.augsynth <- function(object, inf = T, inf_type = "conformal", ...) {
} else {
summ$bias_est <- m1 - t(mhat[trt==0,,drop=F]) %*% w
}


summ$inf_type <- if(inf) inf_type else "None"
class(summ) <- "summary.augsynth"
return(summ)
Expand All @@ -395,7 +394,7 @@ summary.augsynth <- function(object, inf = T, inf_type = "conformal", ...) {
#' @export
print.summary.augsynth <- function(x, ...) {
summ <- x

## straight from lm
cat("\nCall:\n", paste(deparse(summ$call), sep="\n", collapse="\n"), "\n\n", sep="")

Expand All @@ -405,7 +404,7 @@ print.summary.augsynth <- function(x, ...) {
att_est <- summ$att$Estimate
t_total <- length(att_est)
t_int <- summ$att %>% filter(Time <= summ$t_int) %>% nrow()

att_pre <- att_est[1:(t_int-1)]
att_post <- att_est[t_int:t_total]

Expand All @@ -420,14 +419,14 @@ print.summary.augsynth <- function(x, ...) {
se_avg <- summ$average_att$Std.Error

out_msg <- paste("Average ATT Estimate (Jackknife Std. Error): ",
format(round(att_post,3), nsmall=3),
format(round(att_post,3), nsmall=3),
" (",
format(round(se_avg,3)), ")\n")
inf_type <- "Jackknife over units"
} else if(summ$inf_type == "conformal") {
p_val <- summ$average_att$p_val
out_msg <- paste("Average ATT Estimate (p Value for Joint Null): ",
format(round(att_post,3), nsmall=3),
format(round(att_post,3), nsmall=3),
" (",
format(round(p_val,3)), ")\n")
inf_type <- "Conformal inference"
Expand All @@ -442,7 +441,7 @@ print.summary.augsynth <- function(x, ...) {
}


out_msg <- paste(out_msg,
out_msg <- paste(out_msg,
"L2 Imbalance: ",
format(round(summ$l2_imbalance,3), nsmall=3), "\n",
"Percent improvement from uniform weights: ",
Expand All @@ -452,16 +451,16 @@ print.summary.augsynth <- function(x, ...) {

out_msg <- paste(out_msg,
"Covariate L2 Imbalance: ",
format(round(summ$covariate_l2_imbalance,3),
format(round(summ$covariate_l2_imbalance,3),
nsmall=3),
"\n",
"Percent improvement from uniform weights: ",
format(round(1 - summ$scaled_covariate_l2_imbalance,3)*100),
format(round(1 - summ$scaled_covariate_l2_imbalance,3)*100),
"%\n\n",
sep="")

}
out_msg <- paste(out_msg,
out_msg <- paste(out_msg,
"Avg Estimated Bias: ",
format(round(mean(summ$bias_est), 3),nsmall=3), "\n\n",
"Inference type: ",
Expand All @@ -471,30 +470,30 @@ print.summary.augsynth <- function(x, ...) {
cat(out_msg)

if(summ$inf_type == "jackknife") {
out_att <- summ$att[t_int:t_final,] %>%
out_att <- summ$att[t_int:t_final,] %>%
select(Time, Estimate, Std.Error)
} else if(summ$inf_type == "conformal") {
out_att <- summ$att[t_int:t_final,] %>%
out_att <- summ$att[t_int:t_final,] %>%
select(Time, Estimate, lower_bound, upper_bound, p_val)
names(out_att) <- c("Time", "Estimate",
names(out_att) <- c("Time", "Estimate",
paste0((1 - summ$alpha) * 100, "% CI Lower Bound"),
paste0((1 - summ$alpha) * 100, "% CI Upper Bound"),
paste0("p Value"))
} else if(summ$inf_type == "jackknife+") {
out_att <- summ$att[t_int:t_final,] %>%
out_att <- summ$att[t_int:t_final,] %>%
select(Time, Estimate, lower_bound, upper_bound)
names(out_att) <- c("Time", "Estimate",
names(out_att) <- c("Time", "Estimate",
paste0((1 - summ$alpha) * 100, "% CI Lower Bound"),
paste0((1 - summ$alpha) * 100, "% CI Upper Bound"))
} else {
out_att <- summ$att[t_int:t_final,] %>%
out_att <- summ$att[t_int:t_final,] %>%
select(Time, Estimate)
}
out_att %>%
mutate_at(vars(-Time), ~ round(., 3)) %>%
print(row.names = F)


}

#' Plot function for summary function for augsynth
Expand All @@ -509,7 +508,7 @@ plot.summary.augsynth <- function(x, inf = T, ...) {
# } else {
# inf <- T
# }

p <- summ$att %>%
ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate))
if(inf) {
Expand All @@ -526,15 +525,15 @@ plot.summary.augsynth <- function(x, inf = T, ...) {
}
p + ggplot2::geom_line() +
ggplot2::geom_vline(xintercept=summ$t_int, lty=2) +
ggplot2::geom_hline(yintercept=0, lty=2) +
ggplot2::geom_hline(yintercept=0, lty=2) +
ggplot2::theme_bw()

}



#' augsynth
#'
#'
#' @description A package implementing the Augmented Synthetic Controls Method
#' @docType package
#' @name augsynth-package
Expand All @@ -545,9 +544,9 @@ plot.summary.augsynth <- function(x, inf = T, ...) {
#' @import tidyr
#' @importFrom stats terms
#' @importFrom stats formula
#' @importFrom stats update
#' @importFrom stats delete.response
#' @importFrom stats model.matrix
#' @importFrom stats model.frame
#' @importFrom stats update
#' @importFrom stats delete.response
#' @importFrom stats model.matrix
#' @importFrom stats model.frame
#' @importFrom stats na.omit
NULL
Loading