diff --git a/.Rbuildignore b/.Rbuildignore index 9c86736..224a30a 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -5,3 +5,5 @@ ^pkg.Rproj$ figure$ cache$ +^.*\.Rproj$ +^\.Rproj\.user$ diff --git a/.gitignore b/.gitignore index f53558f..af74ec9 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,5 @@ inst/doc test.R -*-vignette.pdf \ No newline at end of file +*-vignette.pdf +.Rproj.user diff --git a/DESCRIPTION b/DESCRIPTION index bc179df..39d57fe 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -18,17 +18,12 @@ Imports: Matrix, osqp, rlang, - purrr, -Remotes: - susanathey/MCPanel -License: MIT + file LICENSE -Encoding: UTF-8 -LazyData: true -RoxygenNote: 7.1.1 + purrr Suggests: testthat, CausalImpact, keras, + FNN, gsynth, knitr, rmarkdown, @@ -38,4 +33,10 @@ Suggests: randomForest, kableExtra, ggrepel +Remotes: + susanathey/MCPanel +License: MIT + file LICENSE +Encoding: UTF-8 +LazyData: true +RoxygenNote: 7.1.2 VignetteBuilder: knitr diff --git a/NAMESPACE b/NAMESPACE index ac926a3..96c327a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -31,6 +31,7 @@ importFrom(ggplot2,aes) importFrom(graphics,plot) importFrom(magrittr,"%>%") importFrom(purrr,reduce) +importFrom(stats,as.formula) importFrom(stats,coef) importFrom(stats,delete.response) importFrom(stats,formula) @@ -40,6 +41,10 @@ importFrom(stats,model.matrix) importFrom(stats,na.omit) importFrom(stats,poly) importFrom(stats,predict) +importFrom(stats,qnorm) +importFrom(stats,quantile) +importFrom(stats,rgamma) +importFrom(stats,rmultinom) importFrom(stats,sd) importFrom(stats,terms) importFrom(stats,update) diff --git a/R/augsynth.R b/R/augsynth.R index 00c7f96..d856cf0 100644 --- a/R/augsynth.R +++ b/R/augsynth.R @@ -2,9 +2,8 @@ ## Main functions for single-period treatment augmented synthetic controls Method ################################################################################ - #' Fit Augmented SCM -#' +#' #' @param form outcome ~ treatment | auxillary covariates #' @param unit Name of unit column #' @param time Name of time column @@ -14,11 +13,11 @@ #' ridge=Ridge regression (allows for standard errors), #' none=No outcome model, #' en=Elastic Net, RF=Random Forest, GSYN=gSynth, -#' mcp=MCPanel, +#' mcp=MCPanel, #' cits=Comparitive Interuppted Time Series #' causalimpact=Bayesian structural time series with CausalImpact #' @param scm Whether the SCM weighting function is used -#' @param fixedeff Whether to include a unit fixed effect, default F +#' @param fixedeff Whether to include a unit fixed effect, default F #' @param cov_agg Covariate aggregation functions, if NULL then use mean with NAs omitted #' @param ... optional arguments for outcome model #' @@ -49,9 +48,9 @@ single_augsynth <- function(form, unit, time, t_int, data, wide <- format_data(outcome, trt, unit, time, t_int, data) synth_data <- do.call(format_synth, wide) - + treated_units <- data %>% filter(!!trt == 1) %>% distinct(!!unit) %>% pull(!!unit) - control_units <- data %>% filter(!(!!unit %in% treated_units)) %>% + control_units <- data %>% filter(!(!!unit %in% treated_units)) %>% distinct(!!unit) %>% arrange(!!unit) %>% pull(!!unit) ## add covariates if(length(form)[2] == 2) { @@ -59,17 +58,17 @@ single_augsynth <- function(form, unit, time, t_int, data, } else { Z <- NULL } - + # fit augmented SCM - augsynth <- fit_augsynth_internal(wide, synth_data, Z, progfunc, + augsynth <- fit_augsynth_internal(wide, synth_data, Z, progfunc, scm, fixedeff, ...) - + # add some extra data augsynth$data$time <- data %>% distinct(!!time) %>% arrange(!!time) %>% pull(!!time) augsynth$call <- call_name - augsynth$t_int <- t_int - + augsynth$t_int <- t_int + augsynth$weights <- matrix(augsynth$weights) rownames(augsynth$weights) <- control_units @@ -86,9 +85,9 @@ single_augsynth <- function(form, unit, time, t_int, data, #' @param fixedeff Whether to de-mean synth #' @param V V matrix for Synth, default NULL #' @param ... Extra args for outcome model -#' +#' #' @noRd -#' +#' fit_augsynth_internal <- function(wide, synth_data, Z, progfunc, scm, fixedeff, V = NULL, ...) { @@ -119,7 +118,7 @@ fit_augsynth_internal <- function(wide, synth_data, Z, progfunc, } else if(progfunc == "none") { ## Just SCM augsynth <- do.call(fit_ridgeaug_formatted, - c(list(wide_data = fit_wide, + c(list(wide_data = fit_wide, synth_data = fit_synth_data, Z = Z, ridge = F, scm = T, V = V, ...))) } else { @@ -127,15 +126,15 @@ fit_augsynth_internal <- function(wide, synth_data, Z, progfunc, progfuncs = c("ridge", "none", "en", "rf", "gsyn", "mcp", "cits", "causalimpact", "seq2seq") if (progfunc %in% progfuncs) { - augsynth <- fit_augsyn(fit_wide, fit_synth_data, + augsynth <- fit_augsyn(fit_wide, fit_synth_data, progfunc, scm, ...) } else { stop("progfunc must be one of 'EN', 'RF', 'GSYN', 'MCP', 'CITS', 'CausalImpact', 'seq2seq', 'None'") } - + } - augsynth$mhat <- mhat + cbind(matrix(0, nrow = n, ncol = t0), + augsynth$mhat <- mhat + cbind(matrix(0, nrow = n, ncol = t0), augsynth$mhat) augsynth$data <- wide augsynth$data$Z <- Z @@ -169,13 +168,13 @@ predict.augsynth <- function(object, att = F, ...) { # att <- F # } augsynth <- object - + X <- augsynth$data$X y <- augsynth$data$y comb <- cbind(X, y) trt <- augsynth$data$trt mhat <- augsynth$mhat - + m1 <- colMeans(mhat[trt==1,,drop=F]) resid <- (comb[trt==0,,drop=F] - mhat[trt==0,drop=F]) @@ -198,7 +197,7 @@ predict.augsynth <- function(object, att = F, ...) { #' @export print.augsynth <- function(x, ...) { augsynth <- x - + ## straight from lm cat("\nCall:\n", paste(deparse(augsynth$call), sep="\n", collapse="\n"), "\n\n", sep="") @@ -214,7 +213,7 @@ print.augsynth <- function(x, ...) { #' Plot function for augsynth #' @importFrom graphics plot -#' +#' #' @param x Augsynth object to be plotted #' @param inf Boolean, whether to get confidence intervals around the point estimates #' @param cv If True, plot cross validation MSE against hyper-parameter, otherwise plot effects @@ -228,22 +227,22 @@ plot.augsynth <- function(x, inf = T, cv = F, ...) { # } augsynth <- x - + if (cv == T) { errors = data.frame(lambdas = augsynth$lambdas, errors = augsynth$lambda_errors, errors_se = augsynth$lambda_errors_se) p <- ggplot2::ggplot(errors, ggplot2::aes(x = lambdas, y = errors)) + - ggplot2::geom_point(size = 2) + + ggplot2::geom_point(size = 2) + ggplot2::geom_errorbar( ggplot2::aes(ymin = errors, ymax = errors + errors_se), - width=0.2, size = 0.5) + width=0.2, size = 0.5) p <- p + ggplot2::labs(title = bquote("Cross Validation MSE over " ~ lambda), - x = expression(lambda), y = "Cross Validation MSE", + x = expression(lambda), y = "Cross Validation MSE", parse = TRUE) p <- p + ggplot2::scale_x_log10() - + # find minimum and min + 1se lambda to plot min_lambda <- choose_lambda(augsynth$lambdas, augsynth$lambda_errors, @@ -257,7 +256,7 @@ plot.augsynth <- function(x, inf = T, cv = F, ...) { min_1se_lambda_index <- which(augsynth$lambdas == min_1se_lambda) p <- p + ggplot2::geom_point( - ggplot2::aes(x = min_lambda, + ggplot2::aes(x = min_lambda, y = augsynth$lambda_errors[min_lambda_index]), color = "gold") p + ggplot2::geom_point( @@ -299,8 +298,8 @@ summary.augsynth <- function(object, inf = T, inf_type = "conformal", ...) { # } else { # inf_type <- "conformal" # } - - + + summ <- list() t0 <- ncol(augsynth$data$X) @@ -382,8 +381,8 @@ summary.augsynth <- function(object, inf = T, inf_type = "conformal", ...) { } else { summ$bias_est <- m1 - t(mhat[trt==0,,drop=F]) %*% w } - - + + summ$inf_type <- if(inf) inf_type else "None" class(summ) <- "summary.augsynth" return(summ) @@ -395,7 +394,7 @@ summary.augsynth <- function(object, inf = T, inf_type = "conformal", ...) { #' @export print.summary.augsynth <- function(x, ...) { summ <- x - + ## straight from lm cat("\nCall:\n", paste(deparse(summ$call), sep="\n", collapse="\n"), "\n\n", sep="") @@ -405,7 +404,7 @@ print.summary.augsynth <- function(x, ...) { att_est <- summ$att$Estimate t_total <- length(att_est) t_int <- summ$att %>% filter(Time <= summ$t_int) %>% nrow() - + att_pre <- att_est[1:(t_int-1)] att_post <- att_est[t_int:t_total] @@ -420,14 +419,14 @@ print.summary.augsynth <- function(x, ...) { se_avg <- summ$average_att$Std.Error out_msg <- paste("Average ATT Estimate (Jackknife Std. Error): ", - format(round(att_post,3), nsmall=3), + format(round(att_post,3), nsmall=3), " (", format(round(se_avg,3)), ")\n") inf_type <- "Jackknife over units" } else if(summ$inf_type == "conformal") { p_val <- summ$average_att$p_val out_msg <- paste("Average ATT Estimate (p Value for Joint Null): ", - format(round(att_post,3), nsmall=3), + format(round(att_post,3), nsmall=3), " (", format(round(p_val,3)), ")\n") inf_type <- "Conformal inference" @@ -442,7 +441,7 @@ print.summary.augsynth <- function(x, ...) { } - out_msg <- paste(out_msg, + out_msg <- paste(out_msg, "L2 Imbalance: ", format(round(summ$l2_imbalance,3), nsmall=3), "\n", "Percent improvement from uniform weights: ", @@ -452,16 +451,16 @@ print.summary.augsynth <- function(x, ...) { out_msg <- paste(out_msg, "Covariate L2 Imbalance: ", - format(round(summ$covariate_l2_imbalance,3), + format(round(summ$covariate_l2_imbalance,3), nsmall=3), "\n", "Percent improvement from uniform weights: ", - format(round(1 - summ$scaled_covariate_l2_imbalance,3)*100), + format(round(1 - summ$scaled_covariate_l2_imbalance,3)*100), "%\n\n", sep="") } - out_msg <- paste(out_msg, + out_msg <- paste(out_msg, "Avg Estimated Bias: ", format(round(mean(summ$bias_est), 3),nsmall=3), "\n\n", "Inference type: ", @@ -471,30 +470,30 @@ print.summary.augsynth <- function(x, ...) { cat(out_msg) if(summ$inf_type == "jackknife") { - out_att <- summ$att[t_int:t_final,] %>% + out_att <- summ$att[t_int:t_final,] %>% select(Time, Estimate, Std.Error) } else if(summ$inf_type == "conformal") { - out_att <- summ$att[t_int:t_final,] %>% + out_att <- summ$att[t_int:t_final,] %>% select(Time, Estimate, lower_bound, upper_bound, p_val) - names(out_att) <- c("Time", "Estimate", + names(out_att) <- c("Time", "Estimate", paste0((1 - summ$alpha) * 100, "% CI Lower Bound"), paste0((1 - summ$alpha) * 100, "% CI Upper Bound"), paste0("p Value")) } else if(summ$inf_type == "jackknife+") { - out_att <- summ$att[t_int:t_final,] %>% + out_att <- summ$att[t_int:t_final,] %>% select(Time, Estimate, lower_bound, upper_bound) - names(out_att) <- c("Time", "Estimate", + names(out_att) <- c("Time", "Estimate", paste0((1 - summ$alpha) * 100, "% CI Lower Bound"), paste0((1 - summ$alpha) * 100, "% CI Upper Bound")) } else { - out_att <- summ$att[t_int:t_final,] %>% + out_att <- summ$att[t_int:t_final,] %>% select(Time, Estimate) } out_att %>% mutate_at(vars(-Time), ~ round(., 3)) %>% print(row.names = F) - + } #' Plot function for summary function for augsynth @@ -509,7 +508,7 @@ plot.summary.augsynth <- function(x, inf = T, ...) { # } else { # inf <- T # } - + p <- summ$att %>% ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate)) if(inf) { @@ -526,7 +525,7 @@ plot.summary.augsynth <- function(x, inf = T, ...) { } p + ggplot2::geom_line() + ggplot2::geom_vline(xintercept=summ$t_int, lty=2) + - ggplot2::geom_hline(yintercept=0, lty=2) + + ggplot2::geom_hline(yintercept=0, lty=2) + ggplot2::theme_bw() } @@ -534,7 +533,7 @@ plot.summary.augsynth <- function(x, inf = T, ...) { #' augsynth -#' +#' #' @description A package implementing the Augmented Synthetic Controls Method #' @docType package #' @name augsynth-package @@ -545,9 +544,9 @@ plot.summary.augsynth <- function(x, inf = T, ...) { #' @import tidyr #' @importFrom stats terms #' @importFrom stats formula -#' @importFrom stats update -#' @importFrom stats delete.response -#' @importFrom stats model.matrix -#' @importFrom stats model.frame +#' @importFrom stats update +#' @importFrom stats delete.response +#' @importFrom stats model.matrix +#' @importFrom stats model.frame #' @importFrom stats na.omit NULL diff --git a/R/globalVariables.R b/R/globalVariables.R index 1cdab31..50afee5 100644 --- a/R/globalVariables.R +++ b/R/globalVariables.R @@ -1,5 +1,5 @@ utils::globalVariables(c("time", "val", "post", "weight", ".", "Time", "Estimate", "Std.Error", "Level", "last_time", "is_avg", "label", "Outcome", "unit", "obs", - "lambdas", "errors_se", - "upper_bound", "lower_bound")) \ No newline at end of file + "lambdas", "errors_se", "p_val", + "upper_bound", "lower_bound")) diff --git a/R/inference.R b/R/inference.R index d834e9a..f8329a7 100644 --- a/R/inference.R +++ b/R/inference.R @@ -15,7 +15,7 @@ #' \item{"ub"}{Upper bound of 1 - alpha confidence interval} #' \item{"alpha"}{Level of confidence interval} #' } -time_jackknife_plus <- function(ascm, alpha = 0.05, conservative = F) { +time_jackknife_plus <- function(ascm, alpha = 0.05, conservative = FALSE) { wide_data <- ascm$data synth_data <- ascm$data$synth_data n <- nrow(wide_data$X) @@ -26,7 +26,7 @@ time_jackknife_plus <- function(ascm, alpha = 0.05, conservative = F) { tpost <- ncol(wide_data$y) t_final <- dim(synth_data$Y0plot)[1] - jack_ests <- lapply(1:t0, + jack_ests <- lapply(1:t0, function(tdrop) { # drop unit i new_data <- drop_time_t(wide_data, Z, tdrop) @@ -56,11 +56,11 @@ time_jackknife_plus <- function(ascm, alpha = 0.05, conservative = F) { out <- list() att <- predict(ascm, att = T) - out$att <- c(att, + out$att <- c(att, mean(att[(t0 + 1):t_final])) # held out ATT - out$heldout_att <- c(held_out_errs, - att[(t0 + 1):t_final], + out$heldout_att <- c(held_out_errs, + att[(t0 + 1):t_final], mean(att[(t0 + 1):t_final])) # out$se <- rep(NA, 10 + tpost) @@ -95,7 +95,7 @@ drop_time_t <- function(wide_data, Z, t_drop) { new_wide_data <- list() new_wide_data$trt <- wide_data$trt new_wide_data$X <- wide_data$X[, -t_drop, drop = F] - new_wide_data$y <- cbind(wide_data$X[, t_drop, drop = F], + new_wide_data$y <- cbind(wide_data$X[, t_drop, drop = F], wide_data$y) X0 <- new_wide_data$X[new_wide_data$trt == 0,, drop = F] @@ -113,7 +113,7 @@ drop_time_t <- function(wide_data, Z, t_drop) { return(list(wide_data = new_wide_data, synth_data = new_synth_data, - Z = Z)) + Z = Z)) } #' Conformal inference procedure to compute p-values and point-wise confidence intervals @@ -134,7 +134,7 @@ drop_time_t <- function(wide_data, Z, t_drop) { #' \item{"p_val"}{p-value for test of no post-treatment effect} #' \item{"alpha"}{Level of confidence interval} #' } -conformal_inf <- function(ascm, alpha = 0.05, +conformal_inf <- function(ascm, alpha = 0.05, stat_func = NULL, type = "iid", q = 1, ns = 1000, grid_size = 50) { wide_data <- ascm$data @@ -177,9 +177,9 @@ conformal_inf <- function(ascm, alpha = 0.05, new_wide_data <- wide_data new_wide_data$X <- cbind(wide_data$X, wide_data$y) new_wide_data$y <- matrix(1, nrow = n, ncol = 1) - null_p <- compute_permute_pval(new_wide_data, ascm, 0, ncol(wide_data$y), + null_p <- compute_permute_pval(new_wide_data, ascm, 0, ncol(wide_data$y), type, q, ns, stat_func) - + out <- list() att <- predict(ascm, att = T) out$att <- c(att, mean(att[(t0 + 1):t_final])) @@ -201,7 +201,7 @@ conformal_inf <- function(ascm, alpha = 0.05, #' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)` #' @param ns Number of resamples for "iid" permutations #' @param stat_func Function to compute test statistic -#' +#' #' @return List that contains: #' \itemize{ #' \item{"resids"}{Residuals after enforcing the null} @@ -243,7 +243,7 @@ compute_permute_test_stats <- function(wide_data, ascm, h0, stat_func <- function(x) (sum(abs(x) ^ q) / sqrt(length(x))) ^ (1 / q) } if(type == "iid") { - test_stats <- sapply(1:ns, + test_stats <- sapply(1:ns, function(x) { reorder <- sample(resids) stat_func(reorder[(t0 + 1):tpost]) @@ -256,7 +256,7 @@ compute_permute_test_stats <- function(wide_data, ascm, h0, stat_func(reorder[(t0 + 1):tpost]) }) } - + return(list(resids = resids, test_stats = test_stats, stat_func = stat_func)) @@ -272,7 +272,7 @@ compute_permute_test_stats <- function(wide_data, ascm, h0, #' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)` #' @param ns Number of resamples for "iid" permutations #' @param stat_func Function to compute test statistic -#' +#' #' @return Computed p-value #' @noRd compute_permute_pval <- function(wide_data, ascm, h0, @@ -294,7 +294,7 @@ compute_permute_pval <- function(wide_data, ascm, h0, #' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)` #' @param ns Number of resamples for "iid" permutations #' @param stat_func Function to compute test statistic -#' +#' #' @return (lower bound of interval, upper bound of interval, p-value for null of 0 effect) #' @noRd compute_permute_ci <- function(wide_data, ascm, grid, @@ -302,9 +302,9 @@ compute_permute_ci <- function(wide_data, ascm, grid, q, ns, stat_func) { # make sure 0 is in the grid grid <- c(grid, 0) - ps <-sapply(grid, + ps <-sapply(grid, function(x) { - compute_permute_pval(wide_data, ascm, x, + compute_permute_pval(wide_data, ascm, x, post_length, type, q, ns, stat_func) }) c(min(grid[ps >= alpha]), max(grid[ps >= alpha]), ps[grid == 0]) @@ -312,6 +312,7 @@ compute_permute_ci <- function(wide_data, ascm, grid, #' Jackknife+ algorithm over time +#' @param ascm_multi data.frame #' @param ascm Fitted `augsynth` object #' @param alpha Confidence level #' @param conservative Whether to use the conservative jackknife+ procedure @@ -337,7 +338,7 @@ time_jackknife_plus_multiout <- function(ascm_multi, alpha = 0.05, conservative t_final <- t0 + tpost Z <- wide_data$Z - jack_ests <- lapply(1:t0, + jack_ests <- lapply(1:t0, function(tdrop) { # drop unit i new_data_list <- drop_time_t_multiout(data_list, Z, tdrop) @@ -373,15 +374,15 @@ time_jackknife_plus_multiout <- function(ascm_multi, alpha = 0.05, conservative out <- list() att <- predict(ascm_multi, att = T) - out$att <- rbind(att, + out$att <- rbind(att, colMeans(att[(t0 + 1):t_final, , drop = F])) # held out ATT - out$heldout_att <- rbind(t(held_out_errs), - att[(t0 + 1):t_final, , drop = F], + out$heldout_att <- rbind(t(held_out_errs), + att[(t0 + 1):t_final, , drop = F], colMeans(att[(t0 + 1):t_final, , drop = F])) if(conservative) { - qerr <- apply(abs(held_out_errs), 1, + qerr <- apply(abs(held_out_errs), 1, stats::quantile, 1 - alpha, type = 1) out$lb <- rbind(matrix(NA, nrow = t0, ncol = k), t(t(apply(jack_dist_cons, 1:2, min)) - qerr)) @@ -392,8 +393,8 @@ time_jackknife_plus_multiout <- function(ascm_multi, alpha = 0.05, conservative out$lb <- rbind(matrix(NA, nrow = t0, ncol = k), apply(jack_dist_low, 1:2, stats::quantile, alpha, type = 1)) - out$ub <- rbind(matrix(NA, nrow = t0, ncol = k), - apply(jack_dist_high, 1:2, + out$ub <- rbind(matrix(NA, nrow = t0, ncol = k), + apply(jack_dist_high, 1:2, stats::quantile, 1 - alpha, type = 1)) } # shift back to ATT scale @@ -422,7 +423,7 @@ drop_time_t_multiout <- function(data_list, Z, t_drop) { function(x) x[, -t_drop, drop = F]) new_data_list$y <- lapply(1:length(data_list$y), function(k) { - cbind(data_list$X[[k]][, t_drop, drop = F], + cbind(data_list$X[[k]][, t_drop, drop = F], data_list$y[[k]]) }) return(new_data_list) @@ -430,6 +431,7 @@ drop_time_t_multiout <- function(data_list, Z, t_drop) { #' Conformal inference procedure to compute p-values and point-wise confidence intervals +#' @inheritParams time_jackknife_plus_multiout #' @param ascm Fitted `augsynth` object #' @param alpha Confidence level #' @param stat_func Function to compute test statistic @@ -437,6 +439,7 @@ drop_time_t_multiout <- function(data_list, Z, t_drop) { #' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)` #' @param ns Number of resamples for "iid" permutations #' @param grid_size Number of grid points to use when inverting the hypothesis test +#' @param lin_h0 value #' @return List that contains: #' \itemize{ #' \item{"att"}{Vector of ATT estimates} @@ -447,7 +450,7 @@ drop_time_t_multiout <- function(data_list, Z, t_drop) { #' \item{"p_val"}{p-value for test of no post-treatment effect} #' \item{"alpha"}{Level of confidence interval} #' } -conformal_inf_multiout <- function(ascm_multi, alpha = 0.05, +conformal_inf_multiout <- function(ascm_multi, alpha = 0.05, stat_func = NULL, type = "iid", q = 1, ns = 1000, grid_size = 50, lin_h0 = NULL) { @@ -463,8 +466,8 @@ conformal_inf_multiout <- function(ascm_multi, alpha = 0.05, t_final <- t0 + tpost # grid of nulls - att <- predict(ascm_multi, att = T) - post_att <- att[(t0 +1):t_final,, drop = F] + att <- predict(ascm_multi, att = TRUE) + post_att <- att[(t0 +1):t_final,, drop = FALSE] post_sd <- apply(post_att, 2, function(x) sqrt(mean(x ^ 2, na.rm = T))) # iterate over post-treatment periods to get pointwise CIs vapply(1:tpost, @@ -478,8 +481,8 @@ conformal_inf_multiout <- function(ascm_multi, alpha = 0.05, colnames(data_list$y[[i]])[j]) Xi }) - - + + if(tpost > 1) { new_data_list$y <- lapply(1:k, function(i) { @@ -498,7 +501,7 @@ conformal_inf_multiout <- function(ascm_multi, alpha = 0.05, # make a grid around the estimated ATT if(is.null(lin_h0)) { - grid <- lapply(1:k, + grid <- lapply(1:k, function(i) { seq(att[t0 + j, i] - 2 * post_sd[i], att[t0 + j, i] + 2 * post_sd[i], length.out = grid_size) @@ -528,10 +531,10 @@ conformal_inf_multiout <- function(ascm_multi, alpha = 0.05, data_list$y[[i]][, 1, drop = FALSE] }) null_p <- compute_permute_pval_multiout(new_data_list, ascm_multi, - numeric(k), + numeric(k), tpost, type, q, ns, stat_func) if(is.null(lin_h0)) { - grid <- lapply(1:k, + grid <- lapply(1:k, function(i) { seq(min(att[(t0 + 1):tpost, i]) - 4 * post_sd[i], max(att[(t0 + 1):tpost, i]) + 4 * post_sd[i], @@ -572,6 +575,7 @@ conformal_inf_multiout <- function(ascm_multi, alpha = 0.05, #' Compute conformal test statistics +#' @inheritParams time_jackknife_plus_multiout #' @param wide_data List containing pre- and post-treatment outcomes and outcome vector #' @param ascm Fitted `augsynth` object #' @param h0 Null hypothesis to test @@ -580,7 +584,7 @@ conformal_inf_multiout <- function(ascm_multi, alpha = 0.05, #' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)` #' @param ns Number of resamples for "iid" permutations #' @param stat_func Function to compute test statistic -#' +#' #' @return List that contains: #' \itemize{ #' \item{"resids"}{Residuals after enforcing the null} @@ -618,7 +622,7 @@ compute_permute_test_stats_multiout <- function(data_list, ascm_multi, h0, stat_func <- function(x) (sum(abs(x) ^ q) / sqrt(length(x))) ^ (1 / q) } if(type == "iid") { - test_stats <- sapply(1:ns, + test_stats <- sapply(1:ns, function(x) { idxs <- sample(1:nrow(resids)) reorder <- resids[idxs, , drop = F] @@ -635,7 +639,7 @@ compute_permute_test_stats_multiout <- function(data_list, ascm_multi, h0, apply(reorder[(t0 + 1):tpost, , drop = F], 2, stat_func) }) } - + return(list(resids = resids, test_stats = matrix(test_stats, nrow = k), stat_func = stat_func)) @@ -643,6 +647,7 @@ compute_permute_test_stats_multiout <- function(data_list, ascm_multi, h0, #' Compute conformal p-value +#' @inheritParams time_jackknife_plus_multiout #' @param wide_data List containing pre- and post-treatment outcomes and outcome vector #' @param ascm Fitted `augsynth` object #' @param h0 Null hypothesis to test @@ -651,7 +656,7 @@ compute_permute_test_stats_multiout <- function(data_list, ascm_multi, h0, #' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)` #' @param ns Number of resamples for "iid" permutations #' @param stat_func Function to compute test statistic -#' +#' #' @return Computed p-value #' @noRd compute_permute_pval_multiout <- function(data_list, ascm_multi, h0, @@ -676,6 +681,7 @@ compute_permute_pval_multiout <- function(data_list, ascm_multi, h0, } #' Compute conformal p-value +#' @inheritParams time_jackknife_plus_multiout #' @param wide_data List containing pre- and post-treatment outcomes and outcome vector #' @param ascm Fitted `augsynth` object #' @param grid Set of null hypothesis to test for inversion @@ -684,7 +690,7 @@ compute_permute_pval_multiout <- function(data_list, ascm_multi, h0, #' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)` #' @param ns Number of resamples for "iid" permutations #' @param stat_func Function to compute test statistic -#' +#' #' @return (lower bound of interval, upper bound of interval, p-value for null of 0 effect) #' @noRd compute_permute_ci_multiout <- function(data_list, ascm_multi, grid, @@ -706,12 +712,12 @@ compute_permute_ci_multiout <- function(data_list, ascm_multi, grid, } ps <- apply(grid, 1, function(x) { - compute_permute_pval_multiout(data_list, ascm_multi, x, + compute_permute_pval_multiout(data_list, ascm_multi, x, post_length, type, q, ns, stat_func) }) - sapply(1:k, - function(i) c(min(grid[ps >= alpha, i]), - max(grid[ps >= alpha, i]), + sapply(1:k, + function(i) c(min(grid[ps >= alpha, i]), + max(grid[ps >= alpha, i]), ps[apply(grid == 0, 1, all)])) } @@ -769,7 +775,7 @@ drop_unit_i_multiout <- function(wide_list, Z, i) { #' Estimate standard errors for single ASCM with the jackknife #' Do this for ridge-augmented synth #' @param ascm Fitted augsynth object -#' +#' #' @return List that contains: #' \itemize{ #' \item{"att"}{Vector of ATT estimates} @@ -840,6 +846,7 @@ jackknife_se_single <- function(ascm) { #' Compute standard errors using the jackknife +#' @importFrom stats qnorm #' @param multisynth fitted multisynth object #' @param relative Whether to compute effects according to relative time #' @noRd @@ -994,16 +1001,17 @@ jackknife_se_multiout <- function(ascm) { #' Compute the weighted bootstrap distribution +#' @importFrom stats quantile #' @param multisynth fitted multisynth object #' @param rweight Function to draw random weights as a function of n (e.g rweight(n)) #' @param relative Whether to compute effects according to relative time #' @noRd weighted_bootstrap_multi <- function(multisynth, - rweight = rwild_b, - n_boot = 1000, - alpha = 0.05, - att_weight = NULL, - relative=NULL) { + rweight = rwild_b, + n_boot = 1000, + alpha = 0.05, + att_weight = NULL, + relative = NULL) { ## get info from the multisynth object if(is.null(relative)) { relative <- multisynth$relative @@ -1031,7 +1039,7 @@ weighted_bootstrap_multi <- function(multisynth, function(x) mean(x, na.rm=T)) upper_bound <- att - apply(bs_est, c(1,2), function(x) quantile(x, alpha / 2, na.rm = T)) - + lower_bound <- att - apply(bs_est, c(1,2), function(x) quantile(x, 1 - alpha / 2, na.rm = T)) @@ -1045,6 +1053,7 @@ weighted_bootstrap_multi <- function(multisynth, #' Bayesian bootstrap #' @param n Number of units +#' @importFrom stats rgamma #' @export rdirichlet_b <- function(n) { Z <- as.numeric(rgamma(n, 1, 1)) @@ -1053,6 +1062,7 @@ rdirichlet_b <- function(n) { #' Non-parametric bootstrap #' @param n Number of units +#' @importFrom stats rmultinom #' @export rmultinom_b <- function(n) as.numeric(rmultinom(1, n, rep(1 / n, n))) @@ -1063,4 +1073,4 @@ rwild_b <- function(n) { sample(c(-(sqrt(5) - 1) / 2, (sqrt(5) + 1) / 2 ), n, replace = TRUE, prob = c((sqrt(5) + 1)/ (2 * sqrt(5)), (sqrt(5) - 1) / (2 * sqrt(5)))) -} \ No newline at end of file +} diff --git a/R/multi_outcomes.R b/R/multi_outcomes.R index e94ea75..f1569c0 100644 --- a/R/multi_outcomes.R +++ b/R/multi_outcomes.R @@ -1,4 +1,5 @@ #' Fit Augmented SCM with multiple outcomes +#' @importFrom stats as.formula #' @param form outcome ~ treatment | auxillary covariates #' @param unit Name of unit column #' @param time Name of time column @@ -12,7 +13,7 @@ #' CausalImpact=Bayesian structural time series with CausalImpact #' seq2seq=Sequence to sequence learning with feedforward nets #' @param scm Whether the SCM weighting function is used -#' @param fixedeff Whether to include a unit fixed effect, default F +#' @param fixedeff Whether to include a unit fixed effect, default F #' @param cov_agg Covariate aggregation functions, if NULL then use mean with NAs omitted #' @param ... optional arguments for outcome model #' @@ -39,7 +40,7 @@ augsynth_multiout <- function(form, unit, time, t_int, data, form <- Formula::Formula(form) unit <- enquo(unit) time <- enquo(time) - + ## format data outcome <- terms(formula(form, rhs=1))[[2]] trt <- terms(formula(form, rhs=1))[[3]] @@ -48,9 +49,9 @@ augsynth_multiout <- function(form, unit, time, t_int, data, outcomes <- sapply(outcomes_str, quo) # get outcomes as a list wide_list <- format_data_multi(outcomes, trt, unit, time, t_int, data) - - + + ## add covariates if(length(form)[2] == 2) { @@ -73,11 +74,11 @@ augsynth_multiout <- function(form, unit, time, t_int, data, # add some extra data augsynth$data$time <- data %>% distinct(!!time) %>% pull(!!time) augsynth$call <- call_name - augsynth$t_int <- t_int + augsynth$t_int <- t_int augsynth$combine_method <- combine_method treated_units <- data %>% filter(!!trt == 1) %>% distinct(!!unit) %>% pull(!!unit) - control_units <- data %>% filter(!(!!unit %in% treated_units)) %>% + control_units <- data %>% filter(!(!!unit %in% treated_units)) %>% distinct(!!unit) %>% pull(!!unit) augsynth$weights <- matrix(augsynth$weights) rownames(augsynth$weights) <- control_units @@ -96,7 +97,7 @@ augsynth_multiout <- function(form, unit, time, t_int, data, #' @param ... Extra args for outcome model #' @noRd fit_augsynth_multiout_internal <- function(wide_list, combine_method, Z, - progfunc, scm, fixedeff, + progfunc, scm, fixedeff, outcomes_str, ...) { @@ -115,7 +116,7 @@ fit_augsynth_multiout_internal <- function(wide_list, combine_method, Z, synth_data$Y1plot <- colMeans(cbind(X, y)[trt == 1,, drop = F]) - augsynth <- fit_augsynth_internal(wide_bal, synth_data, Z, progfunc, + augsynth <- fit_augsynth_internal(wide_bal, synth_data, Z, progfunc, scm, fixedeff, V = V, ...) # potentially add back in fixed effects @@ -184,10 +185,10 @@ combine_outcomes <- function(wide_list, combine_method, fixedeff, trt = wide_list$trt) # V matrix scales by inverse variance for outcome and number of periods - V <- do.call(c, - lapply(wide_list$X, - function(x) rep(1 / (sqrt(ncol(x)) * - sd(x[wide_list$trt == 0, , drop = F], na.rm=T)), + V <- do.call(c, + lapply(wide_list$X, + function(x) rep(1 / (sqrt(ncol(x)) * + sd(x[wide_list$trt == 0, , drop = F], na.rm=T)), ncol(x)))) } else if(combine_method == "svd") { wide_bal <- list(X = do.call(cbind, wide_list$X), @@ -195,8 +196,8 @@ combine_outcomes <- function(wide_list, combine_method, fixedeff, trt = wide_list$trt) # first get the standard deviations of the outcomes to put on the same scale - sds <- do.call(c, - lapply(wide_list$X, + sds <- do.call(c, + lapply(wide_list$X, function(x) rep((sqrt(ncol(x)) * sd(x, na.rm=T)), ncol(x)))) # do an SVD on centered and scaled outcomes @@ -206,7 +207,7 @@ combine_outcomes <- function(wide_list, combine_method, fixedeff, V <- diag(1 / sds) %*% svd(X0)$v[, 1:k, drop = FALSE] }else { - stop(paste("combine_method should be one of ('concat'),", + stop(paste("combine_method should be one of ('concat'),", combine_method, " is not a valid combining option")) } @@ -232,33 +233,33 @@ predict.augsynth_multiout <- function(object, ...) { # separate out by outcome n_outs <- length(object$data_list$X) - max_t <- max(sapply(1:n_outs, + max_t <- max(sapply(1:n_outs, function(k) ncol(object$data_list$X[[k]]) + ncol(object$data_list$y[[k]]))) - pred_reshape <- matrix(NA, ncol = n_outs, + pred_reshape <- matrix(NA, ncol = n_outs, nrow = max_t) - colnames <- lapply(1:n_outs, - function(k) colnames(cbind(object$data_list$X[[k]], + colnames <- lapply(1:n_outs, + function(k) colnames(cbind(object$data_list$X[[k]], object$data_list$y[[k]]))) rownames(pred_reshape) <- colnames[[which.max(sapply(colnames, length))]] colnames(pred_reshape) <- object$outcomes # get outcome names for predictions - pre_outs <- do.call(c, - sapply(1:n_outs, + pre_outs <- do.call(c, + sapply(1:n_outs, function(j) { rep(object$outcomes[j], ncol(object$data_list$X[[j]])) }, simplify = FALSE)) - + post_outs <- do.call(c, - sapply(1:n_outs, + sapply(1:n_outs, function(j) { rep(object$outcomes[j], ncol(object$data_list$y[[j]])) }, simplify = FALSE)) # print(pred) # print(cbind(names(pred), c(pre_outs, post_outs))) - + pred_reshape[cbind(names(pred), c(pre_outs, post_outs))] <- pred return(pred_reshape) } @@ -286,8 +287,8 @@ print.augsynth_multiout <- function(x, ...) { #' @param object augsynth_multiout object #' @param ... Optional arguments, including \itemize{\item{"se"}{Whether to plot standard error}} #' @export -summary.augsynth_multiout <- function(object, inf = T, inf_type = "jackknife", ...) { - +summary.augsynth_multiout <- function(object, inf = TRUE, inf_type = "jackknife", ...) { + summ <- list() @@ -304,66 +305,66 @@ summary.augsynth_multiout <- function(object, inf = T, inf_type = "jackknife", . t_final <- nrow(att_se$att) - att_df <- data.frame(att_se$att[1:(t_final - 1),, drop=F]) + att_df <- data.frame(att_se$att[1:(t_final - 1),, drop = FALSE]) names(att_df) <- object$outcomes att_df$Time <- object$data$time att_df <- att_df %>% gather(Outcome, Estimate, -Time) if(inf_type == "jackknife") { - se_df <- data.frame(att_se$se[1:(t_final - 1),, drop=F]) + se_df <- data.frame(att_se$se[1:(t_final - 1),, drop = FALSE]) names(se_df) <- object$outcomes se_df$Time <- object$data$time se_df <- se_df %>% gather(Outcome, Std.Error, -Time) att <- inner_join(att_df, se_df, by = c("Time", "Outcome")) } else if(inf_type %in% c("conformal", "jackknife+")) { - - lb_df <- data.frame(att_se$lb[1:(t_final - 1),, drop=F]) + + lb_df <- data.frame(att_se$lb[1:(t_final - 1),, drop = FALSE]) names(lb_df) <- object$outcomes lb_df$Time <- object$data$time lb_df <- lb_df %>% gather(Outcome, lower_bound, -Time) - ub_df <- data.frame(att_se$ub[1:(t_final - 1),, drop=F]) + ub_df <- data.frame(att_se$ub[1:(t_final - 1),, drop = FALSE]) names(ub_df) <- object$outcomes ub_df$Time <- object$data$time ub_df <- ub_df %>% gather(Outcome, upper_bound, -Time) att <- inner_join(att_df, lb_df, by = c("Time", "Outcome")) %>% - inner_join(ub_df, by = c("Time", "Outcome")) + inner_join(ub_df, by = c("Time", "Outcome")) if(inf_type == "conformal") { - pval_df <- data.frame(att_se$p_val[1:(t_final - 1),, drop=F]) + pval_df <- data.frame(att_se$p_val[1:(t_final - 1),, drop = FALSE]) names(pval_df) <- object$outcomes pval_df$Time <- object$data$time pval_df <- pval_df %>% gather(Outcome, p_val, -Time) - att <- inner_join(att, pval_df, by = c("Time", "Outcome")) + att <- inner_join(att, pval_df, by = c("Time", "Outcome")) } } - att_avg <- data.frame(att_se$att[t_final,, drop = F]) + att_avg <- data.frame(att_se$att[t_final,, drop = FALSE]) names(att_avg) <- object$outcomes att_avg <- gather(att_avg, Outcome, Estimate) if(inf_type == "jackknife") { - att_avg_se <- data.frame(att_se$se[t_final,, drop = F]) + att_avg_se <- data.frame(att_se$se[t_final,, drop = FALSE]) names(att_avg_se) <- object$outcomes att_avg_se <- gather(att_avg_se, Outcome, Std.Error) average_att <- inner_join(att_avg, att_avg_se, by="Outcome") } else if(inf_type %in% c("conformal", "jackknife+")){ - att_avg_lb <- data.frame(att_se$lb[t_final,, drop = F]) + att_avg_lb <- data.frame(att_se$lb[t_final,, drop = FALSE]) names(att_avg_lb) <- object$outcomes att_avg_lb <- gather(att_avg_lb, Outcome, lower_bound) - att_avg_ub <- data.frame(att_se$ub[t_final,, drop = F]) + att_avg_ub <- data.frame(att_se$ub[t_final,, drop = FALSE]) names(att_avg_ub) <- object$outcomes att_avg_ub <- gather(att_avg_ub, Outcome, upper_bound) - - average_att <- inner_join(att_avg, att_avg_lb, by="Outcome") %>% + + average_att <- inner_join(att_avg, att_avg_lb, by = "Outcome") %>% inner_join(att_avg_ub, by = "Outcome") - + if(inf_type == "conformal") { - att_avg_pval <- data.frame(att_se$p_val[t_final,, drop = F]) + att_avg_pval <- data.frame(att_se$p_val[t_final,, drop = FALSE]) names(att_avg_pval) <- object$outcomes att_avg_pval <- gather(att_avg_pval, Outcome, p_val) @@ -372,10 +373,10 @@ summary.augsynth_multiout <- function(object, inf = T, inf_type = "jackknife", . } else { average_att <- gather(att_avg, Outcome, Estimate) } - + } else { - att_est <- predict(object, att = T) + att_est <- predict(object, att = TRUE) att_df <- data.frame(att_est) names(att_df) <- object$outcomes att_df$Time <- object$data$time @@ -383,7 +384,7 @@ summary.augsynth_multiout <- function(object, inf = T, inf_type = "jackknife", . att$Std.Error <- NA t_int <- min(sapply(object$data_list$X, ncol)) att_avg <- data.frame(colMeans( - att_est[as.numeric(rownames(att)) >= t_int,, drop = F])) + att_est[as.numeric(rownames(att)) >= t_int,, drop = FALSE])) names(att_avg) <- object$outcomes average_att <- gather(att_avg, Outcome, Estimate) average_att$Std.Error <- NA @@ -412,7 +413,7 @@ summary.augsynth_multiout <- function(object, inf = T, inf_type = "jackknife", . if(object$progfunc == "None" | (!object$scm)) { summ$bias_est <- NA } - + class(summ) <- "summary.augsynth_multiout" return(summ) } @@ -425,17 +426,17 @@ summary.augsynth_multiout <- function(object, inf = T, inf_type = "jackknife", . print.summary.augsynth_multiout <- function(x, ...) { ## straight from lm cat("\nCall:\n", paste(deparse(x$call), sep="\n", collapse="\n"), "\n\n", sep="") - + att_est <- x$att$Estimate ## get pre-treatment fit by outcome - imbal <- x$att %>% - filter(Time < x$t_int) %>% - group_by(Outcome) %>% - summarise(Pre.RMSE = sqrt(mean(Estimate ^ 2))) + imbal <- x$att %>% + filter(.data$Time < x$t_int) %>% + group_by(.data$Outcome) %>% + summarise(Pre.RMSE = sqrt(mean(.data$Estimate ^ 2))) cat(paste("Overall L2 Imbalance (Scaled):", - format(round(x$l2_imbalance,3), nsmall=3), " (", - format(round(x$scaled_l2_imbalance,3), nsmall=3), ")\n\n", + format(round(x$l2_imbalance, 3), nsmall = 3), " (", + format(round(x$scaled_l2_imbalance, 3), nsmall = 3), ")\n\n", # "Avg Estimated Bias: ", # format(round(mean(summ$bias_est), 3),nsmall=3), "\n\n", sep="")) @@ -448,28 +449,30 @@ print.summary.augsynth_multiout <- function(x, ...) { #' Plot function for summary function for augsynth #' @param x summary.augsynth_multiout object #' @param ... Optional arguments, including \itemize{\item{"se"}{Whether to plot standard error}} -#' +#' #' @export -plot.summary.augsynth_multiout <- function(x, inf = T, ...) { +plot.summary.augsynth_multiout <- function(x, inf = TRUE, ...) { p <- x$att %>% - ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate)) + ggplot2::ggplot(ggplot2::aes(x=.data$Time, y=.data$Estimate)) if(inf) { if(x$inf_type == "jackknife") { - p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin=Estimate-2*Std.Error, - ymax=Estimate+2*Std.Error), - alpha=0.2) + p <- p + ggplot2::geom_ribbon( + ggplot2::aes(ymin=.data$Estimate-2*.data$Std.Error, + ymax=.data$Estimate+2*.data$Std.Error), + alpha=0.2) } else if(x$inf_type %in% c("conformal", "jackknife+")) { - p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin=lower_bound, - ymax=upper_bound), - alpha=0.2) + p <- p + ggplot2::geom_ribbon( + ggplot2::aes(ymin=.data$lower_bound, + ymax=.data$upper_bound), + alpha=0.2) } } p + ggplot2::geom_line() + ggplot2::geom_vline(xintercept=x$t_int, lty=2) + ggplot2::geom_hline(yintercept=0, lty=2) + - ggplot2::facet_wrap(~ Outcome, scales = "free_y") + + ggplot2::facet_wrap(~ .data$Outcome, scales = "free_y") + ggplot2::theme_bw() } diff --git a/R/multisynth_class.R b/R/multisynth_class.R index 48b7e97..b64d1d0 100644 --- a/R/multisynth_class.R +++ b/R/multisynth_class.R @@ -24,12 +24,13 @@ #' @param n_factors Number of factors for interactive fixed effects, setting to NULL fits with CV, default is 0 #' @param scm Whether to fit scm weights #' @param time_cohort Whether to average synthetic controls into time cohorts, default FALSE +#' @param how_match "knn" #' @param cov_agg Covariate aggregation function #' @param eps_abs Absolute error tolerance for osqp #' @param eps_rel Relative error tolerance for osqp #' @param verbose Whether to print logs for osqp #' @param ... Extra arguments -#' +#' #' @return multisynth object that contains: #' \itemize{ #' \item{"weights"}{weights matrix where each column is a set of weights for a treated unit} @@ -142,10 +143,10 @@ multisynth <- function(form, unit, time, data, scm = scm, time_cohort = time_cohort, time_w = F, lambda_t = 0, fit_resids = TRUE, eps_abs = eps_abs, - eps_rel = eps_rel, verbose = verbose, long_df = long_df, + eps_rel = eps_rel, verbose = verbose, long_df = long_df, how_match = how_match, ...) - - + + units <- data %>% arrange(!!unit) %>% distinct(!!unit) %>% pull(!!unit) rownames(msynth$weights) <- units @@ -166,7 +167,7 @@ multisynth <- function(form, unit, time, data, V = V, time_cohort = time_cohort, donors = msynth$donors, - eps_rel = eps_rel, + eps_rel = eps_rel, eps_abs = eps_abs, verbose = verbose) ## scaled global balance @@ -186,6 +187,7 @@ multisynth <- function(form, unit, time, data, #' Internal funciton to fit staggered synth with formatted data +#' @inheritParams multisynth #' @param wide List containing data elements #' @param relative Whether to compute balance by relative time #' @param n_leads How long past treatment effects should be estimated for @@ -207,16 +209,17 @@ multisynth <- function(form, unit, time, data, #' @param ... Extra arguments #' @noRd #' @return multisynth object -multisynth_formatted <- function(wide, relative=T, n_leads, n_lags, - nu, lambda, V, - force, - n_factors, - scm, time_cohort, - time_w, lambda_t, - fit_resids, - eps_abs, eps_rel, - verbose, long_df, - how_match, ...) { +multisynth_formatted <- function(wide, relative = TRUE, + n_leads, n_lags, + nu, lambda, V, + force, + n_factors, + scm, time_cohort, + time_w, lambda_t, + fit_resids, + eps_abs, eps_rel, + verbose, long_df, + how_match, ...) { ## average together treatment groups ## grps <- unique(wide$trt) %>% sort() if(time_cohort) { @@ -246,23 +249,23 @@ multisynth_formatted <- function(wide, relative=T, n_leads, n_lags, n_factors <- ncol(params$factor) ## get residuals from outcome model residuals <- cbind(wide$X, wide$y) - y0hat - + } else if (n_factors != 0) { ## if number of factors is provided don't do CV out <- fit_gsynth_multi(long_df, cbind(wide$X, wide$y), wide$trt, r=n_factors, CV=0, force=force) y0hat <- out$y0hat - params <- out$params - + params <- out$params + ## get residuals from outcome model residuals <- cbind(wide$X, wide$y) - y0hat } else if(force == 0 & n_factors == 0) { - # if no fixed effects or factors, just take out + # if no fixed effects or factors, just take out # control averages at each time point # time fixed effects from pure controls pure_ctrl <- cbind(wide$X, wide$y)[!is.finite(wide$trt), , drop = F] y0hat <- matrix(colMeans(pure_ctrl, na.rm = TRUE), - nrow = nrow(wide$X), ncol = ncol(pure_ctrl), + nrow = nrow(wide$X), ncol = ncol(pure_ctrl), byrow = T) residuals <- cbind(wide$X, wide$y) - y0hat params <- NULL @@ -297,13 +300,13 @@ multisynth_formatted <- function(wide, relative=T, n_leads, n_lags, bal_mat <- wide$X - ctrl_avg bal_mat <- wide$X } - + if(scm) { # get eligible set of donor units based on covariates donors <- get_donors(wide$X, wide$y, wide$trt, - wide$Z[, colnames(wide$Z) %in% + wide$Z[, colnames(wide$Z) %in% wide$match_covariates, drop = F], time_cohort, n_lags, n_leads, how = how_match, exact_covariates = wide$exact_covariates, ...) @@ -386,9 +389,9 @@ multisynth_formatted <- function(wide, relative=T, n_leads, n_lags, msynth$time_w <- time_w msynth$lambda_t <- lambda_t msynth$fit_resids <- fit_resids - msynth$extra_pars <- c(list(eps_abs = eps_abs, - eps_rel = eps_rel, - verbose = verbose), + msynth$extra_pars <- c(list(eps_abs = eps_abs, + eps_rel = eps_rel, + verbose = verbose), list(...)) msynth$long_df <- long_df @@ -417,7 +420,7 @@ predict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = N multisynth <- object relative <- T - + time_cohort <- multisynth$time_cohort if(is.null(relative)) { relative <- multisynth$relative @@ -436,19 +439,19 @@ predict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = N } if(time_cohort) { - which_t <- lapply(grps, + which_t <- lapply(grps, function(tj) (1:n)[multisynth$data$trt == tj]) mask <- unique(multisynth$data$mask) } else { which_t <- (1:n)[is.finite(multisynth$data$trt)] mask <- multisynth$data$mask } - + n1 <- sapply(1:J, function(j) length(which_t[[j]])) fullmask <- cbind(mask, matrix(0, nrow = J, ncol = (ttot - d))) - + ## estimate the post-treatment values to get att estimates mu1hat <- vapply(1:J, @@ -502,11 +505,11 @@ predict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = N c(vec, rep(NA, total_len - length(vec)), mean(mu0hat[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j])) - + }, numeric(total_len +1 )) - + tauhat <- vapply(1:J, function(j) { vec <- c(rep(NA, d-grps[j]), @@ -536,11 +539,11 @@ predict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = N c(vec, rep(NA, total_len - length(vec)), mean(att_weight[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j])) - + }, numeric(total_len +1 )) - + ## get the overall average estimate avg <- apply(mu0hat, 1, function(z) sum(n1 * z, na.rm=T) / sum(n1 * !is.na(z))) avg <- sapply(1:nrow(mu0hat), @@ -557,7 +560,7 @@ predict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = N sum(n1 * (!is.na(tauhat[k,])) * att_weight_new[k, ], na.rm = T) }) tauhat <- cbind(avg, tauhat) - + } else { ## remove all estimates for t > T_j + n_leads @@ -571,7 +574,7 @@ predict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = N rep(NA, max(0, ttot-(grps[j] + n_leads)))), numeric(ttot)) -> tauhat - + ## only average currently treated units avg1 <- rowSums(t(fullmask) * mu0hat * n1) / rowSums(t(fullmask) * n1) @@ -590,7 +593,7 @@ predict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = N replace_na(avg2,0) * apply(1 - fullmask, 2, max) cbind(avg, tauhat) -> tauhat } - + if(att) { return(tauhat) @@ -606,9 +609,9 @@ predict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = N #' @export print.multisynth <- function(x, att_weight = NULL, ...) { multisynth <- x - + ## straight from lm - cat("\nCall:\n", paste(deparse(multisynth$call), + cat("\nCall:\n", paste(deparse(multisynth$call), sep="\n", collapse="\n"), "\n\n", sep="") # print att estimates @@ -655,7 +658,7 @@ plot.multisynth <- function(x, inf_type = "bootstrap", inf = T, #' \item{jackknife}{Jackknife} #' } #' @param ... Optional arguments -#' +#' #' @return summary.multisynth object that contains: #' \itemize{ #' \item{"att"}{Dataframe with ATT estimates, standard errors for each treated unit} @@ -669,7 +672,7 @@ plot.multisynth <- function(x, inf_type = "bootstrap", inf = T, summary.multisynth <- function(object, inf_type = "bootstrap", att_weight = NULL, ...) { multisynth <- object - + relative <- T n_leads <- multisynth$n_leads @@ -686,17 +689,17 @@ summary.multisynth <- function(object, inf_type = "bootstrap", att_weight = NULL grps <- trt[is.finite(trt)] which_t <- (1:n)[is.finite(trt)] } - + # grps <- unique(multisynth$data$trt) %>% sort() J <- length(grps) - + # which_t <- (1:n)[is.finite(multisynth$data$trt)] times <- multisynth$data$time - + summ <- list() ## post treatment estimate for each group and overall # att <- predict(multisynth, relative, att=T) - + if(inf_type == "jackknife") { attse <- jackknife_se_multi(multisynth, relative, att_weight = att_weight, ...) } else if(inf_type == "bootstrap") { @@ -712,16 +715,16 @@ summary.multisynth <- function(object, inf_type = "bootstrap", att_weight = NULL upper_bound = matrix(NA, nrow(att), ncol(att)), lower_bound = matrix(NA, nrow(att), ncol(att))) } - + if(relative) { att <- data.frame(cbind(c(-(d-1):min(n_leads, ttot-min(grps)), NA), attse$att)) if(time_cohort) { - col_names <- c("Time", "Average", + col_names <- c("Time", "Average", as.character(times[grps + 1])) } else { - col_names <- c("Time", "Average", + col_names <- c("Time", "Average", as.character(multisynth$data$units[which_t])) } names(att) <- col_names @@ -730,7 +733,7 @@ summary.multisynth <- function(object, inf_type = "bootstrap", att_weight = NULL mutate(Time=Time-1) -> att se <- data.frame(cbind(c(-(d-1):min(n_leads, ttot-min(grps)), NA), - attse$se)) + attse$se)) names(se) <- col_names se %>% gather(Level, Std.Error, -Time) %>% rename("Time"=Time) %>% @@ -751,11 +754,11 @@ summary.multisynth <- function(object, inf_type = "bootstrap", att_weight = NULL } else { att <- data.frame(cbind(times, attse$att)) - names(att) <- c("Time", "Average", times[grps[1:J]]) + names(att) <- c("Time", "Average", times[grps[1:J]]) att %>% gather(Level, Estimate, -Time) -> att se <- data.frame(cbind(times, attse$se)) - names(se) <- c("Time", "Average", times[grps[1:J]]) + names(se) <- c("Time", "Average", times[grps[1:J]]) se %>% gather(Level, Std.Error, -Time) -> se } @@ -788,12 +791,12 @@ summary.multisynth <- function(object, inf_type = "bootstrap", att_weight = NULL print.summary.multisynth <- function(x, level = "Average", ...) { summ <- x - + ## straight from lm cat("\nCall:\n", paste(deparse(summ$call), sep="\n", collapse="\n"), "\n\n", sep="") first_lvl <- summ$att %>% filter(Level != "Average") %>% pull(Level) %>% min() - + ## get ATT estimates for treatment level, post treatment if(summ$relative) { summ$att %>% @@ -816,18 +819,18 @@ print.summary.multisynth <- function(x, level = "Average", ...) { pull(Std.Error) %>% round(3) %>% format(nsmall=3), ")\n\n", sep="")) - + cat(paste("Global L2 Imbalance: ", format(round(summ$global_l2,3), nsmall=3), "\n", "Scaled Global L2 Imbalance: ", format(round(summ$scaled_global_l2,3), nsmall=3), "\n", - "Percent improvement from uniform global weights: ", + "Percent improvement from uniform global weights: ", format(round(1-summ$scaled_global_l2,3)*100), "\n\n", "Individual L2 Imbalance: ", format(round(summ$ind_l2,3), nsmall=3), "\n", - "Scaled Individual L2 Imbalance: ", + "Scaled Individual L2 Imbalance: ", format(round(summ$scaled_ind_l2,3), nsmall=3), "\n", - "Percent improvement from uniform individual weights: ", + "Percent improvement from uniform individual weights: ", format(round(1-summ$scaled_ind_l2,3)*100), "\t", "\n\n", sep="")) @@ -839,7 +842,7 @@ print.summary.multisynth <- function(x, level = "Average", ...) { #' Plot function for summary function for multisynth #' @importFrom ggplot2 aes -#' +#' #' @param x summary object #' @param inf Whether to plot confidence intervals #' @param levels Which units/groups to plot, default is every group @@ -849,7 +852,7 @@ print.summary.multisynth <- function(x, level = "Average", ...) { plot.summary.multisynth <- function(x, inf = T, levels = NULL, label = T, ...) { summ <- x - + ## get the last time period for each level summ$att %>% filter(!is.na(Estimate), @@ -860,7 +863,7 @@ plot.summary.multisynth <- function(x, inf = T, levels = NULL, label = T, ...) { if(is.null(levels)) levels <- unique(summ$att$Level) - summ$att %>% inner_join(last_times) %>% + summ$att %>% inner_join(last_times, by = "Level") %>% filter(Level %in% levels) %>% mutate(label = ifelse(Time == last_time, Level, NA), is_avg = ifelse(("Average" %in% levels) * (Level == "Average"), @@ -869,13 +872,14 @@ plot.summary.multisynth <- function(x, inf = T, levels = NULL, label = T, ...) { group = Level, color = is_avg, alpha = is_avg)) + - ggplot2::geom_line(size = 1) + - ggplot2::geom_point(size = 1) -> p - + ggplot2::geom_line(size = 1, na.rm = TRUE) + + ggplot2::geom_point(size = 1, na.rm = TRUE) -> p + if(label) { - p <- p + ggrepel::geom_label_repel(ggplot2::aes(label = label), - nudge_x = 1, na.rm = T) - } + p <- p + ggrepel::geom_label_repel( + ggplot2::aes(label = label), + nudge_x = 1, na.rm = TRUE) + } p <- p + ggplot2::geom_hline(yintercept = 0, lty = 2) if(summ$relative) { @@ -904,10 +908,10 @@ plot.summary.multisynth <- function(x, inf = T, levels = NULL, label = T, ...) { ggplot2::aes(ymin=lower_bound, ymax=upper_bound), alpha = alph, color=clr, - data = summ$att %>% + data = summ$att %>% filter(Level == "Average", Time >= 0)) - + } else { p <- p + error_plt( ggplot2::aes(ymin=lower_bound, @@ -919,7 +923,7 @@ plot.summary.multisynth <- function(x, inf = T, levels = NULL, label = T, ...) { p <- p + ggplot2::scale_alpha_manual(values=c(1, 0.5)) + ggplot2::scale_color_manual(values=c("#333333", "#818181")) + - ggplot2::guides(alpha=F, color=F) + + ggplot2::guides(alpha = "none", color = "none") + ggplot2::theme_bw() return(p) diff --git a/man/conformal_inf_multiout.Rd b/man/conformal_inf_multiout.Rd index 2c15d18..b5e6885 100644 --- a/man/conformal_inf_multiout.Rd +++ b/man/conformal_inf_multiout.Rd @@ -16,6 +16,8 @@ conformal_inf_multiout( ) } \arguments{ +\item{ascm_multi}{data.frame} + \item{alpha}{Confidence level} \item{stat_func}{Function to compute test statistic} @@ -28,6 +30,8 @@ conformal_inf_multiout( \item{grid_size}{Number of grid points to use when inverting the hypothesis test} +\item{lin_h0}{value} + \item{ascm}{Fitted `augsynth` object} } \value{ diff --git a/man/multisynth.Rd b/man/multisynth.Rd index f609dd8..3ba4cff 100644 --- a/man/multisynth.Rd +++ b/man/multisynth.Rd @@ -61,6 +61,8 @@ If covariates are time-varying, their average value before the first unit is tre \item{time_cohort}{Whether to average synthetic controls into time cohorts, default FALSE} +\item{how_match}{"knn"} + \item{cov_agg}{Covariate aggregation function} \item{eps_abs}{Absolute error tolerance for osqp} diff --git a/man/plot.summary.augsynth_multiout.Rd b/man/plot.summary.augsynth_multiout.Rd index 45b79ac..a8de434 100644 --- a/man/plot.summary.augsynth_multiout.Rd +++ b/man/plot.summary.augsynth_multiout.Rd @@ -4,7 +4,7 @@ \alias{plot.summary.augsynth_multiout} \title{Plot function for summary function for augsynth} \usage{ -\method{plot}{summary.augsynth_multiout}(x, inf = T, ...) +\method{plot}{summary.augsynth_multiout}(x, inf = TRUE, ...) } \arguments{ \item{x}{summary.augsynth_multiout object} diff --git a/man/single_augsynth.Rd b/man/single_augsynth.Rd index 10d7d72..d329b69 100644 --- a/man/single_augsynth.Rd +++ b/man/single_augsynth.Rd @@ -32,7 +32,7 @@ single_augsynth( ridge=Ridge regression (allows for standard errors), none=No outcome model, en=Elastic Net, RF=Random Forest, GSYN=gSynth, -mcp=MCPanel, +mcp=MCPanel, cits=Comparitive Interuppted Time Series causalimpact=Bayesian structural time series with CausalImpact} diff --git a/man/summary.augsynth_multiout.Rd b/man/summary.augsynth_multiout.Rd index 96e9bf7..ba349b6 100644 --- a/man/summary.augsynth_multiout.Rd +++ b/man/summary.augsynth_multiout.Rd @@ -4,7 +4,7 @@ \alias{summary.augsynth_multiout} \title{Summary function for augsynth} \usage{ -\method{summary}{augsynth_multiout}(object, inf = T, inf_type = "jackknife", ...) +\method{summary}{augsynth_multiout}(object, inf = TRUE, inf_type = "jackknife", ...) } \arguments{ \item{object}{augsynth_multiout object} diff --git a/man/time_jackknife_plus.Rd b/man/time_jackknife_plus.Rd index 5b9a3f3..ced761c 100644 --- a/man/time_jackknife_plus.Rd +++ b/man/time_jackknife_plus.Rd @@ -4,7 +4,7 @@ \alias{time_jackknife_plus} \title{Jackknife+ algorithm over time} \usage{ -time_jackknife_plus(ascm, alpha = 0.05, conservative = F) +time_jackknife_plus(ascm, alpha = 0.05, conservative = FALSE) } \arguments{ \item{ascm}{Fitted `augsynth` object} diff --git a/man/time_jackknife_plus_multiout.Rd b/man/time_jackknife_plus_multiout.Rd index 0d06cb2..576f6b5 100644 --- a/man/time_jackknife_plus_multiout.Rd +++ b/man/time_jackknife_plus_multiout.Rd @@ -7,6 +7,8 @@ time_jackknife_plus_multiout(ascm_multi, alpha = 0.05, conservative = F) } \arguments{ +\item{ascm_multi}{data.frame} + \item{alpha}{Confidence level} \item{conservative}{Whether to use the conservative jackknife+ procedure} diff --git a/pkg.Rproj b/pkg.Rproj index d848a9f..cba1b6b 100644 --- a/pkg.Rproj +++ b/pkg.Rproj @@ -5,8 +5,13 @@ SaveWorkspace: No AlwaysSaveHistory: Default EnableCodeIndexing: Yes +UseSpacesForTab: Yes +NumSpacesForTab: 2 Encoding: UTF-8 +RnwWeave: Sweave +LaTeX: pdfLaTeX + AutoAppendNewline: Yes StripTrailingWhitespace: Yes diff --git a/vignettes/multisynth-vignette.Rmd b/vignettes/multisynth-vignette.Rmd index 26b2a1b..f59a979 100644 --- a/vignettes/multisynth-vignette.Rmd +++ b/vignettes/multisynth-vignette.Rmd @@ -1,7 +1,8 @@ --- +title: "MultiSynth" output: rmarkdown::html_vignette vignette: > - %\VignetteIndexEntry{MultiSynth Vignette} + %\VignetteIndexEntry{MultiSynth} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- @@ -194,4 +195,4 @@ ppool_syn_cov_summ$att %>% Again we can plot the effects. ```{r ppool_syn_cov_plot, fig.width=8, fig.height=4.5, fig.align="center", warning=F, message=F} plot(ppool_syn_cov_summ, levels = "Average") -``` \ No newline at end of file +``` diff --git a/vignettes/singlesynth-vignette.Rmd b/vignettes/singlesynth-vignette.Rmd index 501de37..216f821 100644 --- a/vignettes/singlesynth-vignette.Rmd +++ b/vignettes/singlesynth-vignette.Rmd @@ -1,4 +1,5 @@ --- +title: "Single Outcome AugSynth Vignette" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Single Outcome AugSynth Vignette}