Showing preview only (397K chars total). Download the full file or copy to clipboard to get everything.
Repository: ebenmichael/augsynth
Branch: master
Commit: 65c5a6f34f4e
Files: 83
Total size: 373.9 KB
Directory structure:
gitextract_wwvwxodd/
├── .Rbuildignore
├── .gitignore
├── .travis.yml
├── DESCRIPTION
├── LICENSE
├── NAMESPACE
├── R/
│ ├── augsynth.R
│ ├── augsynth_pre.R
│ ├── cv.R
│ ├── data.R
│ ├── eligible_donors.R
│ ├── fit_synth.R
│ ├── format.R
│ ├── globalVariables.R
│ ├── highdim.R
│ ├── inference.R
│ ├── multi_outcomes.R
│ ├── multi_synth_qp.R
│ ├── multisynth_class.R
│ ├── outcome_models.R
│ ├── outcome_multi.R
│ ├── ridge.R
│ ├── ridge_lambda.R
│ └── time_regression_multi.R
├── README.md
├── data/
│ └── kansas.rda
├── data-raw/
│ ├── clean_kansas.R
│ └── kansas_longer2.dta
├── man/
│ ├── augsynth-package.Rd
│ ├── augsynth.Rd
│ ├── augsynth_multiout.Rd
│ ├── check_data_stag.Rd
│ ├── conformal_inf.Rd
│ ├── conformal_inf_linear.Rd
│ ├── conformal_inf_multiout.Rd
│ ├── get_nona_donors.Rd
│ ├── jackknife_se_single.Rd
│ ├── kansas.Rd
│ ├── make_V_matrix.Rd
│ ├── multisynth.Rd
│ ├── plot.augsynth.Rd
│ ├── plot.augsynth_multiout.Rd
│ ├── plot.multisynth.Rd
│ ├── plot.summary.augsynth.Rd
│ ├── plot.summary.augsynth_multiout.Rd
│ ├── plot.summary.multisynth.Rd
│ ├── predict.augsynth.Rd
│ ├── predict.augsynth_multiout.Rd
│ ├── predict.multisynth.Rd
│ ├── print.augsynth.Rd
│ ├── print.augsynth_multiout.Rd
│ ├── print.multisynth.Rd
│ ├── print.summary.augsynth.Rd
│ ├── print.summary.augsynth_multiout.Rd
│ ├── print.summary.multisynth.Rd
│ ├── rdirichlet_b.Rd
│ ├── rmultinom_b.Rd
│ ├── rwild_b.Rd
│ ├── single_augsynth.Rd
│ ├── summary.augsynth.Rd
│ ├── summary.augsynth_multiout.Rd
│ ├── summary.multisynth.Rd
│ ├── time_jackknife_plus.Rd
│ └── time_jackknife_plus_multiout.Rd
├── pkg.Rproj
├── tests/
│ ├── testthat/
│ │ ├── test_augsynth_pre.R
│ │ ├── test_format.R
│ │ ├── test_general.R
│ │ ├── test_lambda.R
│ │ ├── test_load_data.R
│ │ ├── test_multiple_outcomes.R
│ │ ├── test_multisynth.R
│ │ ├── test_multisynth_covariates.R
│ │ ├── test_outcome_models.R
│ │ ├── test_time_cohort.R
│ │ └── test_unbalanced_multisynth.R
│ └── testthat.R
└── vignettes/
├── .gitignore
├── multi-outcomes-vignette.Rmd
├── multisynth-vignette.Rmd
├── multisynth-vignette.md
├── singlesynth-vignette.Rmd
└── singlesynth-vignette.md
================================================
FILE CONTENTS
================================================
================================================
FILE: .Rbuildignore
================================================
^data-raw$
^Meta$
^doc$
^\.travis\.yml$
^pkg.Rproj$
figure$
cache$
================================================
FILE: .gitignore
================================================
Meta
doc
inst/doc
## Files
# Emacs autosave files
*~
\#*#
# Don't put data in the repo
*.csv
*.feather
# R stuff
*.Rout
*.Rhistory
*.RData
*.Rapp.history
# Mac stuff
*.DS_store
# C++ stuff
*.o
*.so
*.dll
test.R
*-vignette.pdf
================================================
FILE: .travis.yml
================================================
# R for travis: see documentation at https://docs.travis-ci.com/user/languages/r
language: r
r:
- 3.5.1
sudo: false
cache: packages
warnings_are_errors: false
r_binary_packages:
- dplyr
- magrittr
- ggplot2
- glmnet
- plyr
- kableExtra
================================================
FILE: DESCRIPTION
================================================
Package: augsynth
Title: The Augmented Synthetic Control Method
Version: 0.2.0
Authors@R: person("Eli", "Ben-Michael", email = "ebenmichael@berkeley.edu", role = c("aut", "cre"))
Description: A package implementing the Augmented Synthetic Controls Method.
Depends:
R (>= 3.5.0)
Imports:
dplyr,
tidyr,
magrittr,
ggplot2,
MASS,
LiblineaR,
Formula,
Matrix,
osqp,
rlang,
purrr,
FNN
Remotes:
susanathey/MCPanel
License: MIT + file LICENSE
Encoding: UTF-8
LazyData: true
RoxygenNote: 7.2.3
Suggests:
testthat,
CausalImpact,
keras,
gsynth,
knitr,
rmarkdown,
softImpute,
MCPanel,
glmnet,
randomForest,
kableExtra,
ggrepel
VignetteBuilder: knitr
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2018 Elijahu Ben-Michael
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: NAMESPACE
================================================
# Generated by roxygen2: do not edit by hand
S3method(plot,augsynth)
S3method(plot,augsynth_multiout)
S3method(plot,multisynth)
S3method(plot,summary.augsynth)
S3method(plot,summary.augsynth_multiout)
S3method(plot,summary.multisynth)
S3method(predict,augsynth)
S3method(predict,augsynth_multiout)
S3method(predict,multisynth)
S3method(print,augsynth)
S3method(print,augsynth_multiout)
S3method(print,multisynth)
S3method(print,summary.augsynth)
S3method(print,summary.augsynth_multiout)
S3method(print,summary.multisynth)
S3method(summary,augsynth)
S3method(summary,augsynth_multiout)
S3method(summary,multisynth)
export(augsynth)
export(augsynth_multiout)
export(multisynth)
export(rdirichlet_b)
export(rmultinom_b)
export(rwild_b)
export(single_augsynth)
import(dplyr)
import(tidyr)
importFrom(ggplot2,aes)
importFrom(graphics,plot)
importFrom(magrittr,"%>%")
importFrom(purrr,reduce)
importFrom(stats,coef)
importFrom(stats,delete.response)
importFrom(stats,formula)
importFrom(stats,lm)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,na.omit)
importFrom(stats,poly)
importFrom(stats,predict)
importFrom(stats,sd)
importFrom(stats,terms)
importFrom(stats,update)
importFrom(utils,capture.output)
================================================
FILE: R/augsynth.R
================================================
################################################################################
## Main functions for single-period treatment augmented synthetic controls Method
################################################################################
#' Fit Augmented SCM
#'
#' @param form outcome ~ treatment | auxillary covariates
#' @param unit Name of unit column
#' @param time Name of time column
#' @param t_int Time of intervention
#' @param data Panel data as dataframe
#' @param progfunc What function to use to impute control outcomes
#' ridge=Ridge regression (allows for standard errors),
#' none=No outcome model,
#' en=Elastic Net, RF=Random Forest, GSYN=gSynth,
#' mcp=MCPanel,
#' cits=Comparitive Interuppted Time Series
#' causalimpact=Bayesian structural time series with CausalImpact
#' @param scm Whether the SCM weighting function is used
#' @param fixedeff Whether to include a unit fixed effect, default F
#' @param cov_agg Covariate aggregation functions, if NULL then use mean with NAs omitted
#' @param ... optional arguments for outcome model
#'
#' @return augsynth object that contains:
#' \itemize{
#' \item{"weights"}{Ridge ASCM weights}
#' \item{"l2_imbalance"}{Imbalance in pre-period outcomes, measured by the L2 norm}
#' \item{"scaled_l2_imbalance"}{L2 imbalance scaled by L2 imbalance of uniform weights}
#' \item{"mhat"}{Outcome model estimate}
#' \item{"data"}{Panel data as matrices}
#' }
#' @export
single_augsynth <- function(form, unit, time, t_int, data,
progfunc = "ridge",
scm=T,
fixedeff = FALSE,
cov_agg=NULL, ...) {
call_name <- match.call()
form <- Formula::Formula(form)
unit <- enquo(unit)
time <- enquo(time)
## format data
outcome <- terms(formula(form, rhs=1))[[2]]
trt <- terms(formula(form, rhs=1))[[3]]
wide <- format_data(outcome, trt, unit, time, t_int, data)
synth_data <- do.call(format_synth, wide)
treated_units <- data %>% filter(!!trt == 1) %>% distinct(!!unit) %>% pull(!!unit)
control_units <- data %>% filter(!(!!unit %in% treated_units)) %>%
distinct(!!unit) %>% arrange(!!unit) %>% pull(!!unit)
## add covariates
if(length(form)[2] == 2) {
Z <- extract_covariates(form, unit, time, t_int, data, cov_agg)
} else {
Z <- NULL
}
# fit augmented SCM
augsynth <- fit_augsynth_internal(wide, synth_data, Z, progfunc,
scm, fixedeff, ...)
# add some extra data
augsynth$data$time <- data %>% distinct(!!time) %>%
arrange(!!time) %>% pull(!!time)
augsynth$call <- call_name
augsynth$t_int <- t_int
augsynth$weights <- matrix(augsynth$weights)
rownames(augsynth$weights) <- control_units
return(augsynth)
}
#' Internal function to fit augmented SCM
#' @param wide Data formatted from format_data
#' @param synth_data Data formatted from foramt_synth
#' @param Z Matrix of auxiliary covariates
#' @param progfunc outcome model to use
#' @param scm Whether to fit SCM
#' @param fixedeff Whether to de-mean synth
#' @param V V matrix for Synth, default NULL
#' @param ... Extra args for outcome model
#'
#' @noRd
#'
fit_augsynth_internal <- function(wide, synth_data, Z, progfunc,
scm, fixedeff, V = NULL, ...) {
n <- nrow(wide$X)
t0 <- ncol(wide$X)
ttot <- t0 + ncol(wide$y)
if(fixedeff) {
demeaned <- demean_data(wide, synth_data)
fit_wide <- demeaned$wide
fit_synth_data <- demeaned$synth_data
mhat <- demeaned$mhat
} else {
fit_wide <- wide
fit_synth_data <- synth_data
mhat <- matrix(0, n, ttot)
}
if (is.null(progfunc)) {
progfunc = "none"
}
progfunc = tolower(progfunc)
## fit augsynth
if(progfunc == "ridge") {
# Ridge ASCM
augsynth <- do.call(fit_ridgeaug_formatted,
list(wide_data = fit_wide,
synth_data = fit_synth_data,
Z = Z, V = V, scm = scm, ...))
} else if(progfunc == "none") {
## Just SCM
augsynth <- do.call(fit_ridgeaug_formatted,
c(list(wide_data = fit_wide,
synth_data = fit_synth_data,
Z = Z, ridge = F, scm = T, V = V, ...)))
} else {
## Other outcome models
progfuncs = c("ridge", "none", "en", "rf", "gsyn", "mcp",
"cits", "causalimpact", "seq2seq")
if (progfunc %in% progfuncs) {
augsynth <- fit_augsyn(fit_wide, fit_synth_data,
progfunc, scm, ...)
} else {
stop("progfunc must be one of 'EN', 'RF', 'GSYN', 'MCP', 'CITS', 'CausalImpact', 'seq2seq', 'None'")
}
}
augsynth$mhat <- mhat + cbind(matrix(0, nrow = n, ncol = t0),
augsynth$mhat)
augsynth$data <- wide
augsynth$data$Z <- Z
augsynth$data$synth_data <- synth_data
augsynth$progfunc <- progfunc
augsynth$scm <- scm
augsynth$fixedeff <- fixedeff
augsynth$extra_args <- list(...)
if(progfunc == "ridge") {
augsynth$extra_args$lambda <- augsynth$lambda
} else if(progfunc == "gsyn") {
augsynth$extra_args$r <- ncol(augsynth$params$factor)
augsynth$extra_args$CV <- 0
}
##format output
class(augsynth) <- "augsynth"
return(augsynth)
}
#' Get prediction of ATT or average outcome under control
#' @param object augsynth object
#' @param att If TRUE, return the ATT, if FALSE, return imputed counterfactual
#' @param ... Optional arguments
#'
#' @return Vector of predicted post-treatment control averages
#' @export
predict.augsynth <- function(object, att = F, ...) {
# if ("att" %in% names(list(...))) {
# att <- list(...)$att
# } else {
# att <- F
# }
augsynth <- object
X <- augsynth$data$X
y <- augsynth$data$y
comb <- cbind(X, y)
trt <- augsynth$data$trt
mhat <- augsynth$mhat
m1 <- colMeans(mhat[trt==1,,drop=F])
resid <- (comb[trt==0,,drop=F] - mhat[trt==0,drop=F])
y0 <- m1 + t(resid) %*% augsynth$weights
if(att) {
return(colMeans(comb[trt == 1,, drop = F]) - c(y0))
} else {
rnames <- rownames(y0)
y0_vec <- c(y0)
names(y0_vec) <- rnames
return(y0_vec)
}
}
#' Print function for augsynth
#' @param x augsynth object
#' @param ... Optional arguments
#' @export
print.augsynth <- function(x, ...) {
augsynth <- x
## straight from lm
cat("\nCall:\n", paste(deparse(augsynth$call), sep="\n", collapse="\n"), "\n\n", sep="")
## print att estimates
tint <- ncol(augsynth$data$X)
ttotal <- tint + ncol(augsynth$data$y)
att_post <- predict(augsynth, att = T)[(tint + 1):ttotal]
cat(paste("Average ATT Estimate: ",
format(round(mean(att_post),3), nsmall = 3), "\n\n", sep=""))
}
#' Plot function for augsynth
#' @importFrom graphics plot
#'
#' @param x Augsynth object to be plotted
#' @param inf Boolean, whether to get confidence intervals around the point estimates
#' @param cv If True, plot cross validation MSE against hyper-parameter, otherwise plot effects
#' @param ... Optional arguments
#' @export
plot.augsynth <- function(x, inf = T, cv = F, ...) {
# if ("se" %in% names(list(...))) {
# se <- list(...)$se
# } else {
# se <- T
# }
augsynth <- x
if (cv == T) {
errors = data.frame(lambdas = augsynth$lambdas,
errors = augsynth$lambda_errors,
errors_se = augsynth$lambda_errors_se)
p <- ggplot2::ggplot(errors, ggplot2::aes(x = lambdas, y = errors)) +
ggplot2::geom_point(size = 2) +
ggplot2::geom_errorbar(
ggplot2::aes(ymin = errors,
ymax = errors + errors_se),
width=0.2, size = 0.5)
p <- p + ggplot2::labs(title = bquote("Cross Validation MSE over " ~ lambda),
x = expression(lambda), y = "Cross Validation MSE",
parse = TRUE)
p <- p + ggplot2::scale_x_log10()
# find minimum and min + 1se lambda to plot
min_lambda <- choose_lambda(augsynth$lambdas,
augsynth$lambda_errors,
augsynth$lambda_errors_se,
F)
min_1se_lambda <- choose_lambda(augsynth$lambdas,
augsynth$lambda_errors,
augsynth$lambda_errors_se,
T)
min_lambda_index <- which(augsynth$lambdas == min_lambda)
min_1se_lambda_index <- which(augsynth$lambdas == min_1se_lambda)
p <- p + ggplot2::geom_point(
ggplot2::aes(x = min_lambda,
y = augsynth$lambda_errors[min_lambda_index]),
color = "gold")
p + ggplot2::geom_point(
ggplot2::aes(x = min_1se_lambda,
y = augsynth$lambda_errors[min_1se_lambda_index]),
color = "gold") +
ggplot2::theme_bw()
} else {
plot(summary(augsynth, ...), inf = inf)
}
}
#' Summary function for augsynth
#' @param object augsynth object
#' @param inf Boolean, whether to get confidence intervals around the point estimates
#' @param inf_type Type of inference algorithm. Options are
#' \itemize{
#' \item{"conformal"}{Conformal inference (default)}
#' \item{"jackknife+"}{Jackknife+ algorithm over time periods}
#' \item{"jackknife"}{Jackknife over units}
#' }
#' @param linear_effect Boolean, whether to invert the conformal inference hypothesis test to get confidence intervals for a linear-in-time treatment effect: intercept + slope * time
#' @param ... Optional arguments for inference, for more details for each `inf_type` see
#' \itemize{
#' \item{"conformal"}{`conformal_inf`}
#' \item{"jackknife+"}{`time_jackknife_plus`}
#' \item{"jackknife"}{`jackknife_se_single`}
#' }
#' @export
summary.augsynth <- function(object, inf = T, inf_type = "conformal",
linear_effect = F,
...) {
augsynth <- object
summ <- list()
t0 <- ncol(augsynth$data$X)
t_final <- t0 + ncol(augsynth$data$y)
if(inf) {
if(inf_type == "jackknife") {
att_se <- jackknife_se_single(augsynth)
} else if(inf_type == "jackknife+") {
att_se <- time_jackknife_plus(augsynth, ...)
} else if(inf_type == "conformal") {
att_se <- conformal_inf(augsynth, ...)
# get CIs for linear treatment effects
if(linear_effect) {
att_linear <- conformal_inf_linear(augsynth, ...)
}
} else {
stop(paste(inf_type, "is not a valid choice of 'inf_type'"))
}
att <- data.frame(Time = augsynth$data$time,
Estimate = att_se$att[1:t_final])
if(inf_type == "jackknife") {
att$Std.Error <- att_se$se[1:t_final]
att_avg_se <- att_se$se[t_final + 1]
} else {
att_avg_se <- NA
}
att_avg <- att_se$att[t_final + 1]
if(inf_type %in% c("jackknife+", "nonpar_bs", "t_dist", "conformal")) {
att$lower_bound <- att_se$lb[1:t_final]
att$upper_bound <- att_se$ub[1:t_final]
}
if(inf_type == "conformal") {
att$p_val <- att_se$p_val[1:t_final]
}
} else {
t0 <- ncol(augsynth$data$X)
t_final <- t0 + ncol(augsynth$data$y)
att_est <- predict(augsynth, att = T)
att <- data.frame(Time = augsynth$data$time,
Estimate = att_est)
att$Std.Error <- NA
att_avg <- mean(att_est[(t0 + 1):t_final])
att_avg_se <- NA
}
summ$att <- att
if(inf) {
if(inf_type %in% c("jackknife+")) {
summ$average_att <- data.frame(Value = "Average Post-Treatment Effect",
Estimate = att_avg, Std.Error = att_avg_se)
summ$average_att$lower_bound <- att_se$lb[t_final + 1]
summ$average_att$upper_bound <- att_se$ub[t_final + 1]
summ$alpha <- att_se$alpha
}
if(inf_type == "conformal") {
# summ$average_att$p_val <- att_se$p_val[t_final + 1]
# summ$average_att$lower_bound <- att_se$lb[t_final + 1]
# summ$average_att$upper_bound <- att_se$ub[t_final + 1]
# summ$alpha <- att_se$alpha
if(linear_effect) {
summ$average_att <- data.frame(
Value = c("Average Post-Treatment Effect",
"Treatment Effect Intercept",
"Treatment Effect Slope"),
Estimate = c(att_avg, att_linear$est_int,
att_linear$est_slope),
Std.Error = c(att_avg_se, NA, NA),
p_val = c(att_se$p_val[t_final + 1], NA, NA),
lower_bound = c(att_se$lb[t_final + 1],
att_linear$ci_int[1],
att_linear$ci_slope[1]),
upper_bound = c(att_se$ub[t_final + 1],
att_linear$ci_int[2],
att_linear$ci_slope[2])
)
} else {
summ$average_att <- data.frame(
Value = c("Average Post-Treatment Effect"),
Estimate = att_avg,
Std.Error = att_avg_se,
p_val = att_se$p_val[t_final + 1],
lower_bound = att_se$lb[t_final + 1],
upper_bound = att_se$ub[t_final + 1]
)
}
summ$alpha <- att_se$alpha
}
} else {
summ$average_att <- data.frame(Value = "Average Post-Treatment Effect",
Estimate = att_avg, Std.Error = att_avg_se)
}
summ$t_int <- augsynth$t_int
summ$call <- augsynth$call
summ$l2_imbalance <- augsynth$l2_imbalance
summ$scaled_l2_imbalance <- augsynth$scaled_l2_imbalance
if(!is.null(augsynth$covariate_l2_imbalance)) {
summ$covariate_l2_imbalance <- augsynth$covariate_l2_imbalance
summ$scaled_covariate_l2_imbalance <- augsynth$scaled_covariate_l2_imbalance
}
## get estimated bias
if(tolower(augsynth$progfunc) == "ridge") {
mhat <- augsynth$ridge_mhat
w <- augsynth$synw
} else {
mhat <- augsynth$mhat
w <- augsynth$weights
}
trt <- augsynth$data$trt
m1 <- colMeans(mhat[trt==1,,drop=F])
if(tolower(augsynth$progfunc) == "none" | (!augsynth$scm)) {
summ$bias_est <- NA
} else {
summ$bias_est <- m1 - t(mhat[trt==0,,drop=F]) %*% w
}
summ$inf_type <- if(inf) inf_type else "None"
class(summ) <- "summary.augsynth"
return(summ)
}
#' Print function for summary function for augsynth
#' @param x summary object
#' @param ... Optional arguments
#' @export
print.summary.augsynth <- function(x, ...) {
summ <- x
## straight from lm
cat("\nCall:\n", paste(deparse(summ$call), sep="\n", collapse="\n"), "\n\n", sep="")
t_final <- nrow(summ$att)
## distinction between pre and post treatment
att_est <- summ$att$Estimate
t_total <- length(att_est)
t_int <- summ$att %>% filter(Time <= summ$t_int) %>% nrow()
att_pre <- att_est[1:(t_int-1)]
att_post <- att_est[t_int:t_total]
out_msg <- ""
# print out average post treatment estimate
att_post <- summ$average_att$Estimate[1]
se_est <- summ$att$Std.Error
if(summ$inf_type == "jackknife") {
se_avg <- summ$average_att$Std.Error
out_msg <- paste("Average ATT Estimate (Jackknife Std. Error): ",
format(round(att_post,3), nsmall=3),
" (",
format(round(se_avg,3)), ")\n")
inf_type <- "Jackknife over units"
} else if(summ$inf_type == "conformal") {
p_val <- summ$average_att$p_val[1]
out_msg <- paste("Average ATT Estimate (p Value for Joint Null): ",
format(att_post, digits = 3),
" (",
format(p_val, digits = 2), ")\n")
inf_type <- "Conformal inference"
if("Treatment Effect Slope" %in% summ$average_att$Value) {
lowers <- summ$average_att$lower_bound[2:3]
uppers <- summ$average_att$upper_bound[2:3]
out_msg_line2 <- paste0("Confidence intervals for linear-in-time treatment effects (Intercept + Slope * Time)\n",
"\tIntercept: [", format(lowers[1], digits = 3), ",",
format(uppers[1], digits = 3), "]\n",
"\tSlope: [", format(lowers[2], digits = 3), ",",
format(uppers[2], digits = 3), "]\n")
out_msg <- paste0(out_msg, out_msg_line2)
}
} else if(summ$inf_type == "jackknife+") {
out_msg <- paste("Average ATT Estimate: ",
format(round(att_post,3), nsmall=3), "\n")
inf_type <- "Jackknife+ over time periods"
} else {
out_msg <- paste("Average ATT Estimate: ",
format(round(att_post,3), nsmall=3), "\n")
inf_type <- "None"
}
out_msg <- paste(out_msg,
"L2 Imbalance: ",
format(round(summ$l2_imbalance,3), nsmall=3), "\n",
"Percent improvement from uniform weights: ",
format(round(1 - summ$scaled_l2_imbalance,3)*100), "%\n\n",
sep="")
if(!is.null(summ$covariate_l2_imbalance)) {
out_msg <- paste(out_msg,
"Covariate L2 Imbalance: ",
format(round(summ$covariate_l2_imbalance,3),
nsmall=3),
"\n",
"Percent improvement from uniform weights: ",
format(round(1 - summ$scaled_covariate_l2_imbalance,3)*100),
"%\n\n",
sep="")
}
out_msg <- paste(out_msg,
"Avg Estimated Bias: ",
format(round(mean(summ$bias_est), 3),nsmall=3), "\n\n",
"Inference type: ",
inf_type,
"\n\n",
sep="")
cat(out_msg)
if(summ$inf_type == "jackknife") {
out_att <- summ$att[t_int:t_final,] %>%
select(Time, Estimate, Std.Error)
} else if(summ$inf_type == "conformal") {
out_att <- summ$att[t_int:t_final,] %>%
select(Time, Estimate, lower_bound, upper_bound, p_val)
names(out_att) <- c("Time", "Estimate",
paste0((1 - summ$alpha) * 100, "% CI Lower Bound"),
paste0((1 - summ$alpha) * 100, "% CI Upper Bound"),
paste0("p Value"))
} else if(summ$inf_type == "jackknife+") {
out_att <- summ$att[t_int:t_final,] %>%
select(Time, Estimate, lower_bound, upper_bound)
names(out_att) <- c("Time", "Estimate",
paste0((1 - summ$alpha) * 100, "% CI Lower Bound"),
paste0((1 - summ$alpha) * 100, "% CI Upper Bound"))
} else {
out_att <- summ$att[t_int:t_final,] %>%
select(Time, Estimate)
}
out_att %>%
mutate_at(vars(-Time), ~ round(., 3)) %>%
print(row.names = F)
}
#' Plot function for summary function for augsynth
#' @param x Summary object
#' @param inf Boolean, whether to plot confidence intervals
#' @param ... Optional arguments
#' @export
plot.summary.augsynth <- function(x, inf = T, ...) {
summ <- x
# if ("inf" %in% names(list(...))) {
# inf <- list(...)$inf
# } else {
# inf <- T
# }
p <- summ$att %>%
ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate))
if(inf) {
if(all(is.na(summ$att$lower_bound))) {
p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin=Estimate-2*Std.Error,
ymax=Estimate+2*Std.Error),
alpha=0.2)
} else {
p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin=lower_bound,
ymax=upper_bound),
alpha=0.2)
}
}
p + ggplot2::geom_line() +
ggplot2::geom_vline(xintercept=summ$t_int, lty=2) +
ggplot2::geom_hline(yintercept=0, lty=2) +
ggplot2::theme_bw()
}
#' augsynth
#'
#' @description A package implementing the Augmented Synthetic Controls Method
#' @docType package
#' @name augsynth-package
#' @importFrom magrittr "%>%"
#' @importFrom purrr reduce
#' @import dplyr
#' @import tidyr
#' @importFrom stats terms
#' @importFrom stats formula
#' @importFrom stats update
#' @importFrom stats delete.response
#' @importFrom stats model.matrix
#' @importFrom stats model.frame
#' @importFrom stats na.omit
NULL
================================================
FILE: R/augsynth_pre.R
================================================
################################################################################
## Main function for the augmented synthetic controls Method
################################################################################
#' Fit Augmented SCM
#' @param form outcome ~ treatment | auxillary covariates
#' @param unit Name of unit column
#' @param time Name of time column
#' @param data Panel data as dataframe
#' @param t_int Time of intervention (used for single-period treatment only)
#' @param ... Optional arguments
#' \itemize{
#' \item Single period augsynth with/without multiple outcomes
#' \itemize{
#' \item{"progfunc"}{What function to use to impute control outcomes: Ridge=Ridge regression (allows for standard errors), None=No outcome model, EN=Elastic Net, RF=Random Forest, GSYN=gSynth, MCP=MCPanel, CITS=CITS, CausalImpact=Bayesian structural time series with CausalImpact, seq2seq=Sequence to sequence learning with feedforward nets}
#' \item{"scm"}{Whether the SCM weighting function is used}
#' \item{"fixedeff"}{Whether to include a unit fixed effect, default F }
#' \item{"cov_agg"}{Covariate aggregation functions, if NULL then use mean with NAs omitted}
#' }
#' \item Multi period (staggered) augsynth
#' \itemize{
#' \item{"relative"}{Whether to compute balance by relative time}
#' \item{"n_leads"}{How long past treatment effects should be estimated for}
#' \item{"n_lags"}{Number of pre-treatment periods to balance, default is to balance all periods}
#' \item{"alpha"}{Fraction of balance for individual balance}
#' \item{"lambda"}{Regularization hyperparameter, default = 0}
#' \item{"force"}{Include "none", "unit", "time", "two-way" fixed effects. Default: "two-way"}
#' \item{"n_factors"}{Number of factors for interactive fixed effects, default does CV}
#' }
#' }
#'
#' @return augsynth object that contains:
#' \itemize{
#' \item{"weights"}{weights}
#' \item{"data"}{Panel data as matrices}
#' }
#' @export
#'
augsynth <- function(form, unit, time, data, t_int=NULL, ...) {
call_name <- match.call()
form <- Formula::Formula(form)
unit_quosure <- enquo(unit)
time_quosure <- enquo(time)
## format data
outcome <- terms(formula(form, rhs=1))[[2]]
trt <- terms(formula(form, rhs=1))[[3]]
# check for multiple outcomes
multi_outcome <- length(outcome) != 1
## get first treatment times
trt_time <- data %>%
group_by(!!unit_quosure) %>%
filter(!all(!!trt == 0)) %>%
summarise(trt_time = min((!!time_quosure)[(!!trt) == 1])) %>%
mutate(trt_time = replace_na(as.numeric(trt_time), Inf))
num_trt_years <- sum(is.finite(unique(trt_time$trt_time)))
if(multi_outcome & num_trt_years > 1) {
stop("augsynth is not currently implemented for more than one outcome and more than one treated unit")
} else if(num_trt_years > 1) {
message("More than one treatment time found. Running multisynth.")
if("progfunc" %in% names(list(...))) {
warning("`progfunc` is not an argument for multisynth, so it is ignored")
}
return(multisynth(form, !!enquo(unit), !!enquo(time), data, ...))
} else {
if (is.null(t_int)) {
t_int <- trt_time %>% filter(is.finite(trt_time)) %>%
summarise(t_int = max(trt_time)) %>% pull(t_int)
}
if(!multi_outcome) {
message("One outcome and one treatment time found. Running single_augsynth.")
return(single_augsynth(form, !!enquo(unit), !!enquo(time), t_int,
data = data, ...))
} else {
message("Multiple outcomes and one treatment time found. Running augsynth_multiout.")
return(augsynth_multiout(form, !!enquo(unit), !!enquo(time), t_int,
data = data, ...))
}
}
}
================================================
FILE: R/cv.R
================================================
drop_time_t <- function(wide_data, Z, t_drop) {
new_wide_data <- list()
new_wide_data$trt <- wide_data$trt
if (is.list(wide_data$X)) {
# TODO
} else {
new_wide_data$X <- wide_data$X[, -t_drop, drop = F]
new_wide_data$y <- cbind(wide_data$X[, t_drop, drop = F],
wide_data$y)
X0 <- new_wide_data$X[new_wide_data$trt == 0,, drop = F]
x1 <- matrix(colMeans(new_wide_data$X[new_wide_data$trt == 1,, drop = F]),
ncol=1)
y0 <- new_wide_data$y[new_wide_data$trt == 0,, drop = F]
y1 <- colMeans(new_wide_data$y[new_wide_data$trt == 1,, drop = F])
new_synth_data <- list()
new_synth_data$Z0 <- t(X0)
new_synth_data$X0 <- t(X0)
new_synth_data$Z1 <- x1
new_synth_data$X1 <- x1
return(list(wide_data = new_wide_data,
synth_data = new_synth_data,
Z = Z))
}
}
drop_time_and_refit <- function(wide_data, Z, t_drop, progfunc, scm, fixedeff, ...) {
new_data <- drop_time_t(wide_data, Z, t_drop)
new_ascm <- do.call(fit_augsynth_internal,
c(list(wide = new_data$wide,
synth_data = new_data$synth_data,
Z = new_data$Z,
progfunc = progfunc,
scm = scm,
fixedeff = fixedeff, ...)))
return(new_ascm)
}
cv_internal <- function(wide_data, Z, progfunc, scm, fixedeff, lambdas, holdout_periods, ...) {
X <- wide_data$X
lambda_error_vals <- vapply(lambdas, function(lambda){
errors <- apply(holdout_periods, 1, function(t_drop){
new_ascm <- drop_time_and_refit(wide_data, Z, t_drop, progfunc, scm, fixedeff, lambda = lambda, ...)
err <- sum((predict(new_ascm, att = T)[(ncol(X)-length(t_drop)+1):ncol(X)])^2)
err
})
lambda_error <- mean(errors)
lambda_error_se <- sd(errors) / sqrt(length(errors))
c(lambda_error, lambda_error_se)
}, numeric(2))
return(list(lambda_errors = lambda_error_vals[1,], lambda_errors_se = lambda_error_vals[2,]))
}
cv_ridge <- function(wide_data, synth_data, Z, progfunc, scm, fixedeff, how = 'time', holdout_length = 1, lambdas = NULL,
lambda_min_ratio = 1e-8, n_lambda = 20, lambda_max = NULL, min_1se = T, V = NULL, ...) {
X <- wide_data$X
trt <- wide_data$trt
if (is.null(lambdas)) {
if(is.null(lambda_max)) {
X_cent <- apply(X, 2, function(x) x - mean(x[trt==0]))
X_c <- X_cent[trt==0,,drop=FALSE]
t0 <- ncol(X_c)
if(is.null(V)) {
V <- diag(rep(1, t0))
} else if(is.vector(V)) {
V <- diag(V)
} else if(ncol(V) == 1 & nrow(V) == t0) {
V <- diag(c(V))
} else if(ncol(V) == t0 & nrow(V) == 1) {
V <- diag(c(V))
} else if(nrow(V) == t0) {
} else {
stop("`V` must be a vector with t0 elements or a t0xt0 matrix")
}
X_c <- X_c %*% V
if(!is.null(Z)) {
Z_cent <- apply(Z, 2, function(x) x - mean(x[trt==0]))
Z_c <- Z_cent[trt==0,,drop=FALSE]
Xc_hat <- Z_c %*% solve(t(Z_c) %*% Z_c) %*% t(Z_c) %*% X_c
res_c <- X_c - Xc_hat
X_c <- res_c
}
lambda_max <- svd(X_c)$d[1] ^ 2
}
lambdas <- create_lambda_list(lambda_max, lambda_min_ratio, n_lambda)
}
if (how == 'time') {
period_starts <- 1:(ncol(X) - holdout_length)
if (holdout_length == 1) {
holdout_periods <- matrix(period_starts, nrow = length(period_starts), ncol = 1)
} else {
holdout_periods <- t(vapply(period_starts, function(t) t:(t+holdout_length-1), numeric(holdout_length)))
}
results <- cv_internal(wide_data, Z, progfunc, scm, fixedeff, lambdas, holdout_periods, ...)
lambda <- choose_lambda(lambdas, results$lambda_errors, results$lambda_errors_se, min_1se)
return(list(lambda = lambda, lambdas = lambdas, lambda_errors = results$lambda_errors, lambda_errors_se = results$lambda_errors_se))
}
}
================================================
FILE: R/data.R
================================================
#' Economic indicators for US states from 1990-2016
#'
#'
#' @format A dataframe with 5250 rows and 32 variables:
#' \describe{
#' \item{fips}{FIPS code for each state}
#' \item{year}{Year of measurement}
#' \item{qtr}{Quarter (1-4) of measurement}
#' \item{state}{Name of State}
#' \item{gdp}{Gross State Product (millions of $) Values before 2005 are linearly interpolated between years}
#' \item{revenuepop}{State and local revenue per capita}
#' \item{rev_state_total}{State total general revenue (millions of $)}
#' \item{rev_local_total}{Local total general revenue (millions of $)}
#' \item{popestimate}{Population estimate}
#' \item{qtrly_estabs_count}{Count of establishments for a given quarter}
#' \item{month1_emplvl, month2_emplvl, month3_emplvl}{ Employment level for first, second, and third months of a given quarter}
#' \item{total_qtrly_wages}{Total wages for a givne quarter}
#' \item{taxable_qtrly_wage}{Taxable wages for a given quarter}
#' \item{avg_wkly_wage}{Average weekly wage for a given quarter}
#' \item{year_qtr}{Year and quarter combined into one continuous variable}
#' \item{treated}{Whether the state passed tax cuts before the given year and quareter}
#' \item{lngdpcapita}{Natural log of GDP per capita}
#' \item{emplvlcapita}{Average employment level per capita}
#' \item{Xcapita}{Per capita value of X}
#' \item{abb}{State abbreviation}
#' }
"kansas"
================================================
FILE: R/eligible_donors.R
================================================
##############################################################################
## Code to get eligible donor units based on covariates
##############################################################################
get_donors <- function(X, y, trt, Z, time_cohort, n_lags,
n_leads, how = "knn",
exact_covariates = NULL, ...) {
# first get eligible donors by treatment time
donors <- get_eligible_donors(trt, time_cohort, n_leads)
# get donors with no NA values
nona_donors <- get_nona_donors(X, y, trt, n_lags, n_leads, time_cohort)
donors <- lapply(1:length(donors),
function(j) donors[[j]] & nona_donors[[j]])
# if Z isn't NULL, futher restrict the donors by matching
if(!is.null(Z)) {
if(ncol(Z) != 0) {
donors <- get_matched_donors(trt, Z, donors, how, exact_covariates, ...)
}
}
return(donors)
}
get_eligible_donors <- function(trt, time_cohort, n_leads) {
# get treatment times
if(time_cohort) {
grps <- unique(trt[is.finite(trt)])
} else {
grps <- trt[is.finite(trt)]
}
J <- length(grps)
# only allow weights on donors treated after n_leads
donors <- lapply(1:J, function(j) trt > n_leads + grps[j])
return(donors)
}
#' Get donors that don't have missing outcomes where treated units have outcomes
get_nona_donors <- function(X, y, trt, n_lags, n_leads, time_cohort) {
n <- length(trt)
# find na treatment times
fulldat <- cbind(X, y)
is_na <- is.na(fulldat[is.finite(trt), , drop = F])
# aggregate by time cohort
if(time_cohort) {
grps <- unique(trt[is.finite(trt)])
# if doing a time cohort, convert the boolean mask
finite_trt <- trt[is.finite(trt)]
is_na <- t(sapply(grps,
function(tj) apply(is_na[finite_trt == tj, , drop = F],
2, all)))
} else {
grps <- trt[is.finite(trt)]
}
not_na <- !is.na(fulldat)
J <- length(grps)
lapply(1:J,
function(j) {
idxs <- max(grps[j] - n_lags + 1, 1):min(grps[j] + n_leads,
ncol(fulldat))
isna_j <- is_na[j, idxs]
apply(not_na[, idxs, drop = F][, !isna_j, drop = F], 1, all)
}) -> donors
return(donors)
}
get_matched_donors <- function(trt, Z, donors, how, exact_covariates = NULL, k = NULL, ...) {
J <- sum(is.finite(trt))
trt_idx <- which(is.finite(trt))
if(is.null(exact_covariates)) {
if(how == "exact") {
return(
lapply(1:J,
function(j) donors[[j]] & apply(t(Z) == Z[trt_idx[j], ], 2, all)
)
)
} else if(how == "knn") {
return(get_knn_donors(trt, Z, donors, k))
} else {
stop("Option for exact matching must be in ('exact', 'knn')")
}
} else {
if(how == "exact") {
return(
lapply(1:J,
function(j) donors[[j]] & apply(t(Z) == Z[trt_idx[j],
exact_covariates], 2, all)
)
)
} else if(how == "knn") {
donors <- lapply(1:J,
function(j) { donors[[j]] &
apply(t(Z[, exact_covariates, drop = F]) ==
Z[trt_idx[j],exact_covariates], 2, all)
}
)
approx_covs <- which(!colnames(Z) %in% exact_covariates)
if(length(approx_covs != 0)) {
return(get_knn_donors(trt, Z[, approx_covs, drop = F], donors, k))
} else {
return(donors)
}
} else {
stop("Option for exact matching must be in ('exact', 'knn')")
}
}
}
get_knn_donors <- function(trt, Z, donors, k) {
if(is.null(k)) {
stop("Number of neighbors for knn not selected, please choose k.")
}
# knn matching within time cohort
trt_idxs <- which(is.finite(trt))
lapply(1:length(trt_idxs),
function(j) {
idx <- trt_idxs[j]
# idxs for treated units treated at time tj
Z_tj <- Z[idx, , drop = F]
# get donors for treated cohort
donors_tj <- donors[[j]]
Z_donors_tj <- Z[donors_tj, , drop = F]
# check that k is less than the number of donors
# if not, warn and set k to be the number of donors - 1
if(k >= nrow(Z_donors_tj)) {
warning(paste("Number of potential donor units is less than",
"the number of required matches,",
"returning all donors as matches"))
return(donors_tj)
}
# do knn matching
nn <- FNN::get.knnx(data = Z_donors_tj, query = Z_tj, k = k)
# keep track of which indices these are
donors_j <- logical(length(donors_tj))
true_idx <- which(donors_tj)[nn$nn.index[1, ]]
donors_j[true_idx] <- TRUE
return(donors_j)
}) -> matches
names(matches) <- trt_idxs
return(matches)
}
================================================
FILE: R/fit_synth.R
================================================
#######################################################
# Helper scripts to fit synthetic controls to simulations
#######################################################
#' Make a V matrix from a vector (or null)
make_V_matrix <- function(t0, V) {
if(is.null(V)) {
V <- diag(rep(1, t0))
} else if(is.vector(V)) {
if(length(V) != t0) {
stop(paste("`V` must be a vector with", t0, "elements or a", t0,
"x", t0, "matrix"))
}
V <- diag(V)
} else if(ncol(V) == 1 & nrow(V) == t0) {
V <- diag(c(V))
} else if(ncol(V) == t0 & nrow(V) == 1) {
V <- diag(c(V))
} else if(nrow(V) == t0) {
} else {
stop(paste("`V` must be a vector with", t0, "elements or a", t0,
"x", t0, "matrix"))
}
return(V)
}
#' Fit synthetic controls on outcomes after formatting data
#' @param synth_data Panel data in format of Synth::dataprep
#' @param V Matrix to scale the obejctive by
#' @noRd
#' @return \itemize{
#' \item{"weights"}{Synth weights}
#' \item{"l2_imbalance"}{Imbalance in pre-period outcomes, measured by the L2 norm}
#' \item{"scaled_l2_imbalance"}{L2 imbalance scaled by L2 imbalance of uniform weights}
#' }
fit_synth_formatted <- function(synth_data, V = NULL) {
t0 <- dim(synth_data$Z0)[1]
## if no is supplied, set equal to 1
V <- make_V_matrix(t0, V)
weights <- synth_qp(synth_data$X1, t(synth_data$X0), V)
l2_imbalance <- sqrt(sum((synth_data$Z0 %*% weights - synth_data$Z1)^2))
## primal objective value scaled by least squares difference for mean
uni_w <- matrix(1/ncol(synth_data$Z0), nrow=ncol(synth_data$Z0), ncol=1)
unif_l2_imbalance <- sqrt(sum((synth_data$Z0 %*% uni_w - synth_data$Z1)^2))
scaled_l2_imbalance <- l2_imbalance / unif_l2_imbalance
return(list(weights=weights,
l2_imbalance=l2_imbalance,
scaled_l2_imbalance=scaled_l2_imbalance))
}
#' Solve the synth QP directly
#' @param X1 Target vector
#' @param X0 Matrix of control outcomes
#' @param V Scaling matrix
#' @noRd
synth_qp <- function(X1, X0, V) {
Pmat <- X0 %*% V %*% t(X0)
qvec <- - t(X1) %*% V %*% t(X0)
n0 <- nrow(X0)
A <- rbind(rep(1, n0), diag(n0))
l <- c(1, numeric(n0))
u <- c(1, rep(1, n0))
settings = osqp::osqpSettings(verbose = FALSE,
eps_rel = 1e-8,
eps_abs = 1e-8)
sol <- osqp::solve_osqp(P = Pmat, q = qvec,
A = A, l = l, u = u,
pars = settings)
return(sol$x)
}
================================================
FILE: R/format.R
================================================
################################################################################
## Scripts to format panel data into matrices
################################################################################
#' Format "long" panel data into "wide" program evaluation matrices
#' @param outcome Name of outcome column
#' @param trt Name of treatment column
#' @param unit Name of unit column
#' @param time Name of time column
#' @param t_int Time of intervention
#' @param data Panel data as dataframe
#' @noRd
#' @return \itemize{
#' \item{"X"}{Matrix of pre-treatment outcomes}
#' \item{"trt"}{Vector of treatment assignments}
#' \item{"y"}{Matrix of post-treatment outcomes}
#' }
format_data <- function(outcome, trt, unit, time, t_int, data) {
## pre treatment outcomes
X <- data %>%
filter(!!time < t_int) %>%
select(!!unit, !!time, !!outcome) %>%
spread(!!time, !!outcome) %>%
select(-!!unit) %>%
as.matrix()
## post treatment outcomes
y <- data %>%
filter(!!time >= t_int) %>%
select(!!unit, !!time, !!outcome) %>%
spread(!!time, !!outcome) %>%
select(-!!unit) %>%
as.matrix()
## treatment status
trt <- data %>%
select(!!unit, !!trt) %>%
group_by(!!unit) %>%
summarise(trt = max(!!trt)) %>%
ungroup() %>%
pull(trt)
return(list(X=X, trt=trt, y=y))
}
#' Format "long" panel data into "wide" program evaluation matrices
#' @param outcomes Vectors of names of outcome columns
#' @param trt Name of treatment column
#' @param unit Name of unit column
#' @param time Name of time column
#' @param t_int Time of intervention
#' @param data Panel data as dataframe
#' @noRd
#' @return \itemize{
#' \item{"X"}{List of matrices of pre-treatment outcomes}
#' \item{"trt"}{Vector of treatment assignments}
#' \item{"y"}{List of matrices of post-treatment outcomes}
#' }
format_data_multi <- function(outcomes, trt, unit, time, t_int, data) {
lapply(outcomes,
function(outcome) format_data(outcome, trt, unit,
time, t_int, data)
) -> formats
# X <- simplify2array(lapply(formats, function(x) x$X))
# y <- simplify2array(lapply(formats, function(x) x$y))
# X <- lapply(formats, function(x) t(na.omit(t(x$X))))
X <- lapply(formats, `[[`, "X")
y <- lapply(formats, function(x) t(na.omit(t(x$y))))
trt <- formats[[1]]$trt
return(list(X = X, trt = trt, y = y))
}
#' Format "long" panel data into "wide" program evaluation matrices with staggered adoption
#' @param outcome Name of outcome column
#' @param trt Name of treatment column
#' @param unit Name of unit column
#' @param time Name of time column
#' @param data Panel data as dataframe
#' @noRd
#' @return \itemize{
#' \item{"X"}{Matrix of pre-treatment outcomes}
#' \item{"trt"}{Vector of treatment assignments}
#' \item{"y"}{Matrix of post-treatment outcomes}
#' }
format_data_stag <- function(outcome, trt, unit, time, data) {
# arrange data by time first
data <- data %>% arrange(!!time)
## get first treatment times
trt_time <- data %>%
group_by(!!unit) %>%
summarise(trt_time=(!!time)[(!!trt) == 1][1]) %>%
mutate(trt_time=replace_na(as.numeric(trt_time), Inf))
t_int <- trt_time %>% filter(is.finite(trt_time)) %>%
summarise(t_int=max(trt_time)) %>% pull(t_int)
## ## boolean mask of available data for treatment groups
## mask <- data %>% inner_join(trt_time %>%
## filter(is.finite(trt_time))) %>%
## filter(!!time < t_int) %>%
## mutate(trt=1-!!trt) %>%
## select(!!unit, !!time, trt_time, trt) %>%
## spread(!!time, trt) %>%
## group_by(trt_time) %>%
## summarise_all(list(max)) %>%
## arrange(trt_time) %>%
## select(-trt_time, -!!unit) %>%
## as.matrix()
## boolean mask of available data for treatment groups
mask <- data %>% inner_join(trt_time %>%
filter(is.finite(trt_time)),
by = rlang::as_name(unit)) %>%
filter(!!time < t_int) %>%
mutate(trt=1-!!trt) %>%
select(!!unit, !!time, trt_time, trt) %>%
spread(!!time, trt) %>%
## arrange(!!unit) %>%
select(-trt_time, -!!unit) %>%
as.matrix()
# outcomes as a matrix
Xy <- data %>%
select(!!unit, !!time, !!outcome) %>%
spread(!!time, !!outcome) %>%
select(-!!unit) %>%
as.matrix()
pre_times <- data %>% filter(!!time < t_int) %>%
distinct(!!time) %>% pull(!!time)
post_times <- data %>% filter(!!time >= t_int) %>%
distinct(!!time) %>% pull(!!time)
X <- Xy[, as.character(pre_times), drop = F]
y <- Xy[, as.character(post_times), drop = F]
if(nrow(X) != nrow(y)) {
stop("There are not the same number of units after the last unit is treated as before the last unit is treated")
}
t_vec <- data %>% pull(!!time) %>% unique() %>% sort()
trt <- sapply(trt_time$trt_time, function(x) which(t_vec == x)-1) %>%
as.numeric() %>%
replace_na(Inf)
units <- data %>%
filter(!!time < t_int) %>%
select(!!unit, !!time, !!outcome) %>%
spread(!!time, !!outcome) %>%
pull(!!unit)
return(list(X=X,
trt=trt,
y=y,
mask=mask,
time = t_vec,
units=units))
}
#' Format program eval matrices into synth form
#'
#' @param X Matrix of pre-treatment outcomes
#' @param trt Vector of treatment assignments
#' @param y Matrix of post-treatment outcomes
#' @noRd
#' @return List with data formatted as Synth::dataprep
format_synth <- function(X, trt, y) {
synth_data <- list()
## pre-treatment values as covariates
synth_data$Z0 <- t(X[trt==0,,drop=F])
## average treated units together
synth_data$Z1 <- as.matrix((colMeans(X[trt==1,,drop=F])), ncol=1)
## combine everything together also
synth_data$Y0plot <- t(cbind(X[trt==0,,drop=F], y[trt==0,,drop=F]))
synth_data$Y1plot <- as.matrix(colMeans(
cbind(X[trt==1,,drop=F], y[trt==1,,drop=F])), ncol=1)
## predictors are pre-period outcomes
synth_data$X0 <- synth_data$Z0
synth_data$X1 <- synth_data$Z1
return(synth_data)
}
#' Remove unit means
#' @param wide_data X, y, trt
#' @param synth_data List with data formatted as Synth::dataprep
#' @noRd
demean_data <- function(wide_data, synth_data) {
# pre treatment means
means <- rowMeans(wide_data$X)
new_wide_data <- list()
new_X <- wide_data$X - means
trt <- wide_data$trt
new_wide_data$X <- new_X
new_wide_data$y <- wide_data$y - means
new_wide_data$trt <- trt
new_synth_data <- list()
new_synth_data$X0 <- t(new_X[trt == 0,, drop = FALSE])
new_synth_data$Z0 <- new_synth_data$X0
new_synth_data$X1 <- as.matrix((colMeans(new_X[trt==1,,drop = F])),
ncol = 1)
new_synth_data$Z1 <- new_synth_data$X1
# estimate post-treatment as pre-treatment means
mhat <- replicate(ncol(wide_data$X) + ncol(wide_data$y), means)
return(list(wide = new_wide_data,
synth_data = new_synth_data,
mhat = mhat))
}
#' Helper function to extract covariate matrix from data
#' @param form Formula as outcome ~ treatment | covariates
#' @param unit Name of unit column
#' @param time Name of time column
#' @param t_int Time of intervention
#' @param data Panel data as dataframe
#' @param cov_agg Covariate aggregation function
#' @noRd
extract_covariates <- function(form, unit, time, t_int, data, cov_agg) {
## if no aggregation functions, use the mean (omitting NAs)
if(is.null(cov_agg)) {
cov_agg <- c(function(x) mean(x, na.rm=T))
}
cov_form <- update(formula(delete.response(terms(form, rhs=2, data=data))),
~. - 1) ## ensure that there is no intercept
## pull out relevant covariates and aggregate
pre_data <- data %>%
filter(!! (time) < t_int)
model.matrix(cov_form,
model.frame(cov_form, pre_data,
na.action=NULL) ) %>%
data.frame() %>%
mutate(unit=pull(pre_data, !!unit)) %>%
group_by(unit) %>%
summarise_all(cov_agg) -> Z
# recombine with any missing units and convert to matrix
data %>% distinct(!!unit) %>%
rename(unit = !!unit) %>%
left_join(Z, by = "unit") %>%
arrange(unit) %>%
select(-unit) %>%
as.matrix() -> Z
if(nrow(distinct(data, !!unit)) != nrow(Z)) {
stop("Some units missing all covariate data")
}
# check if any covariates have no variation
Zsds <- apply(Z, 2, sd)
if(any(Zsds == 0)) {
zero_covs <- paste(colnames(Z)[Zsds == 0], collapse = ", ")
stop(paste("The following covariates have no variation across units:",
zero_covs))
}
return(Z)
}
#' Check that we can actually run multisynth on the data
#' @param wide Output of format_data_stag
#' @param fixedeff Whether to include a unit fixed effect
#' @param n_leads How long past treatment effects should be estimated for, default is number of post treatment periods for last treated unit
#' @param n_lags Number of pre-treatment periods to balance, default is to balance all periods
check_data_stag <- function(wide, fixedeff, n_leads, n_lags) {
# If there are less than 5 pre-treatment outcomes, give a warning
less_5 <- wide$units[wide$trt < 5]
warn_msg <- ""
if(length(less_5) != 0) {
warn_msg <- paste0(
warn_msg,
"The following units have less than 5 pre-treatment outcomes: (",
paste(less_5, collapse = ","),
"). Be cautious!"
)
}
# check if there are any always treated units
always_trt <- wide$units[wide$trt == 0]
# If including a fixed effect, check that there is more than one pretreatment
# outcome for each unit
n1 <- wide$units[wide$trt == 1]
err_msg <- ""
if(length(always_trt) != 0) {
err_msg <- paste0(
err_msg,
"The following units are always treated and should be removed: (",
paste(always_trt, collapse = ","),
")\n")
}
if(length(n1) != 0 & fixedeff) {
if(nchar(err_msg) > 0) {
err_msg <- paste0(err_msg, " Also: ")
}
err_msg <- paste0(
err_msg,
"You are including a fixed effect with `fixedeff = T`, but the ",
"following units only have one pre-treatment outcome: (",
paste(n1, collapse = ","),
"). Either remove these units or set `fixedeff = F`.\n"
)
}
# check if there are never treated units
if(max(wide$trt) < ncol(wide$X) + ncol(wide$y)) {
if(nchar(err_msg) > 0) {
err_msg <- paste0(err_msg, " Also: ")
}
err_msg <- paste0(
err_msg,
"All units are eventually treated. The last treatment time is ",
wide$time[max(wide$trt)],
". To run multisynth, remove all periods after this time. ",
"Units treated at this time will be considered 'never treated' in the ",
"narrowed sample.\n"
)
}
if(nchar(warn_msg) > 0) {
warning(warn_msg)
}
if(nchar(err_msg) > 0) {
stop(err_msg)
}
}
================================================
FILE: R/globalVariables.R
================================================
utils::globalVariables(c("time", "val", "post", "weight", ".", "Time",
"Estimate", "Std.Error", "Level", "last_time",
"is_avg", "label", "Outcome", "unit", "obs",
"lambdas", "errors_se",
"upper_bound", "lower_bound"))
================================================
FILE: R/highdim.R
================================================
################################################################################
## Methods to use flexible outcome models
################################################################################
##### Augmented SCM with general outcome models
#' Use zero weights, do nothing but output everything in the right way
#' @param synth_data Panel data in format of Synth::dataprep
#' @noRd
#' @return \itemize{
#' \item{"weights"}{Synth weights}
#' \item{"l2_imbalance"}{Imbalance in pre-period outcomes, measured by the L2 norm}
#' \item{"scaled_l2_imbalance"}{L2 imbalance scaled by L2 imbalance of uniform weights}
#' }
fit_zero_weights <- function(synth_data) {
## Imbalance is uniform weights imbalance
uni_w <- matrix(1/ncol(synth_data$Z0), nrow=ncol(synth_data$Z0), ncol=1)
unif_l2_imbalance <- sqrt(sum((synth_data$Z0 %*% uni_w - synth_data$Z1)^2))
scaled_l2_imbalance <- 1
return(list(weights=rep(0, ncol(synth_data$Z0)),
l2_imbalance=unif_l2_imbalance,
scaled_l2_imbalance=scaled_l2_imbalance))
}
#' Fit E[Y(0)|X] and for each post-period and balance pre-period
#'
#' @param wide_data Output of `format_ipw`
#' @param synth_data Output of `synth_data`
#' @param fit_progscore Function to fit prognostic score
#' @param fit_weights Function to fit synth weights
#' @param ... optional arguments for outcome model
#' @noRd
#' @return \itemize{
#' \item{"weights"}{Ridge ASCM weights}
#' \item{"l2_imbalance"}{Imbalance in pre-period outcomes, measured by the L2 norm}
#' \item{"scaled_l2_imbalance"}{L2 imbalance scaled by L2 imbalance of uniform weights}
#' \item{"mhat"}{Outcome model estimate}
#' }
fit_augsyn_formatted <- function(wide_data, synth_data,
fit_progscore, fit_weights, ...) {
X <- wide_data$X
y <- wide_data$y
trt <- wide_data$trt
## fit prognostic scores
fitout <- do.call(fit_progscore,
list(X=X, y=y, trt=trt, ...))
## fit synth
syn <- fit_weights(synth_data)
syn$params <- fitout$params
syn$mhat <- fitout$y0hat
return(syn)
}
#' Fit outcome model and balance pre-period
#' @param wide_data Output of `format_ipw`
#' @param synth_data Output of `synth_data`
#' @param progfunc What function to use to impute control outcomes
#' EN=Elastic Net, RF=Random Forest, GSYN=gSynth,
#' Comp=softImpute, MCP=MCPanel, CITS=CITS
#' CausalImpact=Bayesian structural time series with CausalImpact
#' seq2seq=Sequence to sequence learning with feedforward nets
#' @param scm Whether the SCM weighting function is used
#' @param ... optional arguments for outcome model
#' @noRd
#' @return \itemize{
#' \item{"weights"}{Ridge ASCM weights}
#' \item{"l2_imbalance"}{Imbalance in pre-period outcomes, measured by the L2 norm}
#' \item{"scaled_l2_imbalance"}{L2 imbalance scaled by L2 imbalance of uniform weights}
#' \item{"mhat"}{Outcome model estimate}
#' }
fit_augsyn <- function(wide_data, synth_data,
progfunc=c("EN", "RF", "GSYN", "MCP","CITS", "CausalImpact", "seq2seq"),
scm=T, ...) {
## prognostic score and weight functions to use
progfunc = tolower(progfunc)
if(progfunc == "en") {
progf <- fit_prog_reg
} else if(progfunc == "rf") {
progf <- fit_prog_rf
} else if(progfunc == "gsyn"){
progf <- fit_prog_gsynth
} else if(progfunc == "mcp"){
progf <- fit_prog_mcpanel
} else if(progfunc == "cits") {
progf <- fit_prog_cits
} else if(progfunc == "causalimpact") {
progf <- fit_prog_causalimpact
} else if(progfunc == "seq2seq"){
progf <- fit_prog_seq2seq
} else {
stop("progfunc must be one of 'EN', 'RF', 'GSYN', 'MCP', 'CITS', 'CausalImpact', 'seq2seq'")
}
if(scm) {
weightf <- fit_synth_formatted
} else {
## still fit synth even if none
## TODO: This is a dumb wasteful hack
weightf <- fit_zero_weights
}
return(fit_augsyn_formatted(wide_data, synth_data,
progf, weightf, ...))
}
### Combine synth and gsynth by balancing pre-period residuals
#' Fit outcome model and balance residuals
#'
#' @param wide_data Output of `format_data`
#' @param synth_data Output of `format_synth`
#' @param fit_progscore Function to fit prognostic score
#' @param fit_weights Function to fit synth weights
#' @param ... optional arguments for outcome model
#' @noRd
#' @return \itemize{
#' \item{"weights"}{Ridge ASCM weights}
#' \item{"l2_imbalance"}{Imbalance in pre-period outcomes, measured by the L2 norm}
#' \item{"scaled_l2_imbalance"}{L2 imbalance scaled by L2 imbalance of uniform weights}
#' \item{"mhat"}{Outcome model estimate}
#' }
fit_residaug_formatted <- function(wide_data, synth_data,
fit_progscore, fit_weights, ...) {
X <- wide_data$X
y <- wide_data$y
trt <- wide_data$trt
## fit prognostic scores
fitout <- do.call(fit_progscore, list(X=X, y=y, trt=trt, ...))
y0hat <- fitout$y0hat
## get residuals
ctrl_resids <- fitout$params$ctrl_resids
trt_resids <- fitout$params$trt_resids
## replace outcomes with pre-period residuals
t0 <- dim(X)[2]
synth_data$Z0 <- ctrl_resids[1:t0, ]
synth_data$Z1 <- as.matrix(trt_resids[1:t0])
## fit synth weights
syn <- fit_weights(synth_data)
syn$params <- fitout$params
## return predicted values for treatment and control
syn$mhat <- y0hat
return(syn)
}
#' Fit outcome model and balance residuals
#'
#' @param wide_data Output of `format_data`
#' @param synth_data Output of `format_synth`
#' @param progfunc What function to use to impute control outcomes
#' GSYN=gSynth, MCP=MCPanel,
#' CITS=Comparative interrupted time series
#' CausalImpact=Bayesian structural time series with CausalImpact
#' @param weightfunc What function to use to fit weights
#' SCM=Vanilla Synthetic Controls
#' NONE=No reweighting, just outcome model
#' @param ... optional arguments for outcome model
#' @noRd
#' @return \itemize{
#' \item{"weights"}{Ridge ASCM weights}
#' \item{"l2_imbalance"}{Imbalance in pre-period outcomes, measured by the L2 norm}
#' \item{"scaled_l2_imbalance"}{L2 imbalance scaled by L2 imbalance of uniform weights}
#' \item{"mhat"}{Outcome model estimate}
#' }
fit_residaug <- function(wide_data, synth_data,
progfunc=c("GSYN", "MCP", "CITS", "CausalImpact"),
weightfunc=c("SC","ENT", "SVD", "NONE"), ...) {
## prognostic score and weight functions to use
if(progfunc == "GSYN"){
progf <- fit_prog_gsynth
} else if(progfunc == "MCP"){
progf <- fit_prog_mcpanel
} else if(progfunc == "CITS") {
progf <- fit_prog_cits
} else if(progfunc == "CausalImpact") {
progf <- fit_prog_causalimpact
} else {
stop("progfunc must be one of 'GSYN', 'MCP', 'CITS', 'CausalImpact'")
}
## weight function to use
if(weightfunc == "SCM") {
weightf <- fit_synth_formatted
} else if(weightfunc == "NONE") {
## still fit synth even if none
## TODO: This is a dumb wasteful hack
weightf <- fit_synth_formatted
} else {
stop("weightfunc must be one of 'SCM', 'NONE'")
}
return(fit_residaug_formatted(wide_data, synth_data,
progf, weightf, ...))
}
================================================
FILE: R/inference.R
================================================
################################################################################
## Code for inference
################################################################################
#' Jackknife+ algorithm over time
#' @param ascm Fitted `augsynth` object
#' @param alpha Confidence level
#' @param conservative Whether to use the conservative jackknife+ procedure
#' @return List that contains:
#' \itemize{
#' \item{"att"}{Vector of ATT estimates}
#' \item{"heldout_att"}{Vector of ATT estimates with the time period held out}
#' \item{"se"}{Standard error, always NA but returned for compatibility}
#' \item{"lb"}{Lower bound of 1 - alpha confidence interval}
#' \item{"ub"}{Upper bound of 1 - alpha confidence interval}
#' \item{"alpha"}{Level of confidence interval}
#' }
time_jackknife_plus <- function(ascm, alpha = 0.05, conservative = F) {
wide_data <- ascm$data
synth_data <- ascm$data$synth_data
n <- nrow(wide_data$X)
n_c <- dim(synth_data$Z0)[2]
Z <- wide_data$Z
t0 <- dim(synth_data$Z0)[1]
tpost <- ncol(wide_data$y)
t_final <- dim(synth_data$Y0plot)[1]
jack_ests <- lapply(1:t0,
function(tdrop) {
# drop unit i
new_data <- drop_time_t(wide_data, Z, tdrop)
# refit
new_ascm <- do.call(fit_augsynth_internal,
c(list(wide = new_data$wide,
synth_data = new_data$synth_data,
Z = new_data$Z,
progfunc = ascm$progfunc,
scm = ascm$scm,
fixedeff = ascm$fixedeff),
ascm$extra_args))
# get ATT estimates and held out error for time t
# t0 is prediction for held out time
est <- predict(new_ascm, att = F)[(t0 +1):t_final]
est <- c(est, mean(est))
err <- c(colMeans(wide_data$X[wide_data$trt == 1,
tdrop,
drop = F]) -
predict(new_ascm, att = F)[t0])
list(err, rbind(est + abs(err), est - abs(err), est + err, est))
})
# get errors and jackknife distribution
held_out_errs <- vapply(jack_ests, `[[`, numeric(1), 1)
jack_dist <- vapply(jack_ests, `[[`,
matrix(0, nrow = 4, ncol = tpost + 1), 2)
out <- list()
att <- predict(ascm, att = T)
out$att <- c(att,
mean(att[(t0 + 1):t_final]))
# held out ATT
out$heldout_att <- c(held_out_errs,
att[(t0 + 1):t_final],
mean(att[(t0 + 1):t_final]))
# out$se <- rep(NA, 10 + tpost)
if(conservative) {
qerr <- stats::quantile(abs(held_out_errs), 1 - alpha)
out$lb <- c(rep(NA, t0), apply(jack_dist[4,,], 1, min) - qerr)
out$ub <- c(rep(NA, t0), apply(jack_dist[4,,], 1, max) + qerr)
} else {
out$lb <- c(rep(NA, t0), apply(jack_dist[2,,], 1, stats::quantile, alpha / 2))
out$ub <- c(rep(NA, t0), apply(jack_dist[1,,], 1, stats::quantile, 1 - alpha / 2))
}
# shift back to ATT scale
y1 <- predict(ascm, att = F) + att
y1 <- c(y1, mean(y1[(t0 + 1):t_final]))
shifted_lb <- y1 - out$ub
shifted_ub <- y1 - out$lb
out$lb <- shifted_lb
out$ub <- shifted_ub
out$alpha <- alpha
return(out)
}
#' Drop time period from pre-treatment data
#' @param wide_data (X, y, trt)
#' @param Z Covariates matrix
#' @param t_drop Time to drop
#' @noRd
drop_time_t <- function(wide_data, Z, t_drop) {
new_wide_data <- list()
new_wide_data$trt <- wide_data$trt
new_wide_data$X <- wide_data$X[, -t_drop, drop = F]
new_wide_data$y <- cbind(wide_data$X[, t_drop, drop = F],
wide_data$y)
X0 <- new_wide_data$X[new_wide_data$trt == 0,, drop = F]
x1 <- matrix(colMeans(new_wide_data$X[new_wide_data$trt == 1,,
drop = F]),
ncol=1)
y0 <- new_wide_data$y[new_wide_data$trt == 0,, drop = F]
y1 <- colMeans(new_wide_data$y[new_wide_data$trt == 1,, drop = F])
new_synth_data <- list()
new_synth_data$Z0 <- t(X0)
new_synth_data$X0 <- t(X0)
new_synth_data$Z1 <- x1
new_synth_data$X1 <- x1
return(list(wide_data = new_wide_data,
synth_data = new_synth_data,
Z = Z))
}
#' Conformal inference procedure to compute p-values and point-wise confidence intervals
#' @param ascm Fitted `augsynth` object
#' @param alpha Confidence level
#' @param stat_func Function to compute test statistic
#' @param type Either "iid" for iid permutations or "block" for moving block permutations; default is "block"
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param grid_size Number of grid points to use when inverting the hypothesis test
#' @return List that contains:
#' \itemize{
#' \item{"att"}{Vector of ATT estimates}
#' \item{"heldout_att"}{Vector of ATT estimates with the time period held out}
#' \item{"se"}{Standard error, always NA but returned for compatibility}
#' \item{"lb"}{Lower bound of 1 - alpha confidence interval}
#' \item{"ub"}{Upper bound of 1 - alpha confidence interval}
#' \item{"p_val"}{p-value for test of no post-treatment effect}
#' \item{"alpha"}{Level of confidence interval}
#' }
conformal_inf <- function(ascm, alpha = 0.05,
stat_func = NULL, type = "iid",
q = 1, ns = 1000, grid_size = 50) {
wide_data <- ascm$data
synth_data <- ascm$data$synth_data
n <- nrow(wide_data$X)
n_c <- dim(synth_data$Z0)[2]
Z <- wide_data$Z
t0 <- dim(synth_data$Z0)[1]
tpost <- ncol(wide_data$y)
t_final <- dim(synth_data$Y0plot)[1]
# grid of nulls
att <- predict(ascm, att = T)
post_att <- att[(t0 +1):t_final]
post_sd <- sqrt(mean(post_att ^ 2))
# iterate over post-treatment periods to get pointwise CIs
vapply(1:tpost,
function(j) {
# fit using t0 + j as a pre-treatment period and get reisduals
new_wide_data <- wide_data
new_wide_data$X <- cbind(wide_data$X, wide_data$y[, j, drop = TRUE])
if(tpost > 1) {
new_wide_data$y <- wide_data$y[, -j, drop = FALSE]
} else {
# set the post period has to be *something*
new_wide_data$y <- matrix(1, nrow = n, ncol = 1)
}
# make a grid around the estimated ATT
grid <- seq(att[t0 + j] - 2 * post_sd, att[t0 + j] + 2 * post_sd,
length.out = grid_size)
compute_permute_ci(new_wide_data, ascm, grid, 1, alpha, type,
q, ns, stat_func)
},
numeric(3)) -> cis
# test a null post-treatment effect
new_wide_data <- wide_data
new_wide_data$X <- cbind(wide_data$X, wide_data$y)
new_wide_data$y <- matrix(1, nrow = n, ncol = 1)
null_p <- compute_permute_pval(new_wide_data, ascm, 0, ncol(wide_data$y),
type, q, ns, stat_func)
out <- list()
att <- predict(ascm, att = T)
out$att <- c(att, mean(att[(t0 + 1):t_final]))
# out$se <- rep(NA, t_final)
# out$sigma <- NA
out$lb <- c(rep(NA, t0), cis[1, ], NA)
out$ub <- c(rep(NA, t0), cis[2, ], NA)
out$p_val <- c(rep(NA, t0), cis[3, ], null_p)
out$alpha <- alpha
return(out)
}
#' Conformal inference procedure to compute a confidence interval for a linear in time effect
#' @param ascm Fitted `augsynth` object
#' @param alpha Confidence level
#' @param stat_func Function to compute test statistic
#' @param type Either "iid" for iid permutations or "block" for moving block permutations; default is "iid"
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param grid_size Number of grid points to use when inverting the hypothesis test
#' @return List that contains:
#' \itemize{
#' \item{"att"}{Vector of ATT estimates}
#' \item{"heldout_att"}{Vector of ATT estimates with the time period held out}
#' \item{"se"}{Standard error, always NA but returned for compatibility}
#' \item{"lb"}{Lower bound of 1 - alpha confidence interval}
#' \item{"ub"}{Upper bound of 1 - alpha confidence interval}
#' \item{"p_val"}{p-value for test of no post-treatment effect}
#' \item{"alpha"}{Level of confidence interval}
#' }
conformal_inf_linear <- function(ascm, alpha = 0.05,
stat_func = NULL, type = "iid",
q = 1, ns = 1000, grid_size = 50) {
wide_data <- ascm$data
synth_data <- ascm$data$synth_data
n <- nrow(wide_data$X)
n_c <- dim(synth_data$Z0)[2]
Z <- wide_data$Z
t0 <- dim(synth_data$Z0)[1]
tpost <- ncol(wide_data$y)
t_final <- dim(synth_data$Y0plot)[1]
# grid of nulls
att <- predict(ascm, att = T)
post_att <- att[(t0 +1):t_final]
post_second <- sqrt(mean(post_att^2))
# grid for slope
# use ols to get pilot estimate
ts <- 1:tpost
lm_out <- summary(lm(post_att ~ ts))$coefficients
# grid for intercept
grid_int <- seq(lm_out[1,1] - 2 * post_second,
lm_out[1,1] + 2 * post_second,
length.out = grid_size)
if(tpost == 2) {
warning(paste0("There are 2 post-treatment time periods, so a linear model has a perfect fit. A confidence interval for the slope may not be reasonable here."))
grid_slope <- seq(lm_out[2,1] - abs(lm_out[2,1]),
lm_out[2,1] + abs(lm_out[2,1]),
length.out = grid_size)
} else if(tpost <= 1) {
stop("There is only one post-treatment time period, so an intercept and a slope cannot be computed.")
} else {
grid_slope <- seq(lm_out[2,1] - 4 * lm_out[2,2] * sqrt(tpost),
lm_out[2,1] + 4 * lm_out[2,2] * sqrt(tpost),
length.out = grid_size)
}
# test a null post-treatment effect
new_wide_data <- wide_data
new_wide_data$X <- cbind(wide_data$X, wide_data$y)
new_wide_data$y <- matrix(1, nrow = n, ncol = 1)
null_p <- compute_permute_pval(new_wide_data, ascm, 0, ncol(wide_data$y),
type, q, ns, stat_func)
# confidence interval for linear in time treatment effects
cis <- compute_permute_ci_linear(new_wide_data, ascm, grid_int, grid_slope,
ncol(wide_data$y), alpha, type, q, ns, stat_func)
return(cis)
}
#' Compute conformal test statistics
#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector
#' @param ascm Fitted `augsynth` object
#' @param h0 Null hypothesis to test
#' @param post_length Number of post-treatment periods
#' @param type Either "iid" for iid permutations or "block" for moving block permutations
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param stat_func Function to compute test statistic
#'
#' @return List that contains:
#' \itemize{
#' \item{"resids"}{Residuals after enforcing the null}
#' \item{"test_stats"}{Permutation distribution of test statistics}
#' \item{"stat_func"}{Test statistic function}
#' }
#' @noRd
compute_permute_test_stats <- function(wide_data, ascm, h0,
post_length, type,
q, ns, stat_func) {
# format data
new_wide_data <- wide_data
t0 <- ncol(wide_data$X) - post_length
tpost <- t0 + post_length
# adjust outcomes for null
new_wide_data$X[wide_data$trt == 1,(t0 + 1):tpost ] <- new_wide_data$X[wide_data$trt == 1,(t0 + 1):tpost] - h0
X0 <- new_wide_data$X[new_wide_data$trt == 0,, drop = F]
x1 <- matrix(colMeans(new_wide_data$X[new_wide_data$trt == 1,, drop = F]),
ncol=1)
new_synth_data <- list()
new_synth_data$Z0 <- t(X0)
new_synth_data$X0 <- t(X0)
new_synth_data$Z1 <- x1
new_synth_data$X1 <- x1
# fit synth with adjusted data and get residuals
new_ascm <- do.call(fit_augsynth_internal,
c(list(wide = new_wide_data,
synth_data = new_synth_data,
Z = wide_data$Z,
progfunc = ascm$progfunc,
scm = ascm$scm,
fixedeff = ascm$fixedeff),
ascm$extra_args))
resids <- predict(new_ascm, att = T)[1:tpost]
# permute residuals and compute test statistic
if(is.null(stat_func)) {
stat_func <- function(x) (sum(abs(x) ^ q) / sqrt(length(x))) ^ (1 / q)
}
if(type == "iid") {
test_stats <- sapply(1:ns,
function(x) {
reorder <- sample(resids)
stat_func(reorder[(t0 + 1):tpost])
})
} else {
## increment time by one step and wrap
test_stats <- sapply(1:tpost,
function(j) {
reorder <- resids[(0:tpost -1 + j) %% tpost + 1]
stat_func(reorder[(t0 + 1):tpost])
})
}
return(list(resids = resids,
test_stats = test_stats,
stat_func = stat_func))
}
#' Compute conformal p-value
#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector
#' @param ascm Fitted `augsynth` object
#' @param h0 Null hypothesis to test
#' @param post_length Number of post-treatment periods
#' @param type Either "iid" for iid permutations or "block" for moving block permutations
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param stat_func Function to compute test statistic
#'
#' @return Computed p-value
#' @noRd
compute_permute_pval <- function(wide_data, ascm, h0,
post_length, type,
q, ns, stat_func) {
t0 <- ncol(wide_data$X) - post_length
tpost <- t0 + post_length
out <- compute_permute_test_stats(wide_data, ascm, h0,
post_length, type, q, ns, stat_func)
mean(out$stat_func(out$resids[(t0 + 1):tpost]) <= out$test_stats)
}
#' Compute conformal p-value
#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector
#' @param ascm Fitted `augsynth` object
#' @param grid Set of null hypothesis to test for inversion
#' @param post_length Number of post-treatment periods
#' @param type Either "iid" for iid permutations or "block" for moving block permutations
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param stat_func Function to compute test statistic
#'
#' @return (lower bound of interval, upper bound of interval, p-value for null of 0 effect)
#' @noRd
compute_permute_ci <- function(wide_data, ascm, grid,
post_length, alpha, type,
q, ns, stat_func) {
# make sure 0 is in the grid
grid <- c(grid, 0)
ps <-sapply(grid,
function(x) {
compute_permute_pval(wide_data, ascm, x,
post_length, type, q, ns, stat_func)
})
c(min(grid[ps >= alpha]), max(grid[ps >= alpha]), ps[grid == 0])
}
#' Compute conformal confidence interval for a linear model for effects
#' int + slope * time
#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector
#' @param ascm Fitted `augsynth` object
#' @param grid_int Set of null hypothesis values for the intercept
#' @param grid_slope Set of null hypothesis values for the slope
#' @param post_length Number of post-treatment periods
#' @param type Either "iid" for iid permutations or "block" for moving block permutations
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param stat_func Function to compute test statistic
#'
#' @return (lower bound of interval, upper bound of interval, p-value for null of 0 effect)
#' @noRd
compute_permute_ci_linear <- function(wide_data, ascm, grid_int, grid_slope,
post_length, alpha, type,
q, ns, stat_func) {
# make sure 0 is in both grids
# grid_int <- c(grid_int, 0)
# grid_slope <- c(grid_slope, 0)
# make the combined grid
grid_comb <- expand.grid(grid_int, grid_slope)
grid_comb$p_val <-apply(grid_comb, 1,
function(x) {
compute_permute_pval(wide_data, ascm, x[1] + x[2] * (1:post_length),
post_length, type, q, ns, stat_func)
})
ci_int <- c(min(grid_comb[grid_comb$p_val >= alpha, 1]),
max(grid_comb[grid_comb$p_val >= alpha, 1]))
ci_slope <- c(min(grid_comb[grid_comb$p_val >= alpha, 2]),
max(grid_comb[grid_comb$p_val >= alpha, 2]))
int_slope_est <- as.numeric(grid_comb[which.max(grid_comb$p_val), 1:2])
return(list(est_int = int_slope_est[1], ci_int = ci_int,
est_slope = int_slope_est[2], ci_slope = ci_slope))
}
#' Jackknife+ algorithm over time
#' @param ascm Fitted `augsynth` object
#' @param alpha Confidence level
#' @param conservative Whether to use the conservative jackknife+ procedure
#' @return List that contains:
#' \itemize{
#' \item{"att"}{Vector of ATT estimates}
#' \item{"heldout_att"}{Vector of ATT estimates with the time period held out}
#' \item{"se"}{Standard error, always NA but returned for compatibility}
#' \item{"lb"}{Lower bound of 1 - alpha confidence interval}
#' \item{"ub"}{Upper bound of 1 - alpha confidence interval}
#' \item{"alpha"}{Level of confidence interval}
#' }
time_jackknife_plus_multiout <- function(ascm_multi, alpha = 0.05, conservative = F) {
wide_data <- ascm_multi$data
data_list <- ascm_multi$data_list
n <- nrow(wide_data$X)
k <- length(data_list$X)
t0 <- min(sapply(data_list$X, ncol))
tpost <- max(sapply(data_list$y, ncol))
t_final <- t0 + tpost
Z <- wide_data$Z
jack_ests <- lapply(1:t0,
function(tdrop) {
# drop unit i
new_data_list <- drop_time_t_multiout(data_list, Z, tdrop)
# refit
new_ascm <- do.call(fit_augsynth_multiout_internal,
c(list(wide_list = new_data_list,
combine_method = ascm_multi$combine_method,
Z = data_list$Z,
progfunc = ascm_multi$progfunc,
scm = ascm_multi$scm,
fixedeff = ascm_multi$fixedeff,
outcomes_str = ascm_multi$outcomes),
ascm_multi$extra_args))
# get ATT estimates and held out error for time t
# t0 is prediction for held out time
est <- predict(new_ascm, att = F)[(t0 +1):t_final, , drop = F]
est <- rbind(est, colMeans(est))
# err <- c(colMeans(wide_data$X[wide_data$trt == 1,
# tdrop,
# drop = F]) -
# predict(new_ascm, att = F)[t0])
err <- c(predict(new_ascm, att = T)[t0, , drop = F])
list(err, t(t(est) + abs(err)), t(t(est) - abs(err)), t(t(est) + err), est)
})
# get errors and jackknife distribution
held_out_errs <- matrix(vapply(jack_ests, `[[`, numeric(k), 1), nrow = k)
jack_dist_high <- vapply(jack_ests, `[[`,
matrix(0, nrow = tpost + 1, ncol = k), 2)
jack_dist_low <- vapply(jack_ests, `[[`,
matrix(0, nrow = tpost + 1, ncol = k), 3)
jack_dist_cons <- vapply(jack_ests, `[[`,
matrix(0, nrow = tpost + 1, ncol = k), 4)
out <- list()
att <- predict(ascm_multi, att = T)
out$att <- rbind(att,
colMeans(att[(t0 + 1):t_final, , drop = F]))
# held out ATT
out$heldout_att <- rbind(t(held_out_errs),
att[(t0 + 1):t_final, , drop = F],
colMeans(att[(t0 + 1):t_final, , drop = F]))
if(conservative) {
qerr <- apply(abs(held_out_errs), 1,
stats::quantile, 1 - alpha, type = 1)
out$lb <- rbind(matrix(NA, nrow = t0, ncol = k),
t(t(apply(jack_dist_cons, 1:2, min)) - qerr))
out$ub <- rbind(matrix(NA, nrow = t0, ncol = k),
t(t(apply(jack_dist_cons, 1:2, max)) + qerr))
} else {
out$lb <- rbind(matrix(NA, nrow = t0, ncol = k),
apply(jack_dist_low, 1:2,
stats::quantile, alpha, type = 1))
out$ub <- rbind(matrix(NA, nrow = t0, ncol = k),
apply(jack_dist_high, 1:2,
stats::quantile, 1 - alpha, type = 1))
}
# shift back to ATT scale
y1 <- predict(ascm_multi, att = F) + att
y1 <- rbind(y1, colMeans(y1[(t0 + 1):t_final, , drop = F]))
shifted_lb <- y1 - out$ub
shifted_ub <- y1 - out$lb
out$lb <- shifted_lb
out$ub <- shifted_ub
out$alpha <- alpha
return(out)
}
#' Drop time period from pre-treatment data
#' @param wide_data (X, y, trt)
#' @param Z Covariates matrix
#' @param t_drop Time to drop
#' @noRd
drop_time_t_multiout <- function(data_list, Z, t_drop) {
new_data_list <- list()
new_data_list$trt <- data_list$trt
new_data_list$X <- lapply(data_list$X,
function(x) x[, -t_drop, drop = F])
new_data_list$y <- lapply(1:length(data_list$y),
function(k) {
cbind(data_list$X[[k]][, t_drop, drop = F],
data_list$y[[k]])
})
return(new_data_list)
}
#' Conformal inference procedure to compute p-values and point-wise confidence intervals
#' @param ascm Fitted `augsynth` object
#' @param alpha Confidence level
#' @param stat_func Function to compute test statistic
#' @param type Either "iid" for iid permutations or "block" for moving block permutations
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param grid_size Number of grid points to use when inverting the hypothesis test (default is 1, so only to test joint null)
#' @return List that contains:
#' \itemize{
#' \item{"att"}{Vector of ATT estimates}
#' \item{"heldout_att"}{Vector of ATT estimates with the time period held out}
#' \item{"se"}{Standard error, always NA but returned for compatibility}
#' \item{"lb"}{Lower bound of 1 - alpha confidence interval}
#' \item{"ub"}{Upper bound of 1 - alpha confidence interval}
#' \item{"p_val"}{p-value for test of no post-treatment effect}
#' \item{"alpha"}{Level of confidence interval}
#' }
conformal_inf_multiout <- function(ascm_multi, alpha = 0.05,
stat_func = NULL, type = "iid",
q = 1, ns = 1000, grid_size = 1,
lin_h0 = NULL) {
wide_data <- ascm_multi$data
data_list <- ascm_multi$data_list
n <- nrow(wide_data$X)
k <- length(data_list$X)
t0 <- min(sapply(data_list$X, ncol))
tpost <- max(sapply(data_list$y, ncol))
t_final <- t0 + tpost
# grid of nulls
att <- predict(ascm_multi, att = T)
post_att <- att[(t0 +1):t_final,, drop = F]
post_sd <- apply(post_att, 2, function(x) sqrt(mean(x ^ 2, na.rm = T)))
# iterate over post-treatment periods to get pointwise CIs
vapply(1:tpost,
function(j) {
# fit using t0 + j as a pre-treatment period and get residuals
new_data_list <- data_list
new_data_list$X <- lapply(1:k,
function(i) {
Xi <- cbind(data_list$X[[i]], data_list$y[[i]][, j, drop = TRUE])
colnames(Xi) <- c(colnames(data_list$X[[i]]),
colnames(data_list$y[[i]])[j])
Xi
})
if(tpost > 1) {
new_data_list$y <- lapply(1:k,
function(i) {
data_list$y[[i]][, -j, drop = FALSE]
})
} else {
# set the post period has to be *something*
new_data_list$y <- lapply(1:k,
function(i) {
x <- matrix(1, nrow = n, ncol = 1)
colnames(x) <- max(as.numeric(colnames(data_list$y[[i]]))) + 1
x
})
}
# make a grid around the estimated ATT
if(is.null(lin_h0)) {
grid <- lapply(1:k,
function(i) {
seq(att[t0 + j, i] - 2 * post_sd[i], att[t0 + j, i] + 2 * post_sd[i],
length.out = grid_size)
})
} else {
grid <- seq(min(att[t0 + j, ]) - 2 * max(post_sd),
max(att[t0 + j, ]) + 2 * max(post_sd),
length.out = grid_size)
}
if(grid_size > 1) {
compute_permute_ci_multiout(new_data_list, ascm_multi, grid, 1,
alpha, type, q, ns, lin_h0, stat_func)
} else {
rbind(matrix(0, ncol = k, nrow = 2),
compute_permute_pval_multiout(new_data_list, ascm_multi, numeric(k),
1, type, q, ns, stat_func))
}
},
matrix(0, ncol = k, nrow=3)) -> cis
# # test a null post-treatment effect
new_data_list <- data_list
new_data_list$X <- lapply(1:k,
function(i) {
Xi <- cbind(data_list$X[[i]], data_list$y[[i]])
colnames(Xi) <- c(colnames(data_list$X[[i]]),
colnames(data_list$y[[i]]))
Xi
})
# set post treatment to be *something*
new_data_list$y <- lapply(1:k,
function(i) {
data_list$y[[i]][, 1, drop = FALSE]
})
null_p <- compute_permute_pval_multiout(new_data_list, ascm_multi,
numeric(k),
tpost, type, q, ns, stat_func)
if(is.null(lin_h0)) {
grid <- lapply(1:k,
function(i) {
seq(min(att[(t0 + 1):tpost, i]) - 4 * post_sd[i],
max(att[(t0 + 1):tpost, i]) + 4 * post_sd[i],
length.out = grid_size)
})
} else {
grid <- seq(min(att[t0 + 1, ]) - 3 * max(post_sd),
max(att[t0 + 1, ]) + 3 * max(post_sd),
length.out = grid_size)
}
null_ci <- compute_permute_ci_multiout(new_data_list, ascm_multi, grid,
tpost, alpha, type, q, ns,
lin_h0, stat_func)
out <- list()
att <- predict(ascm_multi, att = T)
out$att <- rbind(att, apply(att[(t0 + 1):t_final, , drop = F], 2, mean))
out$lb <- rbind(matrix(NA, nrow = t0, ncol = k),
t(matrix(cis[1, ,], nrow = k)),
# rep(NA, k)
null_ci[1,]
)
colnames(out$lb) <- ascm_multi$outcomes
out$ub <- rbind(matrix(NA, nrow = t0, ncol = k),
t(matrix(cis[2, ,], nrow = k)),
# rep(NA, k)
null_ci[2,]
)
colnames(out$ub) <- ascm_multi$outcomes
out$p_val <- rbind(matrix(NA, nrow = t0, ncol = k),
t(matrix(cis[3, ,], nrow = k)),
# rep(null_p, k)
null_ci[3,])
colnames(out$p_val) <- ascm_multi$outcomes
out$alpha <- alpha
return(out)
}
#' Compute conformal test statistics
#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector
#' @param ascm Fitted `augsynth` object
#' @param h0 Null hypothesis to test
#' @param post_length Number of post-treatment periods
#' @param type Either "iid" for iid permutations or "block" for moving block permutations
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param stat_func Function to compute test statistic
#'
#' @return List that contains:
#' \itemize{
#' \item{"resids"}{Residuals after enforcing the null}
#' \item{"test_stats"}{Permutation distribution of test statistics}
#' \item{"stat_func"}{Test statistic function}
#' }
#' @noRd
compute_permute_test_stats_multiout <- function(data_list, ascm_multi, h0,
post_length, type,
q, ns, stat_func) {
# format data
new_data_list <- data_list
t0 <- ncol(data_list$X[[1]]) - post_length
tpost <- t0 + post_length
k <- length(data_list$X)
# adjust outcomes for null
for(i in 1:k) {
new_data_list$X[[k]][data_list$trt == 1,(t0 + 1):tpost ] <- new_data_list$X[[k]][data_list$trt == 1,(t0 + 1):tpost] - h0[i]
}
# fit synth with adjusted data and get residuals
new_ascm <- do.call(fit_augsynth_multiout_internal,
c(list(wide_list = new_data_list,
combine_method = ascm_multi$combine_method,
Z = data_list$Z,
progfunc = ascm_multi$progfunc,
scm = ascm_multi$scm,
fixedeff = ascm_multi$fixedeff,
outcomes_str = ascm_multi$outcomes),
ascm_multi$extra_args))
resids <- predict(new_ascm, att = T)[1:tpost, , drop = F]
# permute residuals and compute test statistic
if(is.null(stat_func)) {
stat_func <- function(x) {
x <- na.omit(x)
(sum(abs(x) ^ q) / sqrt(length(x))) ^ (1 / q)
}
}
if(type == "iid") {
test_stats <- sapply(1:ns,
function(x) {
idxs <- sample(1:nrow(resids))
reorder <- resids[idxs, , drop = F]
apply(reorder[(t0 + 1):tpost, ,drop = F], 2, stat_func)
})
} else {
## increment time by one step and wrap
test_stats <- sapply(0:(tpost - 1),
function(j) {
reorder <- resids[(0:(tpost -1) + j) %% tpost + 1, ,drop = F]
if(!all(dim(reorder) == dim(resids))) {
stop("Error in block resampling")
}
apply(reorder[(t0 + 1):tpost, , drop = F], 2, stat_func)
})
}
return(list(resids = resids,
test_stats = matrix(test_stats, nrow = k),
stat_func = stat_func))
}
#' Compute conformal p-value
#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector
#' @param ascm Fitted `augsynth` object
#' @param h0 Null hypothesis to test
#' @param post_length Number of post-treatment periods
#' @param type Either "iid" for iid permutations or "block" for moving block permutations
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param stat_func Function to compute test statistic
#'
#' @return Computed p-value
#' @noRd
compute_permute_pval_multiout <- function(data_list, ascm_multi, h0,
post_length, type,
q, ns, stat_func) {
t0 <- ncol(data_list$X[[1]]) - post_length
tpost <- t0 + post_length
out <- compute_permute_test_stats_multiout(data_list, ascm_multi, h0,
post_length, type, q, ns, stat_func)
k <- length(data_list$X)
comb_stat <- mean(apply(out$resids[(t0 + 1):tpost, , drop = F], 2, out$stat_func), na.rm = TRUE)
comb_test_stats <- apply(out$test_stats, 2, mean, na.rm = TRUE)
# if(h0 == 0) {
# hist(comb_test_stats)
# abline(v = comb_stat)
# print(mean(comb_stat <= comb_test_stats))
# print(1 - mean(comb_stat > comb_test_stats))
# }
1 - mean(comb_stat > comb_test_stats)
}
#' Compute conformal p-value
#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector
#' @param ascm Fitted `augsynth` object
#' @param grid Set of null hypothesis to test for inversion
#' @param post_length Number of post-treatment periods
#' @param type Either "iid" for iid permutations or "block" for moving block permutations
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param stat_func Function to compute test statistic
#'
#' @return (lower bound of interval, upper bound of interval, p-value for null of 0 effect)
#' @noRd
compute_permute_ci_multiout <- function(data_list, ascm_multi, grid,
post_length, alpha, type,
q, ns, lin_h0 = NULL, stat_func) {
# make sure 0 is in the grid
if(is.null(lin_h0)) {
grid <- lapply(grid, function(x) c(x, 0))
k <- length(grid)
# get all combinations of grid
grid <- expand.grid(grid)
grid_low <- NULL
} else {
k <- length(lin_h0)
# keep track of low dimensional grid
grid_low <- c(grid, 0)
# transform into high dimensional grid with linear hypothesis
grid <- sapply(lin_h0, function(x) x * grid_low)
}
ps <- apply(grid, 1,
function(x) {
compute_permute_pval_multiout(data_list, ascm_multi, x,
post_length, type, q, ns, stat_func)
})
sapply(1:k,
function(i) c(min(grid[ps >= alpha, i]),
max(grid[ps >= alpha, i]),
ps[apply(grid == 0, 1, all)]))
}
#' Drop unit i from data
#' @param wide_data (X, y, trt)
#' @param Z Covariates matrix
#' @param i Unit to drop
#' @noRd
drop_unit_i <- function(wide_data, Z, i) {
new_wide_data <- list()
new_wide_data$trt <- wide_data$trt[-i]
new_wide_data$X <- wide_data$X[-i,, drop = F]
new_wide_data$y <- wide_data$y[-i,, drop = F]
X0 <- new_wide_data$X[new_wide_data$trt == 0,, drop = F]
x1 <- matrix(colMeans(new_wide_data$X[new_wide_data$trt == 1,, drop = F]),
ncol=1)
y0 <- new_wide_data$y[new_wide_data$trt == 0,, drop = F]
y1 <- colMeans(new_wide_data$y[new_wide_data$trt == 1,, drop = F])
new_synth_data <- list()
new_synth_data$Z0 <- t(X0)
new_synth_data$X0 <- t(X0)
new_synth_data$Z1 <- x1
new_synth_data$X1 <- x1
new_Z <- if(!is.null(Z)) Z[-i, , drop = F] else NULL
return(list(wide_data = new_wide_data,
synth_data = new_synth_data,
Z = new_Z))
}
#' Drop unit i from data
#' @param wide_list (X, y, trt)
#' @param Z Covariates matrix
#' @param i Unit to drop
#' @noRd
drop_unit_i_multiout <- function(wide_list, Z, i) {
new_wide_data <- list()
new_wide_data$trt <- wide_list$trt[-i]
new_wide_data$X <- lapply(wide_list$X, function(x) x[-i,, drop = F])
new_wide_data$y <- lapply(wide_list$y, function(x) x[-i,, drop = F])
new_Z <- if(!is.null(Z)) Z[-i, , drop = F] else NULL
return(list(wide_list = new_wide_data,
Z = new_Z))
}
#' Estimate standard errors for single ASCM with the jackknife
#' Do this for ridge-augmented synth
#' @param ascm Fitted augsynth object
#'
#' @return List that contains:
#' \itemize{
#' \item{"att"}{Vector of ATT estimates}
#' \item{"se"}{Standard error estimate}
#' \item{"lb"}{Lower bound of 1 - alpha confidence interval}
#' \item{"ub"}{Upper bound of 1 - alpha confidence interval}
#' \item{"alpha"}{Level of confidence interval}
#' }
jackknife_se_single <- function(ascm) {
wide_data <- ascm$data
synth_data <- ascm$data$synth_data
n <- nrow(wide_data$X)
n_c <- dim(synth_data$Z0)[2]
Z <- wide_data$Z
t0 <- dim(synth_data$Z0)[1]
tpost <- ncol(wide_data$y)
t_final <- dim(synth_data$Y0plot)[1]
errs <- matrix(0, n_c, t_final - t0)
# only drop out control units with non-zero weights
nnz_weights <- numeric(n)
nnz_weights[wide_data$trt == 0] <- round(ascm$weights, 3) != 0
# if more than one unit is treated, include them in the jackknife
if(sum(wide_data$trt) > 1) {
nnz_weights[wide_data$trt == 1] <- 1
}
trt_idxs <- (1:n)[as.logical(nnz_weights)]
# jackknife estimates
ests <- vapply(trt_idxs,
function(i) {
# drop unit i
new_data <- drop_unit_i(wide_data, Z, i)
# refit
new_ascm <- do.call(fit_augsynth_internal,
c(list(wide = new_data$wide,
synth_data = new_data$synth_data,
Z = new_data$Z,
progfunc = ascm$progfunc,
scm = ascm$scm,
fixedeff = ascm$fixedeff),
ascm$extra_args))
# get ATT estimates
est <- predict(new_ascm, att = T)[(t0 + 1):t_final]
c(est, mean(est))
},
numeric(tpost + 1))
# convert to matrix
ests <- matrix(ests, nrow = tpost + 1, ncol = length(trt_idxs))
## standard errors
se2 <- apply(ests, 1,
function(x) (n - 1) / n * sum((x - mean(x, na.rm = T)) ^ 2))
se <- sqrt(se2)
out <- list()
att <- predict(ascm, att = T)
out$att <- c(att, mean(att[(t0 + 1):t_final]))
out$se <- c(rep(NA, t0), se)
# out$sigma <- NA
return(out)
}
#' Compute standard errors using the jackknife
#' @param multisynth fitted multisynth object
#' @param relative Whether to compute effects according to relative time
#' @noRd
jackknife_se_multi <- function(multisynth, relative=NULL, alpha = 0.05, att_weight = NULL) {
## get info from the multisynth object
if(is.null(relative)) {
relative <- multisynth$relative
}
n_leads <- multisynth$n_leads
n <- nrow(multisynth$data$X)
att <- predict(multisynth, att=T, att_weight = att_weight)
outddim <- nrow(att)
J <- length(multisynth$grps)
## drop each unit and estimate overall treatment effect
jack_est <- vapply(1:n,
function(i) {
msyn_i <- drop_unit_i_multi(multisynth, i)
pred <- predict(msyn_i[[1]], relative=relative, att=T, att_weight = att_weight)
if(nrow(pred) < outddim) {
pred <- rbind(
pred[1:(nrow(pred)-1), ],
matrix(NA, nrow=outddim-nrow(pred), ncol=ncol(pred)),
pred[nrow(pred), ]
)
}
if(length(msyn_i[[2]]) != 0) {
out <- matrix(NA, nrow=nrow(pred), ncol=(J+1))
out[,-(msyn_i[[2]]+1)] <- pred
} else {
out <- pred
}
out
},
matrix(0, nrow=outddim,ncol=(J+1)))
se2 <- apply(jack_est, c(1,2),
function(x) (n-1) / n * sum((x - mean(x,na.rm=T))^2, na.rm=T))
lower_bound <- att - qnorm(1 - alpha / 2) * sqrt(se2)
upper_bound <- att + qnorm(1 - alpha / 2) * sqrt(se2)
return(list(att = att, se = sqrt(se2),
lower_bound = lower_bound, upper_bound = upper_bound))
}
#' Helper function to drop unit i and refit
#' @param msyn multisynth_object
#' @param i Unit to drop
#' @noRd
drop_unit_i_multi <- function(msyn, i) {
n <- nrow(msyn$data$X)
time_cohort <- msyn$time_cohort
which_t <- (1:n)[is.finite(msyn$data$trt)]
not_miss_j <- which_t %in% setdiff(which_t, i)
# drop unit i from data
drop_i <- msyn$data
drop_i$X <- msyn$data$X[-i, , drop = F]
drop_i$y <- msyn$data$y[-i, , drop = F]
drop_i$trt <- msyn$data$trt[-i]
drop_i$mask <- msyn$data$mask[not_miss_j,, drop = F]
if(!is.null(msyn$data$Z)) {
drop_i$Z <- msyn$data$Z[-i, , drop = F]
} else {
drop_i$Z <- NULL
}
long_df <- msyn$long_df
unit <- colnames(long_df)[1]
# make alphabetical, because the ith unit is the index in alphabetical ordering
long_df <- long_df[order(long_df[, unit, drop = TRUE]),]
ith_unit <- unique(long_df[,unit, drop = TRUE])[i]
long_df <- long_df[long_df[,unit, drop = TRUE] != ith_unit,]
# re-fit everything
args_list <- list(wide = drop_i, relative = msyn$relative,
n_leads = msyn$n_leads, n_lags = msyn$n_lags,
nu = msyn$nu, lambda = msyn$lambda,
V = msyn$V,
force = msyn$force, n_factors = msyn$n_factors,
scm = msyn$scm, time_w = msyn$time_w,
lambda_t = msyn$lambda_t,
fit_resids = msyn$fit_resids,
time_cohort = msyn$time_cohort, long_df = long_df,
how_match = msyn$how_match)
msyn_i <- do.call(multisynth_formatted, c(args_list, msyn$extra_pars))
# check for dropped treated units/time periods
if(time_cohort) {
dropped <- which(!msyn$grps %in% msyn_i$grps)
} else {
dropped <- which(!not_miss_j)
}
return(list(msyn_i,
dropped))
}
#' Estimate standard errors for multi outcome ascm with jackknife
#' @param ascm Fitted augsynth object
#' @noRd
jackknife_se_multiout <- function(ascm) {
wide_data <- ascm$data
wide_list <- ascm$data_list
n <- nrow(wide_data$X)
Z <- wide_data$Z
# only drop out control units with non-zero weights
nnz_weights <- numeric(n)
nnz_weights[wide_data$trt == 0] <- round(ascm$weights, 3) != 0
trt_idxs <- (1:n)[as.logical(nnz_weights)]
# jackknife estimates
ests <- lapply(trt_idxs,
function(i) {
# drop unit i
new_data <- drop_unit_i_multiout(wide_list, Z, i)
# refit
new_ascm <- do.call(fit_augsynth_multiout_internal,
c(list(wide = new_data$wide,
combine_method = ascm$combine_method,
Z = new_data$Z,
progfunc = ascm$progfunc,
scm = ascm$scm,
fixedeff = ascm$fixedeff,
outcomes_str = ascm$outcomes),
ascm$extra_args))
new_ascm$outcomes <- ascm$outcomes
new_ascm$data_list <- ascm$data_list
new_ascm$data$time <- ascm$data$time
# get ATT estimates
est <- predict(new_ascm, att = T)
est <- est[as.numeric(rownames(est)) >= ascm$t_int,, drop = F]
rbind(est, colMeans(est, na.rm = T))
})
ests <- simplify2array(ests)
## standard errors
se2 <- apply(ests, c(1, 2),
function(x) (n - 1) / n * sum((x - mean(x, na.rm = T)) ^ 2))
se <- sqrt(se2)
out <- list()
att <- predict(ascm, att = T)
att_post <- colMeans(att[as.numeric(rownames(att)) >= ascm$t_int,, drop = F],
na.rm = T)
out$att <- rbind(att, att_post)
t0 <- sum(as.numeric(rownames(att)) < ascm$t_int)
out$se <- rbind(matrix(NA, t0, ncol(se)), se)
out$sigma <- NA
return(out)
}
#' Compute the weighted bootstrap distribution
#' @param multisynth fitted multisynth object
#' @param rweight Function to draw random weights as a function of n (e.g rweight(n))
#' @param relative Whether to compute effects according to relative time
#' @noRd
weighted_bootstrap_multi <- function(multisynth,
rweight = rwild_b,
n_boot = 1000,
alpha = 0.05,
att_weight = NULL,
relative=NULL) {
## get info from the multisynth object
if(is.null(relative)) {
relative <- multisynth$relative
}
n <- nrow(multisynth$data$X)
att <- predict(multisynth, att=T, att_weight = att_weight)
outddim <- nrow(att)
n1 <- sum(is.finite(multisynth$data$trt))
J <- length(multisynth$grps)
# draw random weights to get bootstrap distribution
bs_est <- vapply(1:n_boot,
function(i) {
Z <- rweight(n)# / sqrt(n1)
predict(multisynth, att=T, att_weight = att_weight, bs_weight = Z) - sum(Z) / n1 * att
},
matrix(0, nrow=outddim,ncol=(J+1)))
se2 <- apply(bs_est, c(1,2),
function(x) mean((x - mean(x))^2, na.rm=T))
bias <- apply(bs_est, c(1,2),
function(x) mean(x, na.rm=T))
upper_bound <- att - apply(bs_est, c(1,2),
function(x) quantile(x, alpha / 2, na.rm = T))
lower_bound <- att - apply(bs_est, c(1,2),
function(x) quantile(x, 1 - alpha / 2, na.rm = T))
return(list(att = att,
bias = bias,
se = sqrt(se2),
upper_bound = upper_bound,
lower_bound = lower_bound))
}
#' Bayesian bootstrap
#' @param n Number of units
#' @export
rdirichlet_b <- function(n) {
Z <- as.numeric(rgamma(n, 1, 1))
return(Z / sum(Z) * n)
}
#' Non-parametric bootstrap
#' @param n Number of units
#' @export
rmultinom_b <- function(n) as.numeric(rmultinom(1, n, rep(1 / n, n)))
#' Wild bootstrap (Mammen 1993)
#' @param n Number of units
#' @export
rwild_b <- function(n) {
sample(c(-(sqrt(5) - 1) / 2, (sqrt(5) + 1) / 2 ), n,
replace = TRUE,
prob = c((sqrt(5) + 1)/ (2 * sqrt(5)), (sqrt(5) - 1) / (2 * sqrt(5))))
}
================================================
FILE: R/multi_outcomes.R
================================================
#' Fit Augmented SCM with multiple outcomes
#' @param form outcome ~ treatment | auxillary covariates
#' @param unit Name of unit column
#' @param time Name of time column
#' @param t_int Time of intervention
#' @param data Panel data as dataframe
#' @param progfunc What function to use to impute control outcomes
#' Ridge=Ridge regression (allows for standard errors),
#' None=No outcome model,
#' @param scm Whether the SCM weighting function is used
#' @param fixedeff Whether to include a unit fixed effect, default F
#' @param cov_agg Covariate aggregation functions, if NULL then use mean with NAs omitted
#' @param combine_method How to combine outcomes: `concat` concatenates outcomes and `avg` averages them, default: 'avg'
#' @param ... optional arguments for outcome model
#'
#' @return augsynth object that contains:
#' \itemize{
#' \item{"weights"}{Ridge ASCM weights}
#' \item{"l2_imbalance"}{Imbalance in pre-period outcomes, measured by the L2 norm}
#' \item{"scaled_l2_imbalance"}{L2 imbalance scaled by L2 imbalance of uniform weights}
#' \item{"mhat"}{Outcome model estimate}
#' \item{"data"}{Panel data as matrices}
#' }
#' @export
augsynth_multiout <- function(form, unit, time, t_int, data,
progfunc=c("Ridge", "None"),
scm=T,
fixedeff = FALSE,
cov_agg=NULL,
combine_method = "avg",
...) {
call_name <- match.call()
form <- Formula::Formula(form)
unit <- enquo(unit)
time <- enquo(time)
## format data
outcome <- terms(formula(form, rhs=1))[[2]]
trt <- terms(formula(form, rhs=1))[[3]]
outcomes_str <- all.vars(outcome)
outcomes <- sapply(outcomes_str, quo)
# get outcomes as a list
wide_list <- format_data_multi(outcomes, trt, unit, time, t_int, data)
## add covariates
if(length(form)[2] == 2) {
cov_form <- paste(deparse(terms(formula(form, rhs = 2))[[3]]), collapse = "")
new_form <- as.formula(paste("~", cov_form))
Z <- extract_covariates(new_form, unit, time, t_int, data, cov_agg)
} else {
Z <- NULL
}
# only allow ridge augmentation
if(! tolower(progfunc) %in% c("none", "ridge")) {
stop(paste(progfunc, "is not a valid augmentation function with multiple outcomes. Only `none` or `ridge` are allowable options for `prog_func`"))
}
# fit augmented SCM
augsynth <- fit_augsynth_multiout_internal(wide_list, combine_method, Z,
progfunc, scm,
fixedeff, outcomes_str, ...)
# add some extra data
augsynth$data$time <- data %>% distinct(!!time) %>% pull(!!time)
augsynth$call <- call_name
augsynth$t_int <- t_int
augsynth$combine_method <- combine_method
treated_units <- data %>% filter(!!trt == 1) %>% distinct(!!unit) %>% pull(!!unit)
control_units <- data %>% filter(!(!!unit %in% treated_units)) %>%
distinct(!!unit) %>% pull(!!unit)
augsynth$weights <- matrix(augsynth$weights)
rownames(augsynth$weights) <- control_units
return(augsynth)
}
#' Internal function to fit augmented SCM with multiple outcomes
#' @param wide_list List of matrices for each outcome formatted from format_data
#' @param combine_method How to combine outcomes
#' @param Z Matrix of auxiliary covariates
#' @param progfunc outcome model to use
#' @param scm Whether to fit SCM
#' @param fixedeff Whether to de-mean synth
#' @param ... Extra args for outcome model
#' @noRd
fit_augsynth_multiout_internal <- function(wide_list, combine_method, Z,
progfunc, scm, fixedeff,
outcomes_str, ...) {
# combine into a matrix for fitting and balancing
out <- combine_outcomes(wide_list, combine_method, fixedeff, ...)
wide_bal <- out$wide_bal
mhat <- out$mhat
V <- out$V
synth_data <- do.call(format_synth, wide_bal)
# set Y1 and Y0plot to be raw concatenated outcomes
X <- do.call(cbind, wide_list$X)
y <- do.call(cbind, wide_list$y)
trt <- wide_list$trt
synth_data$Y0plot <- t(cbind(X, y)[trt == 0,, drop = F])
synth_data$Y1plot <- colMeans(cbind(X, y)[trt == 1,, drop = F])
augsynth <- fit_augsynth_internal(wide_bal, synth_data, Z, progfunc,
scm, fixedeff = F, V = V, ...)
# potentially add back in fixed effects
augsynth$mhat <- mhat# + augsynth$mhat
augsynth$data <- list(X = X, trt = trt, y = y, Z = Z)
augsynth$data_list <- wide_list
augsynth$outcomes <- outcomes_str
# change fixedeff flag to match input (rather than fixedeff = F in fit_augsynth_internal)
augsynth$fixedeff <- fixedeff
##format output
class(augsynth) <- c("augsynth_multiout", "augsynth")
return(augsynth)
}
#' Helper function to combine multiple outcomes into a single balance matrix
#' @param wide_list List of lists of pre/post treatment data for each outcome
#' @param combine_method How to combine outcomes
#' @param fixedeff Whether to take out unit fixed effects or not
#' @param nu Weighting between concatenated and averaged objectives
#' @param ... Extra arguments for combination
#' @noRd
#' @return \itemize{
#' \item{"X"}{Matrix of combined pre-treatment outcomes}
#' \item{"trt"}{Vector of treatment assignments}
#' \item{"y"}{Matrix of combined post-treatment outcomes}
#' }
combine_outcomes <- function(wide_list, combine_method, fixedeff,
nu = NULL, ...) {
n_outs <- length(wide_list$X)
n_units <- Map(nrow, wide_list$X) %>% Reduce(max, .)
# take out unit fixed effects
demean_j <- function(j) {
means <- rowMeans(wide_list$X[[j]], na.rm = TRUE)
new_wide_data <- list()
new_X <- wide_list$X[[j]] - means
new_y <- wide_list$y[[j]] - means
new_wide_data$X <- new_X
new_wide_data$y <- new_y
new_wide_data$mhat_pre <- replicate(ncol(wide_list$X[[j]]),
means)
new_wide_data$mhat_post <- replicate(ncol(wide_list$y[[j]]),
means)
return(new_wide_data)
}
if(fixedeff) {
new_wide_list <- lapply(1:n_outs, demean_j)
wide_list$X <- lapply(new_wide_list, function(x) x$X)
wide_list$y <- lapply(new_wide_list, function(x) x$y)
mhat_pre <- lapply(new_wide_list, function(x) x$mhat_pre)
mhat_post <- lapply(new_wide_list, function(x) x$mhat_post)
} else {
mhat_pre <- lapply(
1:n_outs,
function(j) matrix(0, nrow = n_units, ncol = ncol(wide_list$X[[j]])))
mhat_post <- lapply(
1:n_outs,
function(j) matrix(0, nrow = n_units, ncol = ncol(wide_list$y[[j]])))
}
# combine outcomes
if(combine_method == "concat") {
# center X and scale by overall variance for outcome
# X <- lapply(wide_list$X, function(x) t(t(x) - colMeans(x)) / sd(x))
wide_bal <- list(X = do.call(cbind, lapply(wide_list$X, function(x) t(na.omit(t(x))))),
y = do.call(cbind, lapply(wide_list$y, function(x) t(na.omit(t(x))))),
trt = wide_list$trt)
# V matrix scales by inverse variance for outcome and number of periods
V <- do.call(c,
lapply(wide_list$X,
function(x) rep(1 / (sqrt(nrow(na.omit(t(x)))) *
sd(x[wide_list$trt == 0, , drop = F], na.rm=T)),
nrow(na.omit(t(x))))))
# } else if(combine_method == "svd") {
# wide_bal <- list(X = do.call(cbind, wide_list$X),
# y = do.call(cbind, wide_list$y),
# trt = wide_list$trt)
# # first get the standard deviations of the outcomes to put on the same scale
# sds <- do.call(c,
# lapply(wide_list$X,
# function(x) rep((sqrt(ncol(x)) * sd(x, na.rm=T)), ncol(x))))
# # do an SVD on centered and scaled outcomes
# X0 <- wide_bal$X[wide_bal$trt == 0, , drop = FALSE]
# X0 <- t((t(X0) - colMeans(X0)) / sds)
# k <- if(is.null(k)) ncol(X0) else k
# V <- diag(1 / sds) %*% svd(X0)$v[, 1:k, drop = FALSE]
} else if(combine_method == "avg") {
# average pre-treatment outcomes, dividing by standard deviation and removing missing values
X_avg <- rowMeans(simplify2array(lapply(wide_list$X,
function(x) (x - mean(x[wide_list$trt == 0,], na.rm = TRUE)) / sd(x[wide_list$trt == 0,], na.rm = TRUE))), dims = 2, na.rm = TRUE)
# remove any time periods with NAs
X_avg <- t(na.omit(t(X_avg)))
wide_bal <- list(X = X_avg,
y = rowMeans(simplify2array(wide_list$y), dims = 2, na.rm = TRUE),
trt = wide_list$trt)
V <- diag(ncol(wide_bal$X))
} else if(combine_method == "avg_concat") {
# average pre-treatment outcomes, dividing by standard deviation and removing missing values
# standardize the outcomes
X_list_std<- lapply(wide_list$X,function(x) (x - mean(x[wide_list$trt == 0,], na.rm = TRUE)) / sd(x[wide_list$trt == 0,], na.rm = TRUE))
X_avg <- rowMeans(simplify2array(X_list_std), dims = 2, na.rm = TRUE)
# remove any time periods with NAs
X_avg <- t(na.omit(t(X_avg)))
X_concat <- do.call(cbind, lapply(X_list_std, function(x) t(na.omit(t(x)))))
# V matrix assigns weight nu to the averaged objective and (1 - nu) to the concatenated objective
# V <- c(rep(sqrt(nu), ncol(X_avg)),
# sqrt(1 - nu) / sqrt(n_outs) * do.call(c,
# lapply(wide_list$X,
# function(x) rep(1 / (sqrt(nrow(na.omit(t(x)))) *
# sd(x[wide_list$trt == 0, , drop = F], na.rm=T)),
# nrow(na.omit(t(x))))))
# )
V <- c(rep(sqrt(nu), ncol(X_avg)), rep(sqrt(1 - nu) / sqrt(n_outs), ncol(X_concat)))
wide_bal <- list(
X = cbind(X_avg, X_concat),
y = do.call(cbind, lapply(wide_list$y, function(x) t(na.omit(t(x))))),
trt = wide_list$trt
)
} else {
stop(paste("combine_method should be one of ('avg', 'concat', 'avg_concat'),",
combine_method, " is not a valid combining option"))
}
mhat_pre <- do.call(cbind, mhat_pre)
mhat_post <- do.call(cbind, mhat_post)
mhat <- cbind(mhat_pre, mhat_post)
return(list(wide_bal = wide_bal, mhat = mhat, V = V))
}
#' Get prediction of ATT or average outcome under control
#' @param object augsynth_multiout object
#' @param ... Optional arguments, including \itemize{\item{"att"}{Whether to return the ATT or average outcome under control}}
#'
#' @return Vector of predicted post-treatment control averages
#' @export
predict.augsynth_multiout <- function(object, ...) {
if ("att" %in% names(list(...))) {
att <- list(...)$att
} else {
att <- F
}
# call augsynth predict
pred <- NextMethod()
# separate out by outcome
n_outs <- length(object$data_list$X)
max_t <- max(sapply(1:n_outs,
function(k) ncol(object$data_list$X[[k]]) + ncol(object$data_list$y[[k]])))
pred_reshape <- matrix(NA, ncol = n_outs,
nrow = max_t)
colnames <- lapply(1:n_outs,
function(k) colnames(cbind(object$data_list$X[[k]],
object$data_list$y[[k]])))
rownames(pred_reshape) <- colnames[[which.max(sapply(colnames, length))]]
colnames(pred_reshape) <- object$outcomes
# get outcome names for predictions
pre_outs <- do.call(c,
sapply(1:n_outs,
function(j) {
rep(object$outcomes[j],
ncol(object$data_list$X[[j]]))
}, simplify = FALSE))
post_outs <- do.call(c,
sapply(1:n_outs,
function(j) {
rep(object$outcomes[j],
ncol(object$data_list$y[[j]]))
}, simplify = FALSE))
# print(pred)
# print(cbind(names(pred), c(pre_outs, post_outs)))
pred_reshape[cbind(names(pred), c(pre_outs, post_outs))] <- pred
return(pred_reshape)
}
#' Print function for augsynth
#' @param x augsynth_multiout object
#' @param ... Optional arguments
#' @export
print.augsynth_multiout <- function(x, ...) {
## straight from lm
cat("\nCall:\n", paste(deparse(x$call), sep="\n", collapse="\n"), "\n\n", sep="")
## print att estimates
att <- predict(x, att = T)
att_post <- data.frame(
colMeans(att[as.numeric(rownames(att)) >= x$t_int,, drop = F]))
names(att_post) <- c("")
cat("Average ATT Estimate:\n")
print(att_post)
cat("\n\n")
}
#' Summary function for augsynth
#' @param object augsynth_multiout object
#' @param inf whether or not to perform inference
#' @param inf_typ Type of inference, default is "conformal"
#' @param grid_size Grid to compute prediction intervals over, default is 1 and only p-values are computed
#' @param ... Optional arguments, including \itemize{\item{"se"}{Whether to plot standard error}}
#' @export
summary.augsynth_multiout <- function(object, inf = T, inf_type = "conformal", grid_size = 1, ...) {
summ <- list()
if(inf) {
if(inf_type == "conformal") {
if(grid_size > 1) {
cat(paste0("A grid size of ", grid_size, " will require ",
grid_size, "^", length(object$outcomes),
" = ", grid_size ^ length(object$outcomes),
" evaluations. This could take a while..."))
}
att_se <- conformal_inf_multiout(object, grid_size = grid_size, ...)
} else {
stop("Only conformal inference is supported for multiple outcomes")
}
# if(inf_type == "jackknife") {
# att_se <- jackknife_se_multiout(object)
# } else if(inf_type == "jackknife+") {
# att_se <- time_jackknife_plus_multiout(object, ...)
# } else if(inf_type == "conformal") {
# att_se <- conformal_inf_multiout(object, ...)
# } else {
# stop(paste(inf_type, "is not a valid choice of 'inf_type'"))
# }
t_final <- nrow(att_se$att)
att_df <- data.frame(att_se$att[1:(t_final - 1),, drop=F])
names(att_df) <- object$outcomes
att_df$Time <- object$data$time
att_df <- att_df %>% gather(Outcome, Estimate, -Time)
# if(inf_type == "jackknife") {
# se_df <- data.frame(att_se$se[1:(t_final - 1),, drop=F])
# names(se_df) <- object$outcomes
# se_df$Time <- object$data$time
# se_df <- se_df %>% gather(Outcome, Std.Error, -Time)
# att <- inner_join(att_df, se_df, by = c("Time", "Outcome"))
# } else if(inf_type %in% c("conformal", "jackknife+")) {
lb_df <- data.frame(att_se$lb[1:(t_final - 1),, drop=F])
names(lb_df) <- object$outcomes
lb_df$Time <- object$data$time
lb_df <- lb_df %>% gather(Outcome, lower_bound, -Time)
ub_df <- data.frame(att_se$ub[1:(t_final - 1),, drop=F])
names(ub_df) <- object$outcomes
ub_df$Time <- object$data$time
ub_df <- ub_df %>% gather(Outcome, upper_bound, -Time)
att <- inner_join(att_df, lb_df, by = c("Time", "Outcome")) %>%
inner_join(ub_df, by = c("Time", "Outcome"))
# if(inf_type == "conformal") {
pval_df <- data.frame(att_se$p_val[1:(t_final - 1),, drop=F])
names(pval_df) <- object$outcomes
pval_df$Time <- object$data$time
pval_df <- pval_df %>% gather(Outcome, p_val, -Time)
att <- inner_join(att, pval_df, by = c("Time", "Outcome"))
# }
# }
if(grid_size == 1) {
att <- att %>% mutate(lower_bound = NA, upper_bound = NA)
}
att_avg <- data.frame(att_se$att[t_final,, drop = F])
names(att_avg) <- object$outcomes
att_avg <- gather(att_avg, Outcome, Estimate)
# if(inf_type == "jackknife") {
# att_avg_se <- data.frame(att_se$se[t_final,, drop = F])
# names(att_avg_se) <- object$outcomes
# att_avg_se <- gather(att_avg_se, Outcome, Std.Error)
# average_att <- inner_join(att_avg, att_avg_se, by="Outcome")
# } else if(inf_type %in% c("conformal", "jackknife+")){
att_avg_lb <- data.frame(att_se$lb[t_final,, drop = F])
names(att_avg_lb) <- object$outcomes
att_avg_lb <- gather(att_avg_lb, Outcome, lower_bound)
att_avg_ub <- data.frame(att_se$ub[t_final,, drop = F])
names(att_avg_ub) <- object$outcomes
att_avg_ub <- gather(att_avg_ub, Outcome, upper_bound)
average_att <- inner_join(att_avg, att_avg_lb, by="Outcome") %>%
inner_join(att_avg_ub, by = "Outcome")
# if(inf_type == "conformal") {
att_avg_pval <- data.frame(att_se$p_val[t_final,, drop = F])
names(att_avg_pval) <- object$outcomes
att_avg_pval <- gather(att_avg_pval, Outcome, p_val)
average_att <- inner_join(average_att, att_avg_pval, by = "Outcome")
if(grid_size == 1) {
average_att <- average_att %>% mutate(lower_bound = NA, upper_bound = NA)
}
# }
# } else {
# average_att <- gather(att_avg, Outcome, Estimate)
# }
} else {
att_est <- predict(object, att = T)
att_df <- data.frame(att_est)
names(att_df) <- object$outcomes
att_df$Time <- object$data$time
att <- att_df %>% gather(Outcome, Estimate, -Time)
att$Std.Error <- NA
t_int <- min(sapply(object$data_list$X, ncol))
att_avg <- data.frame(t(colMeans(
att_est[t_int:nrow(att_est),, drop = F])))
print(att_avg)
names(att_avg) <- object$outcomes
average_att <- gather(att_avg, Outcome, Estimate)
average_att$Std.Error <- NA
}
# get average of all outcomes
sds <- data.frame(Outcome = object$outcomes,
sdo = sapply(object$data_list$X,
function(x) sd(x[object$data_list$trt == 0,], na.rm = TRUE)))
att %>%
inner_join(sds, by = "Outcome") %>%
mutate(Estimate = Estimate / sdo) %>%
group_by(Time) %>%
summarise(Estimate = mean(Estimate, na.rm = TRUE)) %>%
mutate(Outcome = "Average") %>%
bind_rows(att, .) -> att
summ$att <- att
summ$average_att <- average_att
summ$t_int <- object$t_int
summ$call <- object$call
summ$l2_imbalance <- object$l2_imbalance
summ$scaled_l2_imbalance <- object$scaled_l2_imbalance
summ$inf_type <- inf_type
## get estimated bias
if(object$progfunc == "Ridge") {
mhat <- object$ridge_mhat
w <- object$synw
} else {
mhat <- object$mhat
w <- object$weights
}
trt <- object$data$trt
m1 <- colMeans(mhat[trt==1,,drop=F])
summ$bias_est <- m1 - t(mhat[trt==0,,drop=F]) %*% w
if(object$progfunc == "None" | (!object$scm)) {
summ$bias_est <- NA
}
class(summ) <- "summary.augsynth_multiout"
return(summ)
}
#' Print function for summary function for augsynth
#' @param x summary.augsynth_multiout object
#' @param ... Optional arguments
#' @export
print.summary.augsynth_multiout <- function(x, ...) {
## straight from lm
cat("\nCall:\n", paste(deparse(x$call), sep="\n", collapse="\n"), "\n\n", sep="")
att_est <- x$att$Estimate
## get pre-treatment fit by outcome
imbal <- x$att %>%
filter(Time < x$t_int) %>%
group_by(Outcome) %>%
summarise(Pre.RMSE = sqrt(mean(Estimate ^ 2, na.rm = TRUE)))
cat(paste("Overall L2 Imbalance (Scaled):",
format(round(x$l2_imbalance,3), nsmall=3), " (",
format(round(x$scaled_l2_imbalance,3), nsmall=3), ")\n\n",
# "Avg Estimated Bias: ",
# format(round(mean(summ$bias_est), 3),nsmall=3), "\n\n",
sep=""))
cat("Average ATT Estimate:\n")
print(inner_join(x$average_att, imbal, by = "Outcome"))
cat("\n\n")
}
#' Plot function for summary function for augsynth
#' @importFrom graphics plot
#' @param x summary.augsynth_multiout object
#' @param inf Boolean, whether to plot uncertainty intervals, default TRUE
#' @param plt_avg Boolean, whether to plot the average of the outcomes, default FALSE
#' @param ... Optional arguments for summary function
#'
#' @export
plot.augsynth_multiout <- function(x, inf = T, plt_avg = F, ...) {
plot(summary(x, ...), inf = inf, plt_avg = plt_avg)
}
#' Plot function for summary function for augsynth
#' @param x summary.augsynth_multiout object
#' @param inf Boolean, whether to plot uncertainty intervals, default TRUE
#' @param plt_avg Boolean, whether to plot the average of the outcomes, default FALSE
#'
#' @export
plot.summary.augsynth_multiout <- function(x, inf = F, plt_avg = F, ...) {
if(plt_avg) {
p <- x$att %>%
ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate))
} else {
p <- x$att %>%
filter(Outcome != "Average") %>%
ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate))
}
if(inf) {
if(x$inf_type == "jackknife") {
p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin=Estimate-2*Std.Error,
ymax=Estimate+2*Std.Error),
alpha=0.2, data = . %>% filter(Outcome != "Average"))
} else if(x$inf_type %in% c("conformal", "jackknife+")) {
p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin=lower_bound,
ymax=upper_bound),
alpha=0.2, data = . %>% filter(Outcome != "Average"))
}
}
p + ggplot2::geom_line() +
ggplot2::geom_vline(xintercept=x$t_int, lty=2) +
ggplot2::geom_hline(yintercept=0, lty=2) +
ggplot2::facet_wrap(~ Outcome, scales = "free_y") +
ggplot2::theme_bw()
}
================================================
FILE: R/multi_synth_qp.R
================================================
################################################################################
## Solve the multisynth problem as a QP
################################################################################
#' Internal function to fit synth with staggered adoption with a QP solver
#' @param X Matrix of pre-final intervention outcomes, or list of such matrices after transformations
#' @param trt Vector of treatment levels/times
#' @param mask Matrix with indicators for observed pre-intervention times for each treatment group
#' @param n_leads Number of time periods after treatment to impute control values.
#' For units treated at time T_j, all units treated after T_j + n_leads
#' will be used as control values. If larger than the number of periods,
#' only never never treated units (pure controls) will be used as comparison units
#' @param n_lags Number of pre-treatment periods to balance, default is to balance all periods
#' @param relative Whether to re-index time according to treatment date, default T
#' @param nu Hyper-parameter that controls trade-off between overall and individual balance.
#' Larger values of nu place more emphasis on individual balance.
#' Balance measure is
#' nu ||global|| + (1-nu) ||individual||
#' Default: 0
#' @param lambda Regularization hyper-parameter. Default, 0
#' @param time_cohort Whether to average synthetic controls into time cohorts
#' @param norm_pool Normalizing value for pooled objective, default: number of treated units squared
#' @param norm_sep Normalizing value for separate objective, default: number of treated units
#' @param verbose Whether to print logs for osqp
#' @param eps_rel Relative error tolerance for osqp
#' @param eps_abs Absolute error tolerance for osqp
#' @noRd
#' @return \itemize{
#' \item{"weights"}{Matrix of unit weights}
#' \item{"imbalance"}{Matrix of overall and group specific imbalance}
#' \item{"global_l2"}{Imbalance overall}
#' \item{"ind_l2"}{Matrix of imbalance for each group}
#' }
multisynth_qp <- function(X, trt, mask, Z = NULL, n_leads=NULL, n_lags=NULL,
relative=T, nu=0, lambda=0, V = NULL, time_cohort = FALSE,
donors = NULL, norm_pool = NULL, norm_sep = NULL,
verbose = FALSE,
eps_rel=1e-4, eps_abs=1e-4) {
# if Z has no columns then set it to NULL
if(!is.null(Z)) {
if(ncol(Z) == 0) {
Z <- NULL
}
}
n <- if(typeof(X) == "list") dim(X[[1]])[1] else dim(X)[1]
d <- if(typeof(X) == "list") dim(X[[1]])[2] else dim(X)[2]
if(is.null(n_leads)) {
n_leads <- d+1
} else if(n_leads > d) {
n_leads <- d+1
}
if(is.null(n_lags)) {
n_lags <- d
} else if(n_lags > d) {
n_lags <- d
}
V <- make_V_matrix(n_lags, V)
## treatment times
if(time_cohort) {
grps <- unique(trt[is.finite(trt)])
which_t <- lapply(grps, function(tj) (1:n)[trt == tj])
# if doing a time cohort, convert the boolean mask
mask <- unique(mask)
} else {
grps <- trt[is.finite(trt)]
which_t <- (1:n)[is.finite(trt)]
}
J <- length(grps)
if(is.null(norm_sep)) {
norm_sep <- 1#J
}
if(is.null(norm_pool)) {
norm_pool <- 1#J ^ 2
}
n1 <- sapply(1:J, function(j) length(which_t[[j]]))
# if no specific donors passed in,
# then all donors treated after n_lags are eligible
if(is.null(donors)) {
donors <- get_eligible_donors(trt, time_cohort, n_leads)
}
## handle X differently if it is a list
if(typeof(X) == "list") {
x_t <- lapply(1:J, function(j) colSums(X[[j]][which_t[[j]], mask[j,]==1, drop=F]))
# Xc contains pre-treatment data for valid donor units
Xc <- lapply(1:nrow(mask),
function(j) X[[j]][donors[[j]], mask[j,]==1, drop=F])
# std dev of outcomes for first treatment time
sdx <- sd(X[[1]][is.finite(trt)])
} else {
x_t <- lapply(1:J, function(j) colSums(X[which_t[[j]], mask[j,]==1, drop=F]))
# Xc contains pre-treatment data for valid donor units
Xc <- lapply(1:nrow(mask),
function(j) X[donors[[j]], mask[j,]==1, drop=F])
# std dev of outcomes
sdx <- sd(X[is.finite(trt)])
}
# get covariates for donors
if(!is.null(Z)) {
# scale covariates to have same variance as pure control outcomes
Z_scale <- sdx * apply(Z, 2,
function(z) (z - mean(z[!is.finite(trt)])) / sd(z[!is.finite(trt)]))
z_t <- lapply(1:J, function(j) colSums(Z_scale[which_t[[j]], , drop = F]))
Zc <- lapply(1:J, function(j) Z_scale[donors[[j]], , drop = F])
} else {
z_t <- lapply(1:J, function(j) c(0))
Zc <- lapply(1:J, function(j) Matrix::Matrix(0,
nrow = sum(donors[[j]]),
ncol = 1))
}
dz <- ncol(Zc[[1]])
# replace NA values with zero
x_t <- lapply(x_t, function(xtk) tidyr::replace_na(xtk, 0))
Xc <- lapply(Xc, function(xck) apply(xck, 2, tidyr::replace_na, 0))
## make matrices for QP
n0s <- sapply(Xc, nrow)
if(any(n0s == 0)) {
stop("Some treated units have no possible donor units!")
}
n0 <- sum(n0s)
const_mats <- make_constraint_mats(trt, grps, n_leads, n_lags, Xc, Zc, d, n1)
Amat <- const_mats$Amat
lvec <- const_mats$lvec
uvec <- const_mats$uvec
## quadratic balance measures
qvec <- make_qvec(Xc, x_t, z_t, nu, n_lags, d, V, norm_pool, norm_sep)
Pmat <- make_Pmat(Xc, x_t, dz, nu, n_lags, lambda, d, V, norm_pool, norm_sep)
## Optimize
settings <- do.call(osqp::osqpSettings,
c(list(verbose = verbose,
eps_rel = eps_rel,
eps_abs = eps_abs)))
out <- osqp::solve_osqp(Pmat, qvec, Amat, lvec, uvec, pars = settings)
## get weights
total_ctrls <- n0 * J
weights <- matrix(out$x[1:total_ctrls], nrow = n0)
nj0 <- as.numeric(lapply(Xc, nrow))
nj0cumsum <- c(0, cumsum(nj0))
imbalance <- vapply(1:J,
function(j) {
dj <- length(x_t[[j]])
ndim <- min(dj, n_lags)
c(numeric(d-ndim),
x_t[[j]][(dj-ndim+1):dj] -
t(Xc[[j]][,(dj-ndim+1):dj, drop = F]) %*%
out$x[(nj0cumsum[j] + 1):nj0cumsum[j + 1]])
},
numeric(d))
avg_imbal <- rowMeans(t(t(imbalance)))
Vsq <- t(V) %*% V
global_l2 <- c(sqrt(t(avg_imbal[(d - n_lags + 1):d]) %*% Vsq %*%
avg_imbal[(d - n_lags + 1):d])) / sqrt(d)
avg_l2 <- mean(apply(imbalance, 2,
function(x) c(sqrt(t(x[(d - n_lags + 1):d]) %*% Vsq %*%
x[(d - n_lags + 1):d]))))
ind_l2 <- sqrt(mean(
apply(imbalance, 2,
function(x) c(x[(d - n_lags + 1):d] %*% Vsq %*% x[(d - n_lags + 1):d]) /
sum(x[(d - n_lags + 1):d] != 0))))
# pad weights with zeros for treated units and divide by number of treated units
vapply(1:J,
function(j) {
weightj <- numeric(n)
weightj[donors[[j]]] <- out$x[(nj0cumsum[j] + 1):nj0cumsum[j + 1]]
weightj
},
numeric(n)) -> weights
weights <- t(t(weights) / n1)
# manually enforce non-negativity constraint
# (osqp solver only enforces constraint up to a tolerance)
weights <- pmax(weights, 0)
output <- list(weights = weights,
imbalance = cbind(avg_imbal, imbalance),
global_l2 = global_l2,
ind_l2 = ind_l2,
avg_l2 = avg_l2,
V = V)
if(!is.null(Z)) {
# imbalance in auxiliary covariates
z_t <- sapply(1:J, function(j) colMeans(Z[which_t[[j]], , drop = F]))
imbal_z <- z_t - t(Z) %*% weights
avg_imbal_z <- rowSums(t(t(imbal_z) * n1)) / sum(n1)
global_l2_z <- sqrt(sum(avg_imbal_z ^ 2))
ind_l2_z <- sum(apply(imbal_z, 2, function(x) sqrt(sum(x ^ 2))))
imbal_z <- cbind(avg_imbal_z, imbal_z)
rownames(imbal_z) <- colnames(Z)
output$imbalance_aux <- imbal_z
output$global_l2_aux <- global_l2_z
output$ind_l2_aux <- ind_l2_z
}
return(output)
}
#' Create constraint matrices for multisynth QP
#' @param trt Vector of treatment levels/times
#' @param grps Treatment times
#' @param n_leads Number of time periods after treatment to impute control values.
#' @param n_lags Number of pre-treatment periods to balance
#' @param Xc List of outcomes for possible comparison units
#' @param d Max number of lagged outcomes
#' @param n1 Vector of number of treated units per cohort
#' @noRd
#' @return
#' \itemize{
#' \item{"Amat"}{Linear constraint matrix}
#' \item{"lvec"}{Lower bounds for linear constraints}
#' \item{"uvec"}{Upper bounds for linear constraints}
#' }
make_constraint_mats <- function(trt, grps, n_leads, n_lags, Xc, Zc, d, n1) {
J <- length(grps)
idxs0 <- trt > n_leads + min(grps)
n0 <- sum(idxs0)
## sum to n1 constraint
A1 <- do.call(Matrix::bdiag, lapply(1:(J), function(x) rep(1, n0)))
A1 <- Matrix::bdiag(lapply(1:J, function(j) rep(1, nrow(Xc[[j]]))))
Amat <- as.matrix(Matrix::t(A1))
Amat <- Matrix::rbind2(Matrix::t(A1), Matrix::Diagonal(nrow(A1)))
dz <- ncol(Zc[[1]])
# constraints for transformed weights
A_trans1 <- do.call(Matrix::bdiag,
lapply(1:J,
function(j) {
dj <- ncol(Xc[[j]])
ndim <- min(dj, n_lags)
max_dim <- min(d, n_lags)
mat <- Xc[[j]][, (dj - ndim + 1):dj, drop = F]
n0 <- nrow(mat)
zero_mat <- Matrix::Matrix(0, n0, max_dim - ndim)
Matrix::t(cbind(zero_mat, mat))
}))
# sum of total number of pre-periods
sum_tj <- min(d, n_lags) * J
A_trans2 <- - Matrix::Diagonal(sum_tj)
A_trans <- Matrix::cbind2(
Matrix::cbind2(A_trans1, A_trans2),
Matrix::Matrix(0, nrow = nrow(A_trans1), ncol = dz * J))
# constraints for transformed weights on auxiliary covariates
A_transz <- Matrix::t(Matrix::bdiag(Zc))
A_transz <- Matrix::cbind2(
Matrix::cbind2(A_transz,
Matrix::Matrix(0, nrow = nrow(A_transz), ncol = sum_tj)),
-Matrix::Diagonal(dz * J))
# add in zero columns for transformated weights
Amat <- Matrix::cbind2(Amat,
Matrix::Matrix(0,
nrow = nrow(Amat),
ncol = sum_tj + dz * J))
Amat <- Matrix::rbind2(Matrix::rbind2(Amat, A_trans), A_transz)
lvec <- c(n1, # sum to n1 constraint
numeric(nrow(A1)), # lower bound by zero
numeric(sum_tj), # constrain transformed weights
numeric(dz * J) # constrain transformed weights
)
uvec <- c(n1, #sum to n1 constraint
rep(Inf, nrow(A1)),
numeric(sum_tj), # constrain transformed weights
numeric(dz * J) # constrain transformed weights
)
return(list(Amat = Amat, lvec = lvec, uvec = uvec))
}
#' Make the vector in the QP
#' @param Xc List of outcomes for possible comparison units
#' @param x_t List of outcomes for treated units
#' @param nu Hyperparameter between global and individual balance
#' @param n_lags Number of lags to balance
#' @param d Largest number of pre-intervention time periods
#' @param V Scaling matrix
#' @param norm_pool Normalizing value for pooled objective
#' @param norm_sep Normalizing value for separate objective
#' @noRd
make_qvec <- function(Xc, x_t, z_t, nu, n_lags, d, V, norm_pool, norm_sep) {
J <- length(x_t)
Vsq <- t(V) %*% V
qvec <- lapply(1:J,
function(j) {
dj <- length(x_t[[j]])
ndim <- min(dj, n_lags)
max_dim <- min(d, n_lags)
vec <- x_t[[j]][(dj - ndim + 1):dj] / ndim
Vsq %*% c(numeric(max_dim - ndim), vec)
})
avg_target_vec <- lapply(x_t,
function(xtk) {
dk <- length(xtk)
ndim <- min(dk, n_lags)
max_dim <- min(d, n_lags)
c(numeric(max_dim - ndim),
xtk[(dk - ndim + 1):dk])
}) %>% reduce(`+`) %*% Vsq
qvec_avg <- rep(avg_target_vec, J)
# qvec <- - (nu * qvec_avg / n_lags + (1 - nu) * reduce(qvec, c))
# qvec <- - (nu * qvec_avg / (J ^ 2 * n_lags) +
# (1 - nu) * reduce(qvec, c) / J)
qvec <- - (nu * qvec_avg / (norm_pool * n_lags * J ^ 2) +
(1 - nu) * reduce(qvec, c) / (norm_sep * J))
qvec_avg_z <- z_t %>% reduce(`+`)
qvec_avg_z <- rep(qvec_avg_z, J)
# qvec_z <- - (nu * qvec_avg_z + (1 - nu) * reduce(z_t, c)) / length(z_t[[1]])
# qvec_z <- - (nu * qvec_avg_z / J ^2 +
# (1 - nu) * reduce(z_t, c) / J) / length(z_t[[1]])
qvec_z <- - (nu * qvec_avg_z / (norm_pool * J ^ 2) +
(1 - nu) * reduce(z_t, c) / (norm_sep * J)) / length(z_t[[1]])
total_ctrls <- lapply(Xc, nrow) %>% reduce(`+`)
return(c(numeric(total_ctrls), qvec, qvec_z))
}
#' Make the matrix in the QP
#' @param Xc List of outcomes for possible comparison units
#' @param x_t List of outcomes for treated units
#' @param nu Hyperparameter between global and individual balance
#' @param n_lags Number of lags to balance
#' @param lambda Regularization hyperparameter
#' @param d Largest number of pre-intervention time periods
#' @param V Scaling matrix
#' @param norm_pool Normalizing value for pooled objective
#' @param norm_sep Normalizing value for separate objective
#' @noRd
make_Pmat <- function(Xc, x_t, dz, nu, n_lags, lambda, d, V,
norm_pool, norm_sep) {
J <- length(x_t)
Vsq <- t(V) %*% V
ndims <- vapply(1:J,
function(j) min(length(x_t[[j]]), n_lags),
numeric(1))
max_dim <- min(d, n_lags)
total_dim <- sum(ndims)
total_dim <- max_dim * J
V1 <- Matrix::bdiag(lapply(ndims,
function(ndim) Matrix::Diagonal(max_dim, 1 / ndim)))
V1 <- Matrix::bdiag(lapply(ndims, function(ndim) Vsq / ndim))
tile_sparse <- function(j) {
kronecker(Matrix::Matrix(1, nrow = j, ncol = j), Vsq)
}
tile_sparse_cov <- function(d, j) {
kronecker(Matrix::Matrix(1, nrow = j, ncol = j),
Matrix::Diagonal(d))
}
V2 <- tile_sparse(J) / n_lags
# Pmat <- nu * V2 + (1 - nu) * V1
# Pmat <- nu * V2 / J ^ 2 + (1 - nu) * V1 / J
Pmat <- nu * V2 / (norm_pool * J ^ 2) + (1 - nu) * V1 / (norm_sep * J)
V1_z <- Matrix::Diagonal(dz * J, 1 / dz)
V2_z <- tile_sparse_cov(dz, J) / dz
# Pmat_z <- nu * V2_z + (1 - nu) * V1_z
# Pmat_z <- nu * V2_z / J ^ 2 + (1 - nu) * V1_z / J
Pmat_z <- nu * V2_z / (norm_pool * J ^ 2) + (1 - nu) * V1_z / (norm_sep * J)
# combine
total_ctrls <- lapply(Xc, nrow) %>% reduce(`+`)
Pmat <- Matrix::bdiag(Matrix::Matrix(0, nrow = total_ctrls,
ncol = total_ctrls),
Pmat, Pmat_z)
I0 <- Matrix::bdiag(Matrix::Diagonal(total_ctrls),
Matrix::Matrix(0, nrow = total_dim + dz * J,
ncol = total_dim + dz * J))
return(Pmat + lambda * I0)
}
================================================
FILE: R/multisynth_class.R
================================================
################################################################################
## Fitting, plotting, summarizing staggered synth
################################################################################
#' Fit staggered synth
#' @param form outcome ~ treatment | weighting covariates | approximate matching covaraites | exact matching covariates
#' \itemize{
#' \item{outcome}{Name of the outcome of interest}
#' \item{treatment}{Name of the treatment assignment variable}
#' \item{weighting covariates}{Auxiliary covariates to weight on}
#' \item{approximate matching covariates}{Auxiliary covariates to approximately match one before weighting}
#' \item{exact matching covariates}{Auxiliary covariates to exactly match on before weighting}
#' }
#' If covariates are time-varying, their average value before the first unit is treated will be used. This can be changed by supplying a custom aggregation function to cov_agg.
#' @param unit Name of unit column
#' @param time Name of time column
#' @param data Panel data as dataframe
#' @param n_leads How long past treatment effects should be estimated for, default is number of post treatment periods for last treated unit
#' @param n_lags Number of pre-treatment periods to balance, default is to balance all periods
#' @param nu Fraction of balance for individual balance
#' @param lambda Regularization hyperparameter, default = 0
#' @param V Scaling matrix for synth optimization, default NULL is identity
#' @param fixedeff Whether to include a unit fixed effect, default TRUE
#' @param n_factors Number of factors for interactive fixed effects, setting to NULL fits with CV, default is 0
#' @param scm Whether to fit scm weights
#' @param time_cohort Whether to average synthetic controls into time cohorts, default FALSE
#' @param cov_agg Covariate aggregation function
#' @param eps_abs Absolute error tolerance for osqp
#' @param eps_rel Relative error tolerance for osqp
#' @param verbose Whether to print logs for osqp
#' @param ... Extra arguments
#'
#' @return multisynth object that contains:
#' \itemize{
#' \item{"weights"}{weights matrix where each column is a set of weights for a treated unit}
#' \item{"data"}{Panel data as matrices}
#' \item{"imbalance"}{Matrix of treatment minus synthetic control for pre-treatment time periods, each column corresponds to a treated unit}
#' \item{"global_l2"}{L2 imbalance for the pooled synthetic control}
#' \item{"scaled_global_l2"}{L2 imbalance for the pooled synthetic control, scaled by the imbalance for unitform weights}
#' \item{"ind_l2"}{Average L2 imbalance for the individual synthetic controls}
#' \item{"scaled_ind_l2"}{Average L2 imbalance for the individual synthetic controls, scaled by the imbalance for unitform weights}
#' \item{"n_leads", "n_lags"}{Number of post treatment outcomes (leads) and pre-treatment outcomes (lags) to include in the analysis}
#' \item{"nu"}{Fraction of balance for individual balance}
#' \item{"lambda"}{Regularization hyperparameter}
#' \item{"scm"}{Whether to fit scm weights}
#' \item{"grps"}{Time periods for treated units}
#' \item{"y0hat"}{Pilot estimates of control outcomes}
#' \item{"residuals"}{Difference between the observed outcomes and the pilot estimates}
#' \item{"n_factors"}{Number of factors for interactive fixed effects}
#' }
#' @export
multisynth <- function(form, unit, time, data,
n_leads=NULL, n_lags=NULL,
nu=NULL, lambda=0, V = NULL,
fixedeff = TRUE,
n_factors=0,
scm=T,
time_cohort = F,
how_match = "knn",
cov_agg = NULL,
eps_abs = 1e-4,
eps_rel = 1e-4,
verbose = FALSE, ...) {
call_name <- match.call()
form <- Formula::Formula(form)
unit <- enquo(unit)
time <- enquo(time)
## format data
outcome <- terms(formula(form, rhs=1))[[2]]
trt <- terms(formula(form, rhs=1))[[3]]
wide <- format_data_stag(outcome, trt, unit, time, data)
check_data_stag(wide, fixedeff, n_leads, n_lags)
force <- if(fixedeff) 3 else 2
# get covariates
if(length(form)[2] == 2) {
Z <- extract_covariates(form, unit, time, wide$time[min(wide$trt) + 1],
data, cov_agg)
} else if(length(form)[2] == 3) {
app_form <- Formula::Formula(formula(form, rhs = 1:2))
Z_weight <- extract_covariates(app_form, unit, time,
wide$time[min(wide$trt) + 1],
data, cov_agg)
exact_form <- Formula::Formula(formula(form, rhs = c(1,3)))
Z_match<- extract_covariates(exact_form, unit, time,
wide$time[min(wide$trt) + 1],
data, cov_agg)
Z <- cbind(Z_weight, Z_match)
wide$match_covariates <- colnames(Z_match)
} else if(length(form)[2] == 4) {
if(time_cohort) {
stop("Aggregating by time cohort and matching on covariates are not ",
"implemented together. If matching then you cannot aggregate ",
"by time cohort.")
}
weight_form <- Formula::Formula(formula(form, rhs = c(1,2)))
Z_weight <- extract_covariates(weight_form, unit, time,
wide$time[min(wide$trt) + 1],
data, cov_agg)
app_form <- Formula::Formula(formula(form, rhs = c(1,3)))
Z_app <- extract_covariates(app_form, unit, time,
wide$time[min(wide$trt) + 1],
data, cov_agg)
exact_form <- Formula::Formula(formula(form, rhs = c(1,4)))
Z_exact <- extract_covariates(exact_form, unit, time,
wide$time[min(wide$trt) + 1],
data, cov_agg)
Z <- cbind(Z_weight, Z_app, Z_exact)
wide$exact_covariates <- colnames(Z_exact)
wide$match_covariates <- c(colnames(Z_app), wide$exact_covariates)
} else {
Z <- NULL
}
wide$Z <- Z
# if n_leads is NULL set it to be the largest possible number of leads
# for the last treated unit
if(is.null(n_leads)) {
n_leads <- ncol(wide$y)
} else if(n_leads > max(apply(1-wide$mask, 1, sum, na.rm = T)) +
ncol(wide$y)) {
n_leads <- max(apply(1-wide$mask, 1, sum, na.rm = T)) + ncol(wide$y)
}
## if n_lags is NULL set it to the largest number of pre-treatment periods
if(is.null(n_lags)) {
n_lags <- ncol(wide$X)
} else if(n_lags > ncol(wide$X)) {
n_lags <- ncol(wide$X)
}
long_df <- data[c(quo_name(unit), quo_name(time), as.character(trt), as.character(outcome))]
msynth <- multisynth_formatted(wide = wide, relative = T,
n_leads = n_leads, n_lags = n_lags,
nu = nu, lambda = lambda, V = V,
force = force, n_factors = n_factors,
scm = scm, time_cohort = time_cohort,
time_w = F, lambda_t = 0,
fit_resids = TRUE, eps_abs = eps_abs,
eps_rel = eps_rel, verbose = verbose, long_df = long_df,
how_match = how_match, ...)
units <- data %>% arrange(!!unit) %>% distinct(!!unit) %>% pull(!!unit)
rownames(msynth$weights) <- units
if(scm) {
## Get imbalance for uniform weights on raw data
## TODO: Get rid of this stupid hack of just fitting the weights again with big lambda
unif <- multisynth_qp(X=wide$X, ## X=residuals[,1:ncol(wide$X)],
trt=wide$trt,
mask=wide$mask,
Z = Z[, ! colnames(Z) %in% wide$match_covariates,
drop = F],
n_leads=n_leads,
n_lags=n_lags,
relative=T,
nu=0, lambda=1e10,
V = V,
time_cohort = time_cohort,
donors = msynth$donors,
eps_rel = eps_rel,
eps_abs = eps_abs,
verbose = verbose)
## scaled global balance
## msynth$scaled_global_l2 <- msynth$global_l2 / sqrt(sum(unif$imbalance[,1]^2))
msynth$scaled_global_l2 <- msynth$global_l2 / unif$global_l2
## balance for individual estimates
## msynth$scaled_ind_l2 <- msynth$ind_l2 / sqrt(sum(unif$imbalance[,-1]^2))
msynth$scaled_ind_l2 <- msynth$ind_l2 / unif$ind_l2
}
msynth$call <- call_name
return(msynth)
}
#' Internal funciton to fit staggered synth with formatted data
#' @param wide List containing data elements
#' @param relative Whether to compute balance by relative time
#' @param n_leads How long past treatment effects should be estimated for
#' @param n_lags Number of pre-treatment periods to balance, default is to balance all periods
#' @param nu Fraction of balance for individual balance
#' @param lambda Regularization hyperparameter, default = 0
#' @param V Scaling matrix for synth optimization, default NULL is identity
#' @param force c(0,1,2,3) what type of fixed effects to include
#' @param n_factors Number of factors for interactive fixed effects, default does CV
#' @param scm Whether to fit scm weights
#' @param time_cohort Whether to average synthetic controls into time cohorts
#' @param time_w Whether to fit time weights
#' @param lambda_t Regularization for time regression
#' @param fit_resids Whether to fit SCM on the residuals or not
#' @param eps_abs Absolute error tolerance for osqp
#' @param eps_rel Relative error tolerance for osqp
#' @param verbose Whether to print logs for osqp
#' @param long_df A long dataframe with 4 columns in the order unit, time, trt, outcome
#' @param ... Extra arguments
#' @noRd
#' @return multisynth object
multisynth_formatted <- function(wide, relative=T, n_leads, n_lags,
nu, lambda, V,
force,
n_factors,
scm, time_cohort,
time_w, lambda_t,
fit_resids,
eps_abs, eps_rel,
verbose, long_df,
how_match, ...) {
## average together treatment groups
## grps <- unique(wide$trt) %>% sort()
if(time_cohort) {
grps <- unique(wide$trt[is.finite(wide$trt)])
} else {
grps <- wide$trt[is.finite(wide$trt)]
}
J <- length(grps)
## fit outcome models
if(time_w) {
# Autoregressive model
out <- fit_time_reg(cbind(wide$X, wide$y), wide$trt,
n_leads, lambda_t, ...)
y0hat <- out$y0hat
residuals <- out$residuals
params <- out$time_weights
} else if(is.null(n_factors)) {
out <- tryCatch({
fit_gsynth_multi(long_df, cbind(wide$X, wide$y), wide$trt, force=force)
}, error = function(error_condition) {
stop("Cannot run CV because there are too few pre-treatment periods.")
})
y0hat <- out$y0hat
params <- out$params
n_factors <- ncol(params$factor)
## get residuals from outcome model
residuals <- cbind(wide$X, wide$y) - y0hat
} else if (n_factors != 0) {
## if number of factors is provided don't do CV
out <- fit_gsynth_multi(long_df, cbind(wide$X, wide$y), wide$trt,
r=n_factors, CV=0, force=force)
y0hat <- out$y0hat
params <- out$params
## get residuals from outcome model
residuals <- cbind(wide$X, wide$y) - y0hat
} else if(force == 0 & n_factors == 0) {
# if no fixed effects or factors, just take out
# control averages at each time point
# time fixed effects from pure controls
pure_ctrl <- cbind(wide$X, wide$y)[!is.finite(wide$trt), , drop = F]
y0hat <- matrix(colMeans(pure_ctrl, na.rm = TRUE),
nrow = nrow(wide$X), ncol = ncol(pure_ctrl),
byrow = T)
residuals <- cbind(wide$X, wide$y) - y0hat
params <- NULL
} else {
## take out pre-treatment averages
fullmask <- cbind(wide$mask, matrix(0, nrow=nrow(wide$mask),
ncol=ncol(wide$y)))
out <- fit_feff(cbind(wide$X, wide$y), wide$trt, fullmask, force, time_cohort)
y0hat <- out$y0hat
residuals <- out$residuals
params <- NULL
}
## balance the residuals
if(fit_resids) {
if(time_w) {
# fit scm on residuals after taking out unit fixed effects
fullmask <- cbind(wide$mask, matrix(0, nrow=nrow(wide$mask),
ncol=ncol(wide$y)))
out <- fit_feff(cbind(wide$X, wide$y), wide$trt, fullmask, force, time_cohort)
bal_mat <- lapply(out$residuals, function(x) x[,1:ncol(wide$X)])
} else if(typeof(residuals) == "list") {
bal_mat <- lapply(residuals, function(x) x[,1:ncol(wide$X)])
} else {
bal_mat <- residuals[,1:ncol(wide$X)]
}
} else {
# if not balancing residuals, then take out control averages
# for each time
ctrl_avg <- matrix(colMeans(wide$X[!is.finite(wide$trt), , drop = F]),
nrow = nrow(wide$X), ncol = ncol(wide$X), byrow = T)
bal_mat <- wide$X - ctrl_avg
bal_mat <- wide$X
}
if(scm) {
# get eligible set of donor units based on covariates
donors <- get_donors(wide$X, wide$y, wide$trt,
wide$Z[, colnames(wide$Z) %in%
wide$match_covariates, drop = F],
time_cohort, n_lags, n_leads, how = how_match,
exact_covariates = wide$exact_covariates, ...)
# run separate synth for scaling
sep_fit <- multisynth_qp(X=bal_mat,
trt=wide$trt,
mask=wide$mask,
Z = wide$Z[, !colnames(wide$Z) %in%
wide$match_covariates,
drop = F],
n_leads=n_leads,
n_lags=n_lags,
relative=relative,
nu=0, lambda=lambda,
V = V,
time_cohort = time_cohort,
donors = donors,
eps_rel = eps_rel,
eps_abs = eps_abs,
verbose = verbose)
# if no nu value is provided, use default based on
# global and individual imbalance for separate synth
if(is.null(nu)) {
# select nu by triangle inequality ratio
glbl <- sep_fit$global_l2 * sqrt(nrow(sep_fit$imbalance))
ind <- sep_fit$avg_l2
nu <- glbl / ind
}
msynth <- multisynth_qp(X=bal_mat,
trt=wide$trt,
mask=wide$mask,
Z = wide$Z[, !colnames(wide$Z) %in%
wide$match_covariates,
drop = F],
n_leads=n_leads,
n_lags=n_lags,
relative=relative,
nu=nu, lambda=lambda,
V = V,
time_cohort = time_cohort,
donors = donors,
norm_pool = sep_fit$global_l2 ^ 2,
norm_sep = sep_fit$ind_l2 ^ 2,
eps_rel = eps_rel,
eps_abs = eps_abs,
verbose = verbose)
} else {
msynth <- list(weights = matrix(0, nrow = nrow(wide$X), ncol = J),
imbalance=NA,
global_l2=NA,
ind_l2=NA)
}
## put in data and hyperparams
msynth$data <- wide
msynth$relative <- relative
msynth$n_leads <- n_leads
msynth$n_lags <- n_lags
msynth$nu <- nu
msynth$lambda <- lambda
msynth$scm <- scm
msynth$time_cohort <- time_cohort
msynth$grps <- grps
msynth$y0hat <- y0hat
msynth$residuals <- residuals
msynth$n_factors <- n_factors
msynth$force <- force
## outcome model parameters
msynth$params <- params
# more arguments
msynth$scm <- scm
msynth$time_w <- time_w
msynth$lambda_t <- lambda_t
msynth$fit_resids <- fit_resids
msynth$extra_pars <- c(list(eps_abs = eps_abs,
eps_rel = eps_rel,
verbose = verbose),
list(...))
msynth$long_df <- long_df
msynth$how_match <- how_match
msynth$donors <- donors
##format output
class(msynth) <- "multisynth"
return(msynth)
}
#' Get prediction of average outcome under control or ATT
#' @param object Fit multisynth object
#' @param att If TRUE, return the ATT, if FALSE, return imputed counterfactual
#' @param att_weight Weights to place on individual units/cohorts when averaging
#' @param bs_weight Weight to perturb units by for weighted bootstrap
#' @param ... Optional arguments
#'
#' @return Matrix of predicted post-treatment control outcomes for each treated unit
#' @export
predict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = NULL, ...) {
multisynth <- object
relative <- T
time_cohort <- multisynth$time_cohort
if(is.null(relative)) {
relative <- multisynth$relative
}
n_leads <- multisynth$n_leads
d <- ncol(multisynth$data$X)
n <- nrow(multisynth$data$X)
fulldat <- cbind(multisynth$data$X, multisynth$data$y)
ttot <- ncol(fulldat)
grps <- multisynth$grps
J <- length(grps)
if(is.null(bs_weight)) {
# bs_weight <- rep(1 / sqrt(sum(is.finite(multisynth$data$trt))), n)
bs_weight <- rep(1, n)
}
if(time_cohort) {
which_t <- lapply(grps,
function(tj) (1:n)[multisynth$data$trt == tj])
mask <- unique(multisynth$data$mask)
} else {
which_t <- (1:n)[is.finite(multisynth$data$trt)]
mask <- multisynth$data$mask
}
n1 <- sapply(1:J, function(j) length(which_t[[j]]))
fullmask <- cbind(mask, matrix(0, nrow = J, ncol = (ttot - d)))
## estimate the post-treatment values to get att estimates
mu1hat <- vapply(1:J,
function(j) colMeans((bs_weight * fulldat)[which_t[[j]],
, drop=FALSE]),
numeric(ttot))
## get average outcome model estimates and reweight residuals
if(typeof(multisynth$y0hat) == "list") {
mu0hat <- vapply(1:J,
function(j) {
y0hat <- colMeans(
(bs_weight * multisynth$y0hat[[j]])[which_t[[j]],
, drop=FALSE])
weightsj <- multisynth$weights[,j] * bs_weight
resj <- multisynth$residuals[[j]][weightsj != 0,, drop = F]
y0hat + t(resj) %*% weightsj[weightsj != 0]
}
, numeric(ttot)
)
} else {
mu0hat <- vapply(1:J,
function(j) {
y0hat <- colMeans(
(bs_weight * multisynth$y0hat)[which_t[[j]],
, drop=FALSE])
weightsj <- multisynth$weights[, j] * bs_weight
resj <- multisynth$residuals[weightsj != 0,, drop = F]
y0hat + t(resj) %*% weightsj[weightsj != 0]
}
, numeric(ttot)
)
}
tauhat <- mu1hat - mu0hat
if(is.null(att_weight)) {
att_weight <- rep(1, J)
}
## re-index time if relative to treatment
if(relative) {
total_len <- min(d + n_leads, ttot + d - min(grps)) ## total length of predictions
mu0hat <- vapply(1:J,
function(j) {
vec <- c(rep(NA, d-grps[j]),
mu0hat[1:grps[j],j],
mu0hat[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j])
## last row is post-treatment average
c(vec,
rep(NA, total_len - length(vec)),
mean(mu0hat[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j]))
},
numeric(total_len +1
))
tauhat <- vapply(1:J,
function(j) {
vec <- c(rep(NA, d-grps[j]),
tauhat[1:grps[j],j],
tauhat[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j])
## last row is post-treatment average
c(vec,
rep(NA, total_len - length(vec)),
mean(tauhat[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j]))
},
numeric(total_len +1
))
# re-index unit weights if they change over time
if(is.null(dim(att_weight))) {
if(J == 1) {
att_weight <- matrix(replicate(total_len + 1, att_weight), ncol = 1)
} else {
att_weight <- t(replicate(total_len + 1, att_weight))
}
}
att_weight_new <- vapply(1:J,
function(j) {
vec <- c(rep(NA, d-grps[j]),
att_weight[1:grps[j],j],
att_weight[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j])
## last row is post-treatment average
c(vec,
rep(NA, total_len - length(vec)),
mean(att_weight[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j]))
},
numeric(total_len +1
))
## get the overall average estimate
avg <- apply(mu0hat, 1, function(z) sum(n1 * z, na.rm=T) / sum(n1 * !is.na(z)))
avg <- sapply(1:nrow(mu0hat),
function(k) {
sum(n1 * mu0hat[k,] * att_weight_new[k,], na.rm=T) /
sum(n1 * (!is.na(mu0hat[k,])) * att_weight_new[k, ], na.rm = T)
})
mu0hat <- cbind(avg, mu0hat)
avg <- apply(tauhat, 1, function(z) sum(n1 * z, na.rm=T) / sum(n1 * !is.na(z)))
avg <- sapply(1:nrow(mu0hat),
function(k) {
sum(n1 * tauhat[k,] * att_weight_new[k,], na.rm=T) /
sum(n1 * (!is.na(tauhat[k,])) * att_weight_new[k, ], na.rm = T)
})
tauhat <- cbind(avg, tauhat)
} else {
## remove all estimates for t > T_j + n_leads
vapply(1:J,
function(j) c(mu0hat[1:min(grps[j]+n_leads, ttot),j],
rep(NA, max(0, ttot-(grps[j] + n_leads)))),
numeric(ttot)) -> mu0hat
vapply(1:J,
function(j) c(tauhat[1:min(grps[j]+n_leads, ttot),j],
rep(NA, max(0, ttot-(grps[j] + n_leads)))),
numeric(ttot)) -> tauhat
## only average currently treated units
avg1 <- rowSums(t(fullmask) * mu0hat * n1) /
rowSums(t(fullmask) * n1)
avg2 <- rowSums(t(1-fullmask) * mu0hat * n1) /
rowSums(t(1-fullmask) * n1)
avg <- replace_na(avg1, 0) * apply(fullmask, 2, min) +
replace_na(avg2,0) * apply(1-fullmask, 2, max)
cbind(avg, mu0hat) -> mu0hat
## only average currently treated units
avg1 <- rowSums(t(fullmask) * tauhat * n1) /
rowSums(t(fullmask) * n1)
avg2 <- rowSums(t(1-fullmask) * tauhat * n1) /
rowSums(t(1-fullmask) * n1)
avg <- replace_na(avg1, 0) * apply(fullmask, 2, min) +
replace_na(avg2,0) * apply(1 - fullmask, 2, max)
cbind(avg, tauhat) -> tauhat
}
if(att) {
return(tauhat)
} else {
return(mu0hat)
}
}
#' Print function for multisynth
#' @param x multisynth object
#' @param ... Optional arguments
#' @export
print.multisynth <- function(x, att_weight = NULL, ...) {
multisynth <- x
## straight from lm
cat("\nCall:\n", paste(deparse(multisynth$call),
sep="\n", collapse="\n"), "\n\n", sep="")
# print att estimates
att_post <- predict(multisynth, att=T, att_weight = att_weight)[,1]
att_post <- att_post[length(att_post)]
cat(paste("Average ATT Estimate: ",
format(round(mean(att_post),3), nsmall = 3), "\n\n", sep=""))
}
#' Plot function for multisynth
#' @importFrom graphics plot
#' @param x Augsynth object to be plotted
#' @param inf_type Type of inference to perform:
#' \itemize{
#' \item{bootstrap}{Wild bootstrap, the default option}
#' \item{jackknife}{Jackknife}
#' }
#' @param inf Whether to compute and plot confidence intervals
#' @param levels Which units/groups to plot, default is every group
#' @param label Whether to label the individual levels
#' @param weights Whether to plot the weights, default = FALSE
#' @param ... Optional arguments
#' @export
plot.multisynth <- function(x, inf_type = "bootstrap", inf = T,
levels = NULL, label = T,
weights = FALSE, ...) {
if(weights) {
ever_trt <- x$data$units[is.finite(x$data$trt)]
never_trt <- x$data$units[!is.finite(x$data$trt)]
weights <- data.frame(x$weights)
colnames(weights) <- ever_trt
weights$unit <- factor(rownames(weights),
levels = sort(rownames(weights), decreasing = TRUE))
# plotting the weights
weights %>%
tidyr::pivot_longer(-unit, names_to = "trt_unit", values_to = "weight") %>%
ggplot2::ggplot(aes(x = trt_unit, y = unit, fill = round(weight, 3))) +
ggplot2::geom_tile(color = "white", size=.5) +
ggplot2::scale_fill_gradient(low = "white", high = "black", limits=c(-0.01,1.01)) +
ggplot2::guides(fill = "none") +
ggplot2::xlab("Treated Unit") +
ggplot2::ylab("Donor Unit") +
ggplot2::theme_bw() +
ggplot2::theme(axis.ticks.x = ggplot2::element_blank(),
axis.ticks.y = ggplot2::element_blank())
}
else {
plot(summary(x, inf_type = inf_type, ...),
inf = inf, levels = levels, label = label)
}
}
#' Summary function for multisynth
#' @param object multisynth object
#' @param inf_type Type of inference to perform:
#' \itemize{
#' \item{bootstrap}{Wild bootstrap, the default option}
#' \item{jackknife}{Jackknife}
#' }
#' @param ... Optional arguments
#'
#' @return summary.multisynth object that contains:
#' \itemize{
#' \item{"att"}{Dataframe with ATT estimates, standard errors for each treated unit}
#' \item{"global_l2"}{L2 imbalance for the pooled synthetic control}
#' \item{"scaled_global_l2"}{L2 imbalance for the pooled synthetic control, scaled by the imbalance for unitform weights}
#' \item{"ind_l2"}{Average L2 imbalance for the individual synthetic controls}
#' \item{"scaled_ind_l2"}{Average L2 imbalance for the individual synthetic controls, scaled by the imbalance for unitform weights}
#' \item{"n_leads", "n_lags"}{Number of post treatment outcomes (leads) and pre-treatment outcomes (lags) to include in the analysis}
#' }
#' @export
summary.multisynth <- function(object, inf_type = "bootstrap", att_weight = NULL, ...) {
multisynth <- object
relative <- T
n_leads <- multisynth$n_leads
d <- ncol(multisynth$data$X)
n <- nrow(multisynth$data$X)
ttot <- d + ncol(multisynth$data$y)
trt <- multisynth$data$trt
time_cohort <- multisynth$time_cohort
if(time_cohort) {
grps <- unique(trt[is.finite(trt)])
which_t <- lapply(grps, function(tj) (1:n)[trt == tj])
} else {
grps <- trt[is.finite(trt)]
which_t <- (1:n)[is.finite(trt)]
}
# grps <- unique(multisynth$data$trt) %>% sort()
J <- length(grps)
# which_t <- (1:n)[is.finite(multisynth$data$trt)]
times <- multisynth$data$time
summ <- list()
## post treatment estimate for each group and overall
# att <- predict(multisynth, relative, att=T)
if(inf_type == "jackknife") {
attse <- jackknife_se_multi(multisynth, relative, att_weight = att_weight, ...)
} else if(inf_type == "bootstrap") {
if(object$force == 2) {
warning("Wild bootstrap without including a unit fixed effect ",
"is likely to be very conservative!")
}
attse <- weighted_bootstrap_multi(multisynth, att_weight = att_weight, ...)
} else {
att <- predict(multisynth, relative, att=T, att_weight = att_weight)
attse <- list(att = att,
se = matrix(NA, nrow(att), ncol(att)),
upper_bound = matrix(NA, nrow(att), ncol(att)),
lower_bound = matrix(NA, nrow(att), ncol(att)))
}
if(relative) {
att <- data.frame(cbind(c(-(d-1):min(n_leads, ttot-min(grps)), NA),
attse$att))
if(time_cohort) {
col_names <- c("Time", "Average",
as.character(times[grps + 1]))
} else {
col_names <- c("Time", "Average",
as.character(multisynth$data$units[which_t]))
}
names(att) <- col_names
att %>% gather(Level, Estimate, -Time) %>%
rename("Time"=Time) %>%
mutate(Time=Time-1) -> att
se <- data.frame(cbind(c(-(d-1):min(n_leads, ttot-min(grps)), NA),
attse$se))
names(se) <- col_names
se %>% gather(Level, Std.Error, -Time) %>%
rename("Time"=Time) %>%
mutate(Time=Time-1) -> se
lower_bound <- data.frame(cbind(c(-(d-1):min(n_leads, ttot-min(grps)), NA),
attse$lower_bound))
names(lower_bound) <- col_names
lower_bound %>% gather(Level, lower_bound, -Time) %>%
rename("Time"=Time) %>%
mutate(Time=Time-1) -> lower_bound
upper_bound <- data.frame(cbind(c(-(d-1):min(n_leads, ttot-min(grps)), NA),
attse$upper_bound))
names(upper_bound) <- col_names
upper_bound %>% gather(Level, upper_bound, -Time) %>%
rename("Time"=Time) %>%
mutate(Time=Time-1) -> upper_bound
} else {
att <- data.frame(cbind(times, attse$att))
names(att) <- c("Time", "Average", times[grps[1:J]])
att %>% gather(Level, Estimate, -Time) -> att
se <- data.frame(cbind(times, attse$se))
names(se) <- c("Time", "Average", times[grps[1:J]])
se %>% gather(Level, Std.Error, -Time) -> se
}
summ$att <- inner_join(att, se, by = c("Time", "Level")) %>%
inner_join(lower_bound, by = c("Time", "Level")) %>%
inner_join(upper_bound, by = c("Time", "Level"))
summ$relative <- relative
summ$grps <- grps
summ$call <- multisynth$call
summ$global_l2 <- multisynth$global_l2
summ$scaled_global_l2 <- multisynth$scaled_global_l2
summ$ind_l2 <- multisynth$ind_l2
summ$scaled_ind_l2 <- multisynth$scaled_ind_l2
summ$n_leads <- multisynth$n_leads
summ$n_lags <- multisynth$n_lags
class(summ) <- "summary.multisynth"
return(summ)
}
#' Print function for summary function for multisynth
#' @param x summary object
#' @param level Which unit/group to print results for, default is the overall average
#' @param ... Optional arguments
#' @export
print.summary.multisynth <- function(x, level = "Average", ...) {
summ <- x
## straight from lm
cat("\nCall:\n", paste(deparse(summ$call), sep="\n", collapse="\n"), "\n\n", sep="")
first_lvl <- summ$att %>% filter(Level != "Average") %>% pull(Level) %>% min()
## get ATT estimates for treatment level, post treatment
if(summ$relative) {
summ$att %>%
filter(Time >= 0, Level==level) %>%
rename("Time Since Treatment"=Time) -> att_est
} else if(level == "average") {
summ$att %>% filter(Time > first_lvl, Level=="Average") -> att_est
} else {
summ$att %>% filter(Time > level, Level==level) -> att_est
}
cat(paste("Average ATT Estimate (Std. Error): ",
summ$att %>%
filter(Level == level, is.na(Time)) %>%
pull(Estimate) %>%
round(3) %>% format(nsmall=3),
" (",
summ$att %>%
filter(Level == level, is.na(Time)) %>%
pull(Std.Error) %>%
round(3) %>% format(nsmall=3),
")\n\n", sep=""))
cat(paste("Global L2 Imbalance: ",
format(round(summ$global_l2,3), nsmall=3), "\n",
"Scaled Global L2 Imbalance: ",
format(round(summ$scaled_global_l2,3), nsmall=3), "\n",
"Percent improvement from uniform global weights: ",
format(round(1-summ$scaled_global_l2,3)*100), "\n\n",
"Individual L2 Imbalance: ",
format(round(summ$ind_l2,3), nsmall=3), "\n",
"Scaled Individual L2 Imbalance: ",
format(round(summ$scaled_ind_l2,3), nsmall=3), "\n",
"Percent improvement from uniform individual weights: ",
format(round(1-summ$scaled_ind_l2,3)*100), "\t",
"\n\n",
sep=""))
print(att_est, row.names=F)
}
#' Plot function for summary function for multisynth
#' @importFrom ggplot2 aes
#'
#' @param x summary object
#' @param inf Whether to plot confidence intervals
#' @param levels Which units/groups to plot, default is every group
#' @param label Whether to label the individual levels
#' @param weights Whether to plot the weights, default = FALSE
#' @param ... Optional arguments
#' @export
plot.summary.multisynth <- function(x, inf = T, levels = NULL, label = T,
weights = FALSE, ...) {
if(weights) {
ever_trt <- x$data$units[is.finite(x$data$trt)]
never_trt <- x$data$units[!is.finite(x$data$trt)]
weights <- data.frame(x$weights)
colnames(weights) <- ever_trt
weights$unit <- factor(rownames(weights),
levels = sort(rownames(weights), decreasing = TRUE))
# plotting the weights
weights %>%
tidyr::pivot_longer(-unit, names_to = "trt_unit", values_to = "weight") %>%
ggplot2::ggplot(aes(x = trt_unit, y = unit, fill = round(weight, 3))) +
ggplot2::geom_tile(color = "white", size=.5) +
ggplot2::scale_fill_gradient(low = "white", high = "black", limits=c(-0.01,1.01)) +
ggplot2::guides(fill = "none") +
ggplot2::xlab("Treated Unit") +
ggplot2::ylab("Donor Unit") +
ggplot2::theme_bw() +
ggplot2::theme(axis.ticks.x = ggplot2::element_blank(),
axis.ticks.y = ggplot2::element_blank())
}
summ <- x
## get the last time period for each level
summ$att %>%
filter(!is.na(Estimate),
Time >= -summ$n_lags,
Time <= summ$n_leads) %>%
group_by(Level) %>%
summarise(last_time = max(Time)) -> last_times
if(is.null(levels)) levels <- unique(summ$att$Level)
summ$att %>% inner_join(last_times) %>%
filter(Level %in% levels) %>%
mutate(label = ifelse(Time == last_time, Level, NA),
is_avg = ifelse(("Average" %in% levels) * (Level == "Average"),
"A", "B")) %>%
ggplot2::ggplot(ggplot2::aes(x = Time, y = Estimate,
group = Level,
color = is_avg,
alpha = is_avg)) +
ggplot2::geom_line(size = 1) +
ggplot2::geom_point(size = 1) -> p
if(label) {
p <- p + ggrepel::geom_label_repel(ggplot2::aes(label = label),
nudge_x = 1, na.rm = T)
}
p <- p + ggplot2::geom_hline(yintercept = 0, lty = 2)
if(summ$relative) {
p <- p + ggplot2::geom_vline(xintercept = 0, lty = 2) +
ggplot2::xlab("Time Relative to Treatment")
} else {
p <- p + ggplot2::geom_vline(aes(xintercept = as.numeric(Level)),
lty = 2, alpha = 0.5,
summ$att %>% filter(Level != "Average"))
}
## add ses
if(inf) {
max_time <- max(summ$att$Time, na.rm = T)
if(max_time == 0) {
error_plt <- ggplot2::geom_errorbar
clr <- "black"
alph <- 1
} else {
error_plt <- ggplot2::geom_ribbon
clr <- NA
alph <- 0.2
}
if("Average" %in% levels) {
p <- p + error_plt(
ggplot2::aes(ymin=lower_bound,
ymax=upper_bound),
alpha = alph, color=clr,
data = summ$att %>%
filter(Level == "Average",
Time >= 0))
} else {
p <- p + error_plt(
ggplot2::aes(ymin=lower_bound,
ymax=upper_bound),
data = . %>% filter(Time >= 0),
alpha = alph, color = clr)
}
}
p <- p + ggplot2::scale_alpha_manual(values=c(1, 0.5)) +
ggplot2::scale_color_manual(values=c("#333333", "#818181")) +
ggplot2::guides(alpha=F, color=F) +
ggplot2::theme_bw()
return(p)
}
================================================
FILE: R/outcome_models.R
================================================
################################################################################
## Code to fit various outcome models
################################################################################
#' Use a separate regularized regression for each post period
#' to fit E[Y(0)|X]
#' @importFrom stats poly
#' @importFrom stats coef
#'
#' @param X Matrix of covariates/lagged outcomes
#' @param y Matrix of post-period outcomes
#' @param trt Vector of treatment indicator
#' @param alpha Mixing between L1 and L2, default: 1 (LASSO)
#' @param lambda Regularization hyperparameter, if null then CV
#' @param poly_order Order of polynomial to fit, default 1
#' @param type How to fit outcome model(s)
#' \itemize{
#' \item{sep }{Separate outcome models}
#' \item{avg }{Average responses into 1 outcome}
#' \item{multi }{Use multi response regression in glmnet}}
#' @param ... optional arguments for outcome model
#' @noRd
#' @return \itemize{
#' \item{y0hat }{Predicted outcome under control}
#' \item{params }{Regression parameters}}
fit_prog_reg <- function(X, y, trt, alpha=1, lambda=NULL,
poly_order=1, type="sep", ...) {
if(!requireNamespace("glmnet", quietly = TRUE)) {
stop("In order to fit an elastic net outcome model, you must install the glmnet package.")
}
extra_params = list(...)
if (length(extra_params) > 0) {
warning("Unused parameters when using elastic net: ", paste(names(extra_params), collapse = ", "))
}
X <- matrix(poly(matrix(X),degree=poly_order), nrow=dim(X)[1])
## helper function to fit regression with CV
outfit <- function(x, y) {
if(is.null(lambda)) {
lam <- glmnet::cv.glmnet(x, y, alpha=alpha, grouped=FALSE)$lambda.min
} else {
lam <- lambda
}
fit <- glmnet::glmnet(x, y, alpha=alpha,
lambda=lam)
return(as.matrix(coef(fit)))
}
if(type=="avg") {
## if fitting the average post period value, stack post periods together
stacky <- c(y)
stackx <- do.call(rbind,
lapply(1:dim(y)[2],
function(x) X))
stacktrt <- rep(trt, dim(y)[2])
regweights <- outfit(stackx[stacktrt==0,],
stacky[stacktrt==0])
} else if(type=="sep"){
## fit separate regressions for each post period
regweights <- apply(as.matrix(y), 2,
function(yt) outfit(X[trt==0,],
yt[trt==0]))
} else {
## fit multi response regression
lam <- glmnet::cv.glmnet(X, y, family="mgaussian",
alpha=alpha, grouped=FALSE)$lambda.min
fit <- glmnet::glmnet(X, y, family="mgaussian",
alpha=alpha,
lambda=lam)
regweights <- as.matrix(do.call(cbind, coef(fit)))
}
## Get predicted values
y0hat <- cbind(rep(1, dim(X)[1]),
X) %*% regweights
return(list(y0hat = y0hat,
params = regweights))
}
#' Use a separate random forest regression for each post period
#' to fit E[Y(0)|X]
#'
#' @param X Matrix of covariates/lagged outcomes
#' @param y Matrix of post-period outcomes
#' @param trt Vector of treatment indicator
#' @param avg Predict the average post-treatment outcome
#' @param ... optional arguments for outcome model
#' @noRd
#' @return \itemize{
#' \item{y0hat }{Predicted outcome under control}
#' \item{params }{Regression parameters}}
fit_prog_rf <- function(X, y, trt, avg=FALSE, ...) {
if(!requireNamespace("randomForest", quietly = TRUE)) {
stop("In order to fit a random forest outcome model, you must install the randomForest package.")
}
extra_params = list(...)
if (length(extra_params) > 0) {
warning("Unused parameters when using random forest: ", paste(names(extra_params), collapse = ", "))
}
## helper function to fit RF
outfit <- function(x, y) {
fit <- randomForest::randomForest(x, y)
return(fit)
}
if(avg | dim(y)[2] == 1) {
## if fitting the average post period value, stack post periods together
stacky <- c(y)
stackx <- do.call(rbind,
lapply(1:dim(y)[2],
function(x) X))
stacktrt <- rep(trt, dim(y)[2])
fit <- outfit(stackx[stacktrt==0,],
stacky[stacktrt==0])
## predict outcome
y0hat <- matrix(predict(fit, X), ncol=1)
## keep feature importances
imports <- randomForest::importance(fit)
} else {
## fit separate regressions for each post period
fits <- apply(as.matrix(y), 2,
function(yt) outfit(X[trt==0,],
yt[trt==0]))
## predict outcome
y0hat <- lapply(fits, function(fit) as.matrix(predict(fit,X))) %>%
bind_rows() %>%
as.matrix()
## keep feature importances
imports <- lapply(fits, function(fit) randomForest::importance(fit)) %>%
bind_rows() %>%
as.matrix()
}
return(list(y0hat=y0hat,
params=imports))
}
#' Use gsynth to fit factor model for E[Y(0)|X]
#'
#' @param X Matrix of covariates/lagged outcomes
#' @param y Matrix of post-period outcomes
#' @param trt Vector of treatment indicator
#' @param r Number of factors to use (or start with if CV==1)
#' @param r.end Max number of factors to consider if CV==1
#' @param force Fixed effects (0=none, 1=unit, 2=time, 3=two-way)
#' @param CV Whether to do CV (0=no CV, 1=yes CV)
#' @param ... optional arguments for outcome model
#' @noRd
#' @return \itemize{
#' \item{y0hat }{Predicted outcome under control}
#' \item{params }{Regression parameters}}
fit_prog_gsynth <- function(X, y, trt, r=0, r.end=5, force=3, CV=1, ...) {
if(!requireNamespace("gsynth", quietly = TRUE)) {
stop("In order to fit generalized synthetic controls, you must install the gsynth package.")
}
extra_params = list(...)
if (length(extra_params) > 0) {
warning("Unused parameters when using gSynth: ", paste(names(extra_params), collapse = ", "))
}
df_x = data.frame(X, check.names=FALSE)
df_x$unit = rownames(df_x)
df_x$trt = rep(0, nrow(df_x))
df_x <- df_x %>% select(unit, trt, everything())
long_df_x = gather(df_x, time, obs, -c(unit,trt))
df_y = data.frame(y, check.names=FALSE)
df_y$unit = rownames(df_y)
df_y$trt = trt
df_y <- df_y %>% select(unit, trt, everything())
long_df_y = gather(df_y, time, obs, -c(unit,trt))
long_df = rbind(long_df_x, long_df_y)
transform(long_df, time = as.numeric(time))
transform(long_df, unit = as.numeric(unit))
gsyn <- gsynth::gsynth(data = long_df, Y = "obs", D = "trt",
index = c("unit", "time"), force = force, CV = CV, r = r)
t0 <- dim(X)[2]
t_final <- t0 + dim(y)[2]
n <- dim(X)[1]
## get predicted outcomes
y0hat <- matrix(0, nrow=n, ncol=(t_final-t0))
y0hat[trt==0,] <- t(gsyn$Y.co[(t0+1):t_final,,drop=FALSE] -
gsyn$est.co$residuals[(t0+1):t_final,,drop=FALSE])
y0hat[trt==1,] <- gsyn$Y.ct[(t0+1):t_final,]
## add treated prediction for whole pre-period
gsyn$est.co$Y.ct <- gsyn$Y.ct
## control and treated residuals
gsyn$est.co$ctrl_resids <- gsyn$est.co$residuals
gsyn$est.co$trt_resids <- colMeans(cbind(X[trt==1,,drop=FALSE],
y[trt==1,,drop=FALSE])) -
rowMeans(gsyn$est.co$Y.ct)
return(list(y0hat=y0hat,
params=gsyn$est.co))
}
#' Use Athey (2017) matrix completion panel data code
#'
#' @param X Matrix of covariates/lagged outcomes
#' @param y Matrix of post-period outcomes
#' @param trt Vector of treatment indicator
#' @param unit_fixed Whether to estimate unit fixed effects
#' @param time_fixed Whether to estimate time fixed effects
#' @param ... optional arguments for outcome model
#' @noRd
#' @return \itemize{
#' \item{y0hat }{Predicted outcome under control}
#' \item{params }{Regression parameters}}
fit_prog_mcpanel <- function(X, y, trt, unit_fixed=1, time_fixed=1, ...) {
if(!requireNamespace("MCPanel", quietly = TRUE)) {
stop("In order to fit matrix completion, you must install the MCPanel package.")
}
extra_params = list(...)
if (length(extra_params) > 0) {
warning("Unused parameters when using MCPanel: ", paste(names(extra_params), collapse = ", "))
}
## create matrix and missingness matrix
t0 <- dim(X)[2]
t_final <- t0 + dim(y)[2]
n <- dim(X)[1]
fullmat <- cbind(X, y)
maskmat <- matrix(1, nrow=nrow(fullmat), ncol=ncol(fullmat))
maskmat[trt==1, (t0+1):t_final] <- 0
## estimate matrix
mcp <- MCPanel::mcnnm_cv(fullmat, maskmat,
to_estimate_u=unit_fixed, to_estimate_v=time_fixed)
## impute matrix
imp_mat <- mcp$L +
sweep(matrix(0, nrow=nrow(fullmat), ncol=ncol(fullmat)), 1, mcp$u, "+") + # unit fixed
sweep(matrix(0, nrow=nrow(fullmat), ncol=ncol(fullmat)), 2, mcp$v, "+") # time fixed
trtmat <- matrix(0, ncol=n, nrow=t_final)
trtmat[t0:t_final, trt == 1] <- 1
## get predicted outcomes
y0hat <- imp_mat[,(t0+1):t_final,drop=FALSE]
params <- mcp
params$trt_resids <- colMeans(cbind(X[trt==1,,drop=FALSE],
y[trt==1,,drop=FALSE])) -
rowMeans(imp_mat[trt==1,,drop=FALSE])
params$ctrl_resids <- t(cbind(X[trt==0,,drop=FALSE],
y[trt==0,,drop=FALSE]) - imp_mat[trt==0,,drop=FALSE])
params$Y.ct <- t(imp_mat[trt==1,,drop=FALSE])
return(list(y0hat=y0hat,
params=params))
}
#' Fit a Comparitive interupted time series
#' to fit E[Y(0)|X]
#' @importFrom stats lm
#' @im
gitextract_wwvwxodd/
├── .Rbuildignore
├── .gitignore
├── .travis.yml
├── DESCRIPTION
├── LICENSE
├── NAMESPACE
├── R/
│ ├── augsynth.R
│ ├── augsynth_pre.R
│ ├── cv.R
│ ├── data.R
│ ├── eligible_donors.R
│ ├── fit_synth.R
│ ├── format.R
│ ├── globalVariables.R
│ ├── highdim.R
│ ├── inference.R
│ ├── multi_outcomes.R
│ ├── multi_synth_qp.R
│ ├── multisynth_class.R
│ ├── outcome_models.R
│ ├── outcome_multi.R
│ ├── ridge.R
│ ├── ridge_lambda.R
│ └── time_regression_multi.R
├── README.md
├── data/
│ └── kansas.rda
├── data-raw/
│ ├── clean_kansas.R
│ └── kansas_longer2.dta
├── man/
│ ├── augsynth-package.Rd
│ ├── augsynth.Rd
│ ├── augsynth_multiout.Rd
│ ├── check_data_stag.Rd
│ ├── conformal_inf.Rd
│ ├── conformal_inf_linear.Rd
│ ├── conformal_inf_multiout.Rd
│ ├── get_nona_donors.Rd
│ ├── jackknife_se_single.Rd
│ ├── kansas.Rd
│ ├── make_V_matrix.Rd
│ ├── multisynth.Rd
│ ├── plot.augsynth.Rd
│ ├── plot.augsynth_multiout.Rd
│ ├── plot.multisynth.Rd
│ ├── plot.summary.augsynth.Rd
│ ├── plot.summary.augsynth_multiout.Rd
│ ├── plot.summary.multisynth.Rd
│ ├── predict.augsynth.Rd
│ ├── predict.augsynth_multiout.Rd
│ ├── predict.multisynth.Rd
│ ├── print.augsynth.Rd
│ ├── print.augsynth_multiout.Rd
│ ├── print.multisynth.Rd
│ ├── print.summary.augsynth.Rd
│ ├── print.summary.augsynth_multiout.Rd
│ ├── print.summary.multisynth.Rd
│ ├── rdirichlet_b.Rd
│ ├── rmultinom_b.Rd
│ ├── rwild_b.Rd
│ ├── single_augsynth.Rd
│ ├── summary.augsynth.Rd
│ ├── summary.augsynth_multiout.Rd
│ ├── summary.multisynth.Rd
│ ├── time_jackknife_plus.Rd
│ └── time_jackknife_plus_multiout.Rd
├── pkg.Rproj
├── tests/
│ ├── testthat/
│ │ ├── test_augsynth_pre.R
│ │ ├── test_format.R
│ │ ├── test_general.R
│ │ ├── test_lambda.R
│ │ ├── test_load_data.R
│ │ ├── test_multiple_outcomes.R
│ │ ├── test_multisynth.R
│ │ ├── test_multisynth_covariates.R
│ │ ├── test_outcome_models.R
│ │ ├── test_time_cohort.R
│ │ └── test_unbalanced_multisynth.R
│ └── testthat.R
└── vignettes/
├── .gitignore
├── multi-outcomes-vignette.Rmd
├── multisynth-vignette.Rmd
├── multisynth-vignette.md
├── singlesynth-vignette.Rmd
└── singlesynth-vignette.md
Condensed preview — 83 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (401K chars).
[
{
"path": ".Rbuildignore",
"chars": 67,
"preview": "^data-raw$\n^Meta$\n^doc$\n^\\.travis\\.yml$\n^pkg.Rproj$\nfigure$\ncache$\n"
},
{
"path": ".gitignore",
"chars": 233,
"preview": "Meta\ndoc\ninst/doc\n\n## Files\n# Emacs autosave files\n*~\n\\#*#\n# Don't put data in the repo\n*.csv\n*.feather\n\n# R stuff\n*.Rou"
},
{
"path": ".travis.yml",
"chars": 266,
"preview": "# R for travis: see documentation at https://docs.travis-ci.com/user/languages/r\n\nlanguage: r\nr:\n - 3.5.1 \nsudo: fal"
},
{
"path": "DESCRIPTION",
"chars": 741,
"preview": "Package: augsynth\nTitle: The Augmented Synthetic Control Method\nVersion: 0.2.0\nAuthors@R: person(\"Eli\", \"Ben-Michael\", e"
},
{
"path": "LICENSE",
"chars": 1076,
"preview": "MIT License\n\nCopyright (c) 2018 Elijahu Ben-Michael\n\nPermission is hereby granted, free of charge, to any person obtaini"
},
{
"path": "NAMESPACE",
"chars": 1232,
"preview": "# Generated by roxygen2: do not edit by hand\n\nS3method(plot,augsynth)\nS3method(plot,augsynth_multiout)\nS3method(plot,mul"
},
{
"path": "R/augsynth.R",
"chars": 21716,
"preview": "################################################################################\n## Main functions for single-period tre"
},
{
"path": "R/augsynth_pre.R",
"chars": 3864,
"preview": "################################################################################\n## Main function for the augmented synt"
},
{
"path": "R/cv.R",
"chars": 3987,
"preview": "drop_time_t <- function(wide_data, Z, t_drop) {\n new_wide_data <- list()\n new_wide_data$trt <- wide_data$trt\n \n if ("
},
{
"path": "R/data.R",
"chars": 1410,
"preview": "#' Economic indicators for US states from 1990-2016\n#' \n#' \n#' @format A dataframe with 5250 rows and 32 variables:\n#' \\"
},
{
"path": "R/eligible_donors.R",
"chars": 4979,
"preview": "##############################################################################\n## Code to get eligible donor units based"
},
{
"path": "R/fit_synth.R",
"chars": 2659,
"preview": "#######################################################\n# Helper scripts to fit synthetic controls to simulations\n######"
},
{
"path": "R/format.R",
"chars": 11471,
"preview": "################################################################################\n## Scripts to format panel data into ma"
},
{
"path": "R/globalVariables.R",
"chars": 317,
"preview": "utils::globalVariables(c(\"time\", \"val\", \"post\", \"weight\", \".\", \"Time\",\n \"Estimate\", \"Std.Error\","
},
{
"path": "R/highdim.R",
"chars": 7802,
"preview": "################################################################################\n## Methods to use flexible outcome mode"
},
{
"path": "R/inference.R",
"chars": 47117,
"preview": "################################################################################\n## Code for inference\n#################"
},
{
"path": "R/multi_outcomes.R",
"chars": 22446,
"preview": "#' Fit Augmented SCM with multiple outcomes\n#' @param form outcome ~ treatment | auxillary covariates\n#' @param unit Nam"
},
{
"path": "R/multi_synth_qp.R",
"chars": 16216,
"preview": "################################################################################\n## Solve the multisynth problem as a QP"
},
{
"path": "R/multisynth_class.R",
"chars": 39714,
"preview": "################################################################################\n## Fitting, plotting, summarizing stagg"
},
{
"path": "R/outcome_models.R",
"chars": 18480,
"preview": "################################################################################\n## Code to fit various outcome models\n#"
},
{
"path": "R/outcome_multi.R",
"chars": 3827,
"preview": "################################################################################\n## Fitting outcome models for multiple "
},
{
"path": "R/ridge.R",
"chars": 13898,
"preview": "################################################################################\n## Ridge-augmented SCM\n################"
},
{
"path": "R/ridge_lambda.R",
"chars": 2307,
"preview": "################################################################################\n## Function to calculate error on diffe"
},
{
"path": "R/time_regression_multi.R",
"chars": 9227,
"preview": "##############################################################################\n## Outcome regression with multiple treat"
},
{
"path": "README.md",
"chars": 1371,
"preview": "# augsynth: Augmented Synthetic Control Method\n[\nlibrary(tidyverse)\n\nkansas <- read_dta(\"kansas_longer2.dta\")\nstate_abb <- read_csv(\"us-state-ansi-fips.cs"
},
{
"path": "man/augsynth-package.Rd",
"chars": 252,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\docType{package}\n\\name{augsynt"
},
{
"path": "man/augsynth.Rd",
"chars": 2013,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth_pre.R\n\\name{augsynth}\n\\alias{augs"
},
{
"path": "man/augsynth_multiout.Rd",
"chars": 1568,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{augsynth_multiout}\n"
},
{
"path": "man/check_data_stag.Rd",
"chars": 670,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/format.R\n\\name{check_data_stag}\n\\alias{che"
},
{
"path": "man/conformal_inf.Rd",
"chars": 1432,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{conformal_inf}\n\\alias{co"
},
{
"path": "man/conformal_inf_linear.Rd",
"chars": 1461,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{conformal_inf_linear}\n\\a"
},
{
"path": "man/conformal_inf_multiout.Rd",
"chars": 1504,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{conformal_inf_multiout}\n"
},
{
"path": "man/get_nona_donors.Rd",
"chars": 390,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/eligible_donors.R\n\\name{get_nona_donors}\n\\"
},
{
"path": "man/jackknife_se_single.Rd",
"chars": 782,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{jackknife_se_single}\n\\al"
},
{
"path": "man/kansas.Rd",
"chars": 1563,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/data.R\n\\docType{data}\n\\name{kansas}\n\\alias"
},
{
"path": "man/make_V_matrix.Rd",
"chars": 268,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/fit_synth.R\n\\name{make_V_matrix}\n\\alias{ma"
},
{
"path": "man/multisynth.Rd",
"chars": 3654,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{multisynth}\n\\alia"
},
{
"path": "man/plot.augsynth.Rd",
"chars": 535,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{plot.augsynth}\n\\alias{plo"
},
{
"path": "man/plot.augsynth_multiout.Rd",
"chars": 618,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{plot.augsynth_multi"
},
{
"path": "man/plot.multisynth.Rd",
"chars": 841,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{plot.multisynth}\n"
},
{
"path": "man/plot.summary.augsynth.Rd",
"chars": 457,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{plot.summary.augsynth}\n\\a"
},
{
"path": "man/plot.summary.augsynth_multiout.Rd",
"chars": 589,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{plot.summary.augsyn"
},
{
"path": "man/plot.summary.multisynth.Rd",
"chars": 692,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{plot.summary.mult"
},
{
"path": "man/predict.augsynth.Rd",
"chars": 548,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{predict.augsynth}\n\\alias{"
},
{
"path": "man/predict.augsynth_multiout.Rd",
"chars": 597,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{predict.augsynth_mu"
},
{
"path": "man/predict.multisynth.Rd",
"chars": 776,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{predict.multisynt"
},
{
"path": "man/print.augsynth.Rd",
"chars": 329,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{print.augsynth}\n\\alias{pr"
},
{
"path": "man/print.augsynth_multiout.Rd",
"chars": 371,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{print.augsynth_mult"
},
{
"path": "man/print.multisynth.Rd",
"chars": 368,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{print.multisynth}"
},
{
"path": "man/print.summary.augsynth.Rd",
"chars": 394,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{print.summary.augsynth}\n\\"
},
{
"path": "man/print.summary.augsynth_multiout.Rd",
"chars": 445,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{print.summary.augsy"
},
{
"path": "man/print.summary.multisynth.Rd",
"chars": 516,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{print.summary.mul"
},
{
"path": "man/rdirichlet_b.Rd",
"chars": 259,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{rdirichlet_b}\n\\alias{rdi"
},
{
"path": "man/rmultinom_b.Rd",
"chars": 268,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{rmultinom_b}\n\\alias{rmul"
},
{
"path": "man/rwild_b.Rd",
"chars": 264,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{rwild_b}\n\\alias{rwild_b}"
},
{
"path": "man/single_augsynth.Rd",
"chars": 1518,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{single_augsynth}\n\\alias{s"
},
{
"path": "man/summary.augsynth.Rd",
"chars": 1085,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{summary.augsynth}\n\\alias{"
},
{
"path": "man/summary.augsynth_multiout.Rd",
"chars": 718,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{summary.augsynth_mu"
},
{
"path": "man/summary.multisynth.Rd",
"chars": 1295,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{summary.multisynt"
},
{
"path": "man/time_jackknife_plus.Rd",
"chars": 915,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{time_jackknife_plus}\n\\al"
},
{
"path": "man/time_jackknife_plus_multiout.Rd",
"chars": 948,
"preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{time_jackknife_plus_mult"
},
{
"path": "pkg.Rproj",
"chars": 312,
"preview": "Version: 1.0\n\nRestoreWorkspace: No\nSaveWorkspace: No\nAlwaysSaveHistory: Default\n\nEnableCodeIndexing: Yes\nEncoding: UTF-8"
},
{
"path": "tests/testthat/test_augsynth_pre.R",
"chars": 4308,
"preview": "context(\"Testing that top level API runs the right functions\")\n\nlibrary(Synth)\n\n\ntest_that(\"augsynth runs single_synth w"
},
{
"path": "tests/testthat/test_format.R",
"chars": 3487,
"preview": "context(\"Test data formatting\")\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0"
},
{
"path": "tests/testthat/test_general.R",
"chars": 4775,
"preview": "context(\"Generally testing the workflow for augsynth\")\n\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = c"
},
{
"path": "tests/testthat/test_lambda.R",
"chars": 1604,
"preview": "context(\"Testing lambda tuning if ridge is true.\")\n\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = case_"
},
{
"path": "tests/testthat/test_load_data.R",
"chars": 121,
"preview": "context(\"Testing that we can load packaged data\")\n\ntest_that(\"kansas data loads\", {\n expect_error(data(kansas), NA)\n}"
},
{
"path": "tests/testthat/test_multiple_outcomes.R",
"chars": 7158,
"preview": "context(\"Generally testing the workflow for synth with multiple outcomes\")\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque"
},
{
"path": "tests/testthat/test_multisynth.R",
"chars": 7439,
"preview": "context(\"Generally testing the workflow for multisynth\")\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = "
},
{
"path": "tests/testthat/test_multisynth_covariates.R",
"chars": 11918,
"preview": "context(\"Testing multisynth with covariates\")\nset.seed(1011)\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(tr"
},
{
"path": "tests/testthat/test_outcome_models.R",
"chars": 3671,
"preview": "context(\"Testing that augmenting synth with different models loads and runs\")\n\n\n\nlibrary(Synth)\ndata(basque)\nbasque <- b"
},
{
"path": "tests/testthat/test_time_cohort.R",
"chars": 1447,
"preview": "context(\"Test time cohort vs unit level analysis\")\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = case_w"
},
{
"path": "tests/testthat/test_unbalanced_multisynth.R",
"chars": 7151,
"preview": "context(\"Test multisynth for unbalanced panels\")\n\nset.seed(1011)\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutat"
},
{
"path": "tests/testthat.R",
"chars": 60,
"preview": "library(testthat)\nlibrary(augsynth)\n\ntest_check(\"augsynth\")\n"
},
{
"path": "vignettes/.gitignore",
"chars": 11,
"preview": "*.html\n*.R\n"
},
{
"path": "vignettes/multi-outcomes-vignette.Rmd",
"chars": 4564,
"preview": "---\noutput: rmarkdown::html_vignette\nvignette: >\n %\\VignetteIndexEntry{Multi Outcomes AugSynth Vignette}\n %\\VignetteEn"
},
{
"path": "vignettes/multisynth-vignette.Rmd",
"chars": 7377,
"preview": "---\noutput: rmarkdown::html_vignette\nvignette: >\n %\\VignetteIndexEntry{MultiSynth Vignette}\n %\\VignetteEngine{knitr::r"
},
{
"path": "vignettes/multisynth-vignette.md",
"chars": 20171,
"preview": "---\noutput: rmarkdown::html_vignette\nvignette: >\n %\\VignetteIndexEntry{MultiSynth Vignette}\n %\\VignetteEngine{knitr::r"
},
{
"path": "vignettes/singlesynth-vignette.Rmd",
"chars": 8303,
"preview": "---\noutput: rmarkdown::html_vignette\nvignette: >\n %\\VignetteIndexEntry{Single Outcome AugSynth Vignette}\n %\\VignetteEn"
},
{
"path": "vignettes/singlesynth-vignette.md",
"chars": 19586,
"preview": "---\noutput: rmarkdown::html_vignette\nvignette: >\n %\\VignetteIndexEntry{Single Outcome AugSynth Vignette}\n %\\VignetteEn"
}
]
// ... and 2 more files (download for full content)
About this extraction
This page contains the full source code of the ebenmichael/augsynth GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 83 files (373.9 KB), approximately 110.9k tokens. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.